Commit ec9c2b5e authored by Anthony Chang's avatar Anthony Chang
Browse files

dK validates

parent 2d55c14c
......@@ -647,7 +647,10 @@ int run(int argc, char* argv[])
1e-2);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData);
kgrad_gs_ns_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData,
......@@ -656,7 +659,7 @@ int run(int argc, char* argv[])
1e-2);
}
return pass ? 0 : 1;
return pass ? (std::cout << "pass\n", 0) : (std::cout << "fail\n", 1);
}
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -789,7 +789,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
template <typename CGradDesc_N_O>
__host__ __device__ static const auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(CGradDesc_N_O c_grid_desc_n_o)
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(const CGradDesc_N_O& c_grid_desc_n_o)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
......@@ -859,7 +859,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
// TODO ANT:
// TODO:
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
template <typename YGradGridDesc_M0_O_M1_>
......@@ -957,6 +957,48 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
struct KGradGemmTile_N_K_M
{
// B position
template <typename QGridDesc_K0_M_K1_>
__device__ static const auto
MakeQGridDesc_M0_K_M1(const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1)
{
const auto Q_K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto Q_K1 = q_grid_desc_k0_m_k1.GetLength(I2);
constexpr auto Q_M1 = B1K1;
const auto Q_M0 = M / Q_M1;
const auto q_grid_desc_m0_k_m1 = transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_unmerge_transform(make_tuple(Q_M0, Q_M1)),
make_merge_transform_v3_division_mod(make_tuple(Q_K0, Q_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return q_grid_desc_m0_k_m1;
}
// C position
template <typename KGridDesc_K0_N_K1_>
__device__ static const auto
MakeKGradGridDesc_N_K(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
{
const auto K_K0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K_K1 = k_grid_desc_k0_n_k1.GetLength(I2);
return transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
};
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
......@@ -1067,7 +1109,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_v_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
......@@ -1075,6 +1117,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [M, O]
const auto block_work_idx =
......@@ -1095,6 +1139,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
//
// set up S / dP Gemm (type 1 rcr)
//
......@@ -1211,11 +1260,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
q_grid_desc_k0_m_k1);
// dQ: Gemm A matrix blockwise copy
// dQ: A matrix blockwise copy
auto qgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// dQ: Gemm B matrix blockwise copy
// dQ: B matrix blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
k_grid_desc_n0_k_n1,
......@@ -1357,9 +1406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{});
// dV: blockwise gemm
auto vgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
auto v_slash_k_grad_thread_buf = v_slash_k_grad_blockwise_gemm.GetCThreadBuffer();
// dV: C VGPR-to-global copy
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
......@@ -1376,6 +1425,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{});
// dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
KGradGemmTile_N_K_M::MakeQGridDesc_M0_K_M1(q_grid_desc_k0_m_k1);
const auto kgrad_grid_desc_n_k =
KGradGemmTile_N_K_M::MakeKGradGridDesc_N_K(k_grid_desc_k0_n_k1);
// dK: A matrix VGPR-to-LDS blockwise copy
auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::PassThrough{}};
// dK: B matrix global-to-LDS blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(q_grid_desc_m0_k_m1)>(
q_grid_desc_m0_k_m1,
make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1,
o_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{},
Gemm2::b_block_desc_m0_o_m1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dK: blockwise gemm
/* reuse v_slash_k_grad_blockwise_gemm, v_slash_k_grad_thread_buf */
// dK: C VGPR-to-global copy
const auto kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(kgrad_grid_desc_n_k);
const auto kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() +
make_multi_index(I0, block_work_idx[I1], I0, I0, I0, I0, I0, I0);
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::PassThrough{});
//
// set up Y dot dY
......@@ -1618,38 +1706,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for gemm1 LDS read
constexpr auto p_block_slice_lengths_m0_n0_m1_n1 =
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1{};
SubThreadBlock<BlockSize> p_thread_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_vgrad_gemm_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
static_assert(vgrad_gemm_tile_p_block_slice_window_iterator.GetNumOfAccess() ==
num_vgrad_gemm_loop,
constexpr index_t num_gemm2_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
"");
// TODO: tune gemm2 pipeline
// dV = P^T * dY
vgrad_thread_buf.Clear();
static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV
v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// load VGrad Gemm B
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf);
// load VGrad Gemm A
const auto p_nd_idx =
vgrad_gemm_tile_p_block_slice_window_iterator.GetIndexTupleOfNumber(
vgrad_gemm_loop_idx);
constexpr auto mwave_range =
make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range =
make_tuple(p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]);
if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
const auto p_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range = make_tuple(
p_slice_idx[I2],
p_slice_idx[I2] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range = make_tuple(
p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
make_tuple(p_slice_idx[I0], p_slice_idx[I1], I0, I0, I0, I0, I0, I0),
s_slash_p_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf);
......@@ -1665,13 +1751,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
gemm2_b_block_buf);
block_sync_lds(); // sync before read
vgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_block_buf, vgrad_thread_buf);
v_slash_k_grad_blockwise_gemm.Run(
gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
}); // end gemm dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
vgrad_thread_buf,
v_slash_k_grad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf);
......@@ -1777,6 +1864,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
} // end gemm dQ
// dK = dS^T * dQ
v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// load KGrad Gemm B
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
// load KGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range =
make_tuple(sgrad_slice_idx[I2],
sgrad_slice_idx[I2] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range =
make_tuple(sgrad_slice_idx[I3],
sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf);
}
// kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read
v_slash_k_grad_blockwise_gemm.Run(
gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
}); // end gemm dK
// atomic_add dK
kgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
v_slash_k_grad_thread_buf,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_grid_buf);
// move slice window
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
......@@ -1794,6 +1931,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_o0_n_o1,
pgrad_gemm_tile_v_block_reset_copy_step); // rewind O and step N
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1,
Gemm2::b_block_reset_copy_step); // rewind M
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment