Commit 763e26be authored by letaoqin's avatar letaoqin
Browse files

kernel add all code, need debug

parent de53e421
......@@ -199,6 +199,7 @@ using ReferenceDropoutInstance =
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorD,
typename TensorS,
typename TensorP,
typename TensorZ,
......@@ -207,6 +208,7 @@ template <typename TensorQ,
void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k,
const TensorV& v_g_n_o,
const TensorD& d_g_m_n,
const float alpha,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
......@@ -226,6 +228,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
s_g_m_n.ForEach(
[&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
......@@ -261,7 +266,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
int init_method = 1; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true;
// Overall QKV matrices shape
......@@ -409,11 +414,13 @@ int run(int argc, char* argv[])
{
case 0: break;
case 1:
// q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{0});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
......@@ -509,6 +516,7 @@ int run(int argc, char* argv[])
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
......@@ -611,7 +619,7 @@ int run(int argc, char* argv[])
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O) *
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount;
......@@ -635,6 +643,7 @@ int run(int argc, char* argv[])
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
d_g_m_n,
alpha,
s_g_m_n,
p_g_m_n,
......
......@@ -1181,7 +1181,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
struct D0Loader
{
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M()
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
......@@ -1193,8 +1193,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0M3,
I1));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor(make_tuple(D0M2, Number<NPerBlock>{}, D0M3),
make_tuple(Number<NPerBlock>{} * D0M3, D0M3, I1));
constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor(
d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple((D0M2 * D0M3) / I8, I2, I4 / D0M3)),
make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M3)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3, 4>{}, Sequence<0, 1>{}, Sequence<5>{}));
return d0_n0_n1_m0_m1_m2_m3;
}
static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M();
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_desc_n0_n1_m0_m1_m2_m3 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, I4 / D0M3, D0M3));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
......@@ -1219,6 +1239,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
false,
true, // DstResetCoord
1>;
using D0ThreadCopy =
ThreadwiseTensorSliceTransfer_v4<D0DataType, // SrcData
D0DataType, // DstData
decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 2, 2>, // SliceLengths
Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder
5, // SrcVectorDim
D0M3.value, // SrcScalarPerVector
D0M3.value>;
};
template <bool HasMainKBlockLoop,
......@@ -1546,14 +1577,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4,
scale_rp_dropout);
// D0
auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
//
// Blockwise softmax
//
......@@ -1703,7 +1726,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0, //
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 46)
// {
// printf("get_thread_local_1d_id(): %d, wave_id[I0]: %d wave_id[I1]: %d "
// "wave_m_n_id[I0]: %d wave_m_n_id[I1]: %d \n",
// get_thread_local_1d_id(),
// wave_id[I0],
// wave_id[I1],
// wave_m_n_id[I0],
// wave_m_n_id[I1]);
// }
// D0
auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0, 0));
ignore = d0_thread_copy_lds_to_vgpr;
//
// set up Y dot dY
//
......@@ -1940,6 +1983,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<16>{}));
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
......@@ -1947,7 +1993,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M1, 1>{}([&](auto) {
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
static_for<0, D0M1, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
......@@ -1956,6 +2007,33 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2_m3,
make_tuple(I0, I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, D0Loader::d0_thread_desc_.GetElementSpaceSize(), 1>{}(
[&](auto i) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(mr, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// {
// printf("c_offset: %d s_slash_p_thread_buf(Number<c_offset>{}):
// %f, "
// "d0_thread_buf[i]: %f\n",
// c_offset,
// s_slash_p_thread_buf(Number<c_offset>{}),
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// }
});
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
......@@ -2036,11 +2114,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
// RunRead(), RunWrite(), and MoveSliceWindow(). But it is impossible to
// implement given that the A1 source buffer is static buffer holding the output
// of first GEMM and requires constexpr offset by design. Therefore, we pass
// tensor coordinate offset explicitly in Run() below.
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
......@@ -2196,11 +2274,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
// RunRead(), RunWrite(), and MoveSliceWindow(). But it is impossible to
// implement given that the A1 source buffer is static buffer holding the output
// of first GEMM and requires constexpr offset by design. Therefore, we pass
// tensor coordinate offset explicitly in Run() below.
// preload data into LDS
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
......@@ -2286,7 +2364,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write
......
......@@ -111,6 +111,11 @@ struct StaticBufferTupleOfVector
return base::operator()(i_v).template AsType<S>()(i_s);
}
template <index_t I>
__host__ __device__ constexpr S& operator()(Number<I> i_v, Number<I> i_s)
{
return base::operator()(i_v).template AsType<S>()(i_s);
}
// Get X
// i is offset of S, not X. i should be aligned to X
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment