Commit 5d90769e authored by guangzlu's avatar guangzlu
Browse files

bwd pass now

parent 644a45e9
...@@ -43,7 +43,7 @@ Kernel outputs: ...@@ -43,7 +43,7 @@ Kernel outputs:
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -72,13 +72,14 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -72,13 +72,14 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16;
using InputDataType = F16; using InputDataType = F16;
using OutputDataType = F16; using OutputDataType = F16;
using GemmDataType = F16; using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = INT32; // INT32 using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -89,7 +90,7 @@ static constexpr ck::index_t NumDimK = 1; ...@@ -89,7 +90,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4 // When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8 // When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4; // static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -189,8 +190,7 @@ using DeviceGemmInstanceBWD = ...@@ -189,8 +190,7 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
InputDataType, DataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -248,9 +248,8 @@ using DeviceGemmInstanceBWD = ...@@ -248,9 +248,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -330,8 +329,7 @@ using DeviceGemmInstanceBWD = ...@@ -330,8 +329,7 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
InputDataType, DataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -352,7 +350,7 @@ using DeviceGemmInstanceBWD = ...@@ -352,7 +350,7 @@ using DeviceGemmInstanceBWD =
1, 1,
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 64, // NPerBlock
64, // KPerBlock 64, // KPerBlock
64, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
...@@ -362,9 +360,9 @@ using DeviceGemmInstanceBWD = ...@@ -362,9 +360,9 @@ using DeviceGemmInstanceBWD =
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 2, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave 1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -387,11 +385,10 @@ using DeviceGemmInstanceBWD = ...@@ -387,11 +385,10 @@ using DeviceGemmInstanceBWD =
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -733,7 +730,7 @@ int run(int argc, char* argv[]) ...@@ -733,7 +730,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.0;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
......
...@@ -1574,8 +1574,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1574,8 +1574,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(0, // MBlockId
0, // NBlockId block_work_idx[I0], // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
...@@ -1995,7 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1995,7 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
decltype(i)>(s_slash_p_thread_buf, decltype(i)>(s_slash_p_thread_buf,
ph, ph,
global_elem_id + global_elem_id +
i.value * id_step, id_step * i.value,
z_tenor_buffer); z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
...@@ -2054,6 +2054,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2054,6 +2054,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto global_elem_id = auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
ignore = z_grid_buf; ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
...@@ -2241,7 +2243,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2241,7 +2243,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm1::c_block_slice_copy_step); // step M qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm1::c_block_slice_copy_step); // step M
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow( lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(1, 0, 0, 0)); lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(1, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock, y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment