Commit 5d90769e authored by guangzlu's avatar guangzlu
Browse files

bwd pass now

parent 644a45e9
......@@ -43,7 +43,7 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -72,13 +72,14 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using DataType = F16;
using InputDataType = F16;
using OutputDataType = F16;
using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = INT32; // INT32
using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -89,7 +90,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
// static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
......@@ -189,8 +190,7 @@ using DeviceGemmInstanceBWD =
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
......@@ -248,9 +248,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization
Deterministic>;
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -330,8 +329,7 @@ using DeviceGemmInstanceBWD =
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
......@@ -352,7 +350,7 @@ using DeviceGemmInstanceBWD =
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // NPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
......@@ -362,9 +360,9 @@ using DeviceGemmInstanceBWD =
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // NXdlPerWave
2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -387,11 +385,10 @@ using DeviceGemmInstanceBWD =
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization
Deterministic>;
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......@@ -733,7 +730,7 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_drop = 0.0;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......
......@@ -1574,8 +1574,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
make_multi_index(0, // MBlockId
block_work_idx[I0], // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
......@@ -1995,7 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
decltype(i)>(s_slash_p_thread_buf,
ph,
global_elem_id +
i.value * id_step,
id_step * i.value,
z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
......@@ -2054,6 +2054,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
......@@ -2241,7 +2243,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm1::c_block_slice_copy_step); // step M
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(1, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment