Commit 763e26be authored by letaoqin's avatar letaoqin
Browse files

kernel add all code, need debug

parent de53e421
...@@ -199,6 +199,7 @@ using ReferenceDropoutInstance = ...@@ -199,6 +199,7 @@ using ReferenceDropoutInstance =
template <typename TensorQ, template <typename TensorQ,
typename TensorK, typename TensorK,
typename TensorV, typename TensorV,
typename TensorD,
typename TensorS, typename TensorS,
typename TensorP, typename TensorP,
typename TensorZ, typename TensorZ,
...@@ -207,6 +208,7 @@ template <typename TensorQ, ...@@ -207,6 +208,7 @@ template <typename TensorQ,
void run_attention_fwd_host(const TensorQ& q_g_m_k, void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k, const TensorK& k_g_n_k,
const TensorV& v_g_n_o, const TensorV& v_g_n_o,
const TensorD& d_g_m_n,
const float alpha, const float alpha,
TensorS& s_g_m_n, TensorS& s_g_m_n,
TensorP& p_g_m_n, TensorP& p_g_m_n,
...@@ -226,6 +228,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -226,6 +228,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
s_g_m_n.ForEach(
[&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking // masking
auto M = s_g_m_n.GetLengths()[1]; auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
...@@ -261,7 +266,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -261,7 +266,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate int init_method = 1; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true; bool time_kernel = true;
// Overall QKV matrices shape // Overall QKV matrices shape
...@@ -409,11 +414,13 @@ int run(int argc, char* argv[]) ...@@ -409,11 +414,13 @@ int run(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
// q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{0});
break; break;
case 2: case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
...@@ -509,6 +516,7 @@ int run(int argc, char* argv[]) ...@@ -509,6 +516,7 @@ int run(int argc, char* argv[])
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data()); z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
...@@ -611,7 +619,7 @@ int run(int argc, char* argv[]) ...@@ -611,7 +619,7 @@ int run(int argc, char* argv[])
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N + (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) + sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N + sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O) * sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N) *
BatchCount + BatchCount +
sizeof(LSEDataType) * M * BatchCount; sizeof(LSEDataType) * M * BatchCount;
...@@ -635,6 +643,7 @@ int run(int argc, char* argv[]) ...@@ -635,6 +643,7 @@ int run(int argc, char* argv[])
run_attention_fwd_host(q_g_m_k, run_attention_fwd_host(q_g_m_k,
k_g_n_k, k_g_n_k,
v_g_n_o, v_g_n_o,
d_g_m_n,
alpha, alpha,
s_g_m_n, s_g_m_n,
p_g_m_n, p_g_m_n,
......
...@@ -1181,7 +1181,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1181,7 +1181,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
struct D0Loader struct D0Loader
{ {
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M() __host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
...@@ -1193,8 +1193,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1193,8 +1193,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0M3, D0M3,
I1)); I1));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor(make_tuple(D0M2, Number<NPerBlock>{}, D0M3),
make_tuple(Number<NPerBlock>{} * D0M3, D0M3, I1));
constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor(
d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple((D0M2 * D0M3) / I8, I2, I4 / D0M3)),
make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M3)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3, 4>{}, Sequence<0, 1>{}, Sequence<5>{}));
return d0_n0_n1_m0_m1_m2_m3;
}
static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M(); GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_desc_n0_n1_m0_m1_m2_m3 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, I4 / D0M3, D0M3));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
...@@ -1219,6 +1239,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1219,6 +1239,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
false, false,
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadCopy =
ThreadwiseTensorSliceTransfer_v4<D0DataType, // SrcData
D0DataType, // DstData
decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 2, 2>, // SliceLengths
Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder
5, // SrcVectorDim
D0M3.value, // SrcScalarPerVector
D0M3.value>;
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -1546,14 +1577,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1546,14 +1577,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4, qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4,
scale_rp_dropout); scale_rp_dropout);
// D0
auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// //
// Blockwise softmax // Blockwise softmax
// //
...@@ -1703,7 +1726,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1703,7 +1726,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0, // 0, //
wave_m_n_id[I1]), // NPerXdl wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 46)
// {
// printf("get_thread_local_1d_id(): %d, wave_id[I0]: %d wave_id[I1]: %d "
// "wave_m_n_id[I0]: %d wave_m_n_id[I1]: %d \n",
// get_thread_local_1d_id(),
// wave_id[I0],
// wave_id[I1],
// wave_m_n_id[I0],
// wave_m_n_id[I1]);
// }
// D0
auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0, 0));
ignore = d0_thread_copy_lds_to_vgpr;
// //
// set up Y dot dY // set up Y dot dY
// //
...@@ -1940,6 +1983,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1940,6 +1983,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<16>{}));
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
...@@ -1947,15 +1993,47 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1947,15 +1993,47 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M1, 1>{}([&](auto) { auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
static_for<0, D0M1, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf); d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf); d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2_m3,
make_tuple(I0, I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, D0Loader::d0_thread_desc_.GetElementSpaceSize(), 1>{}(
[&](auto i) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(mr, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// {
// printf("c_offset: %d s_slash_p_thread_buf(Number<c_offset>{}):
// %f, "
// "d0_thread_buf[i]: %f\n",
// c_offset,
// s_slash_p_thread_buf(Number<c_offset>{}),
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// }
});
}); });
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
...@@ -2036,11 +2114,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2036,11 +2114,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV = P_drop^T * dY // dV = P_drop^T * dY
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), // For a1_blockwise_copy, the goal is to satisfy pipeline requirements
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that // RunRead(), RunWrite(), and MoveSliceWindow(). But it is impossible to
// the A1 source buffer is static buffer holding the output of first GEMM and // implement given that the A1 source buffer is static buffer holding the output
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset // of first GEMM and requires constexpr offset by design. Therefore, we pass
// explicitly in Run() below. // tensor coordinate offset explicitly in Run() below.
// preload data into LDS // preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
...@@ -2196,11 +2274,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2196,11 +2274,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK = scalar * dS^T * Q // dK = scalar * dS^T * Q
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), // For a1_blockwise_copy, the goal is to satisfy pipeline requirements
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that // RunRead(), RunWrite(), and MoveSliceWindow(). But it is impossible to
// the A1 source buffer is static buffer holding the output of first GEMM and // implement given that the A1 source buffer is static buffer holding the output
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset // of first GEMM and requires constexpr offset by design. Therefore, we pass
// explicitly in Run() below. // tensor coordinate offset explicitly in Run() below.
// preload data into LDS // preload data into LDS
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf); kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
...@@ -2286,7 +2364,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2286,7 +2364,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop } while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write // shuffle dK&dV and write
......
...@@ -111,6 +111,11 @@ struct StaticBufferTupleOfVector ...@@ -111,6 +111,11 @@ struct StaticBufferTupleOfVector
return base::operator()(i_v).template AsType<S>()(i_s); return base::operator()(i_v).template AsType<S>()(i_s);
} }
template <index_t I>
__host__ __device__ constexpr S& operator()(Number<I> i_v, Number<I> i_s)
{
return base::operator()(i_v).template AsType<S>()(i_s);
}
// Get X // Get X
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment