Commit 174e013d authored by danyao12's avatar danyao12
Browse files

bwd mqa/gqa w/ permute

parent 104f9da6
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1_protro.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -74,10 +74,12 @@ __global__ void ...@@ -74,10 +74,12 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
...@@ -166,9 +168,11 @@ __global__ void ...@@ -166,9 +168,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
...@@ -204,9 +208,11 @@ __global__ void ...@@ -204,9 +208,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
...@@ -238,9 +244,11 @@ __global__ void ...@@ -238,9 +244,11 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = bgrad_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3; ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = b1grad_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = ygrad_grid_desc_o0_m_o1; ignore = ygrad_grid_desc_o0_m_o1;
...@@ -810,9 +818,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -810,9 +818,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
bgrad_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
bgrad_gs_ns_ks_lengths, bgrad_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
b1grad_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
...@@ -831,8 +843,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -831,8 +843,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
bgrad_grid_desc_g_n_k_{ bgrad_grid_desc_g_n_k_{Transform::MakeB0GridDescriptor_G_N_K(bgrad_gs_ns_ks_lengths,
Transform::MakeB0GridDescriptor_G_N_K(bgrad_gs_ns_ks_lengths, bgrad_gs_ns_ks_strides)}, bgrad_gs_ns_ks_strides)},
b1grad_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( b1grad_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)}, b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{}, y_grid_desc_mblock_mperblock_oblock_operblock_{},
...@@ -940,8 +952,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -940,8 +952,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<< bgrad_grid_desc_g_n_k_.GetLength(I1) << ", " << bgrad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I2) << '\n'; << bgrad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// bgrad_grid_desc_g_n_k_.Print(); // bgrad_grid_desc_g_n_k_.Print();
std::cout << "b1grad_grid_desc_g_n_k_: " << b1grad_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1grad_grid_desc_g_n_k_: " << b1grad_grid_desc_g_n_k_.GetLength(I0)
<< b1grad_grid_desc_g_n_k_.GetLength(I1) << ", " << ", " << b1grad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1grad_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1grad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1grad_grid_desc_g_n_k_.Print(); // b1grad_grid_desc_g_n_k_.Print();
} }
...@@ -963,9 +975,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -963,9 +975,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -1094,9 +1108,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1094,9 +1108,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.bgrad_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b1grad_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_, arg.ygrad_grid_desc_o0_m_o1_,
...@@ -1163,7 +1179,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1163,7 +1179,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && b_g <= c_g)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
{ {
return false; return false;
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2_protro.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -74,10 +74,12 @@ __global__ void ...@@ -74,10 +74,12 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
...@@ -167,9 +169,11 @@ __global__ void ...@@ -167,9 +169,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
...@@ -205,9 +209,11 @@ __global__ void ...@@ -205,9 +209,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
...@@ -239,9 +245,11 @@ __global__ void ...@@ -239,9 +245,11 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = bgrad_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3; ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = b1grad_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = ygrad_grid_desc_m0_o_m1; ignore = ygrad_grid_desc_m0_o_m1;
...@@ -827,9 +835,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -827,9 +835,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
bgrad_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
bgrad_gs_ns_ks_lengths, bgrad_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
b1grad_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
...@@ -847,8 +859,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -847,8 +859,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
bgrad_grid_desc_g_n_k_{ bgrad_grid_desc_g_n_k_{Transform::MakeB0GridDescriptor_G_N_K(bgrad_gs_ns_ks_lengths,
Transform::MakeB0GridDescriptor_G_N_K(bgrad_gs_ns_ks_lengths, bgrad_gs_ns_ks_strides)}, bgrad_gs_ns_ks_strides)},
b1grad_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( b1grad_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)}, b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{}, y_grid_desc_mblock_mperblock_oblock_operblock_{},
...@@ -957,8 +969,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -957,8 +969,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<< bgrad_grid_desc_g_n_k_.GetLength(I1) << ", " << bgrad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I2) << '\n'; << bgrad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// bgrad_grid_desc_g_n_k_.Print(); // bgrad_grid_desc_g_n_k_.Print();
std::cout << "b1grad_grid_desc_g_n_k_: " << b1grad_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1grad_grid_desc_g_n_k_: " << b1grad_grid_desc_g_n_k_.GetLength(I0)
<< b1grad_grid_desc_g_n_k_.GetLength(I1) << ", " << ", " << b1grad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1grad_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1grad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1grad_grid_desc_g_n_k_.Print(); // b1grad_grid_desc_g_n_k_.Print();
} }
...@@ -980,9 +992,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -980,9 +992,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -1115,9 +1129,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1115,9 +1129,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.bgrad_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b1grad_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_m0_o_m1_, arg.ygrad_grid_desc_m0_o_m1_,
...@@ -1196,7 +1212,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1196,7 +1212,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && b_g <= c_g)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
{ {
return false; return false;
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1_protro.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -87,8 +87,8 @@ __global__ void ...@@ -87,8 +87,8 @@ __global__ void
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
...@@ -97,10 +97,12 @@ __global__ void ...@@ -97,10 +97,12 @@ __global__ void
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t bgrad_batch_offset =
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBGradBasePtr(g_idx))); __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
const long_index_t b1grad_batch_offset = __builtin_amdgcn_readfirstlane( arg_ptr[group_id].compute_base_ptr_of_batch_.GetBGradBasePtr(g_idx)));
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1GradBasePtr(g_idx))); const long_index_t b1grad_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1GradBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
...@@ -146,9 +148,11 @@ __global__ void ...@@ -146,9 +148,11 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
...@@ -185,9 +189,11 @@ __global__ void ...@@ -185,9 +189,11 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
...@@ -520,7 +526,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -520,7 +526,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
...@@ -529,7 +534,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -529,7 +534,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
...@@ -736,9 +740,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -736,9 +740,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
...@@ -871,6 +877,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -871,6 +877,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto bgrad_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.bgrad_gs_ns_ks_lengths, problem_desc.bgrad_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths; std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
...@@ -893,6 +901,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -893,6 +901,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1( const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto b1grad_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1grad_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1grad_gs_gemm1ns_gemm1ks_strides);
const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N( const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
...@@ -983,9 +994,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -983,9 +994,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
...@@ -1186,7 +1199,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1186,7 +1199,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && c_g / b_g == arg.h_ratio_)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n &&
c_g % b_g == 0 && c_g / b_g == arg.h_ratio_))
{ {
return false; return false;
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2_protro.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -87,8 +87,8 @@ __global__ void ...@@ -87,8 +87,8 @@ __global__ void
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
...@@ -97,10 +97,12 @@ __global__ void ...@@ -97,10 +97,12 @@ __global__ void
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t bgrad_batch_offset =
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBGradBasePtr(g_idx))); __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
const long_index_t b1grad_batch_offset = __builtin_amdgcn_readfirstlane( arg_ptr[group_id].compute_base_ptr_of_batch_.GetBGradBasePtr(g_idx)));
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1GradBasePtr(g_idx))); const long_index_t b1grad_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1GradBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
...@@ -145,9 +147,11 @@ __global__ void ...@@ -145,9 +147,11 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
...@@ -184,9 +188,11 @@ __global__ void ...@@ -184,9 +188,11 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
...@@ -582,7 +588,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -582,7 +588,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
...@@ -591,7 +596,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -591,7 +596,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
...@@ -806,9 +810,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -806,9 +810,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
...@@ -941,6 +947,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -941,6 +947,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto bgrad_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.bgrad_gs_ns_ks_lengths, problem_desc.bgrad_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths; std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
...@@ -963,6 +971,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -963,6 +971,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1( const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto b1grad_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1grad_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1grad_gs_gemm1ns_gemm1ks_strides);
const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N( const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
...@@ -1053,9 +1064,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1053,9 +1064,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
...@@ -1255,7 +1268,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1255,7 +1268,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && c_g / b_g == arg.h_ratio_)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n &&
c_g % b_g == 0 && c_g / b_g == arg.h_ratio_))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
template <typename InputDataType,
typename D0DataType,
typename OutputDataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatLSE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename SElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1,
typename KGridDesc_N_K,
typename D0GridDesc_M_N,
typename ZGridDesc_M_N,
typename VGridDesc_O0_N_O1,
typename YGridDesc_M_O,
typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto V_K2 = BK1;
static constexpr auto V_K1 = mfma.num_input_blks;
static constexpr auto V_K0 = KPerBlock / V_K1 / V_K2;
static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
}
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2()
{
// V matrix in Vgpr, dst of threadwise copy
return make_naive_tensor_descriptor_packed(
make_tuple(Number<V_K0>{}, I1, I1, Number<V_N1>{}, I1, I1, Number<V_K2>{}));
}
template <typename AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2>
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2& acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2)
{
// acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 to a_src_thread_desc_k0_m_k1
// m0_m1_m2_m3 -> k0
// n0_n1_n2 -> m
// m4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
const auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
const auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
const auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
const auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
const auto m3 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
const auto m4 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
const auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
return transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2, m3)),
make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2)),
make_pass_through_transform(m4)),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 7>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = Gemm0NWaves;
constexpr index_t NWave = Gemm0MWaves;
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetA2BlockDescriptor_K0_M_K1()
{
return make_naive_tensor_descriptor(
make_tuple(Number<Gemm2Param::A_K0>{},
Number<Gemm2Param::Gemm2_M>{},
Number<Gemm2Param::A_K1>{}),
make_tuple(Number<Gemm2Param::Gemm2_M + Gemm2Param::A_LdsPad>{} *
Number<Gemm2Param::A_K1>{},
Number<Gemm2Param::A_K1>{},
I1));
}
__host__ __device__ static constexpr bool
CheckValidity(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDesc_M_O& y_grid_desc_m_o)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(O != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
O % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(const YGridDesc_M_O& y_grid_desc_m_o)
{
const auto M = y_grid_desc_m_o.GetLength(I0);
const auto O = y_grid_desc_m_o.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
y_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
template <typename SrcBlockwiseGemm>
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4(const LSEGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr auto SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
SrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// M0 MXdlPerWave, M1 MWave, M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
const auto M0 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I0);
const auto M1 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I2);
const auto M2 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I4);
const auto M3 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I5);
const auto M4 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I6);
const auto lse_grid_desc_mb_m0_m1_m2_m3_m4 = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, M0, M1, M2, M3, M4))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}));
return lse_grid_desc_mb_m0_m1_m2_m3_m4;
}
__device__ static auto MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1)
{
const auto O0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto O1 = k_grid_desc_k0_n_k1.GetLength(I2);
const auto O = O0 * O1;
const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto k_grid_desc_n_o = transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(O0, O1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
k_grid_desc_n_o,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
__device__ static auto MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1)
{
const auto O0 = v_grid_desc_o0_n_o1.GetLength(I0);
const auto N = v_grid_desc_o0_n_o1.GetLength(I1);
const auto O1 = v_grid_desc_o0_n_o1.GetLength(I2);
const auto O = O0 * O1;
const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto v_grid_desc_n_o = transform_tensor_descriptor(
v_grid_desc_o0_n_o1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(O0, O1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
v_grid_desc_n_o,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
__device__ static auto MakeQGradGridDesc_M_K(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1)
{
const auto K_K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto K_K1 = q_grid_desc_k0_m_k1.GetLength(I2);
return transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(M),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const KGridDesc_N_K& k_grid_desc_n_k)
{
return BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, KPerBlock, KGridDesc_N_K>(
k_grid_desc_n_k);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(YGridDesc_M_O{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// Q / K / V / dY
struct GemmBlockwiseCopy
{
__device__ static auto
MakeVGridDescriptor_K0_K1_N0_N1_N2_N3_K2(const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1)
{
const auto K0_ = v_grid_desc_o0_n_o1.GetLength(I0);
const auto N_ = v_grid_desc_o0_n_o1.GetLength(I1);
const auto K1_ = v_grid_desc_o0_n_o1.GetLength(I2);
constexpr auto V_N3 = NPerXdl;
constexpr auto V_N2 = Gemm0NWaves;
const auto V_N0 = N_ / NPerBlock;
const auto v_grid_desc_n_k = transform_tensor_descriptor(
v_grid_desc_o0_n_o1,
make_tuple(make_pass_through_transform(N_),
make_merge_transform_v3_division_mod(make_tuple(K0_, K1_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
v_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(V_N0, V_N1, V_N2, V_N3)),
make_unmerge_transform(make_tuple(V_K0, V_K1, V_K2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<2, 3, 4, 5>{}, Sequence<0, 1, 6>{}));
}
// Q matrix in LDS memory, dst of blockwise copy
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// K matrix in LDS memory, dst of blockwise copy
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// V matrix in Vgpr, dst of threadwise copy
static constexpr auto v_thread_desc_k0_k1_n0_n1_n2_n3_k2 =
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2();
// dY matrix in LDS memory, dst of blockwise copy
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
template <typename GridDesc_K0_M_K1>
using QBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using KBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_K1_N0_N1_N2_N3_K2>
using VBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<InputDataType,
GemmDataType,
GridDesc_K0_K1_N0_N1_N2_N3_K2,
decltype(v_thread_desc_k0_k1_n0_n1_n2_n3_k2),
decltype(
v_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetLengths()),
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
BK1,
1,
true /* ResetCoordAfterRun */>;
template <typename GridDesc_K0_M_K1>
using YGradBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr auto gemm_tile_q_block_slice_copy_step =
make_multi_index(0, -MPerBlock, 0);
static constexpr auto gemm_tile_ygrad_block_slice_copy_step =
make_multi_index(0, -MPerBlock, 0);
};
// dP Gemm (type 1 rcc)
template <typename BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2)
{
// b_thread_desc_k0_k1_n0_n1_n2_n3_k2 to b_thread_desc_k0_n_k1
// k0_k1 -> k0
// n0_n1_n2_n3 -> n
// k2 -> k1
const auto k0 = b_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetLength(I0);
const auto k1 = b_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetLength(I1);
const auto n0 = b_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetLength(I2);
const auto n1 = b_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetLength(I3);
const auto n2 = b_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetLength(I4);
const auto n3 = b_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetLength(I5);
const auto k2 = b_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetLength(I6);
return transform_tensor_descriptor(
b_thread_desc_k0_k1_n0_n1_n2_n3_k2,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(k0, k1)),
make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_pass_through_transform(k2)),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4, 5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, 1, 1>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_src_thread_desc_k0_n_k1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_src_thread_desc_k0_n_k1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
false,
KPack * XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, KPack, false>{}.K0PerXdlops,
KPack>;
};
// dV / dK Gemm (type 2 rrr)
template <typename ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2>
struct Gemm1
{
private:
static constexpr auto m0 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I0);
static constexpr auto n0 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I1);
static constexpr auto m1 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I2);
static constexpr auto n1 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I3);
static constexpr auto m2 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I4);
static constexpr auto m3 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I5);
static constexpr auto m4 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I6);
static constexpr auto n2 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I7);
// M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
static constexpr auto M3 = ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I5);
public:
static constexpr auto AThreadSliceLength_K0 = Number<Gemm1KPerBlock / m4 / M3>{};
static constexpr auto AThreadSliceLength_M = Number<n0 * n1 * n2>{};
static constexpr auto AThreadSliceLength_K1 = Number<m4>{};
// A source matrix layout in AccVGPR
static constexpr auto a_src_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
make_tuple(AThreadSliceLength_K0, AThreadSliceLength_M, AThreadSliceLength_K1));
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bn0_k_bn1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, 1, 1>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, 1, 1>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr auto ASrcScalarPerVector = m4;
using AThreadSliceLengths_K0_M_K1 = decltype(a_thread_desc_k0_m_k1.GetLengths());
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy =
ThreadwiseTensorSliceTransfer_StaticToStatic<FloatGemmAcc,
GemmDataType,
decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1),
ElementwiseOp,
AThreadSliceLengths_K0_M_K1,
Sequence<1, 0, 2>,
2,
ASrcScalarPerVector>;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static constexpr auto a_block_slice_copy_step = make_tuple(AThreadSliceLength_K0, I0, I0);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = mfma.group_size;
static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4
static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1
static constexpr index_t GemmMRepeat = NXdlPerWave; // 1 // 1
static constexpr index_t GemmNRepeat = Gemm1NXdlPerWave; // 1 // 2
static constexpr index_t GemmKLoop = MPerBlock / Gemm1KPerBlock; // 128/32=4 // 64/32=2
static constexpr index_t B_K3 = GemmKPack; // 4 // 4
static constexpr index_t B_K2 = M3; // 2 // 2
static constexpr index_t B_K1 = Gemm1KPerBlock / B_K2 / B_K3; // 4 // 4
static constexpr index_t B_K0 = GemmKLoop; // 4 // 2
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = b_block_desc_bn0_k_bn1.GetLength(I0);
const auto K_ = b_block_desc_bn0_k_bn1.GetLength(I1);
const auto N1_ = b_block_desc_bn0_k_bn1.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 64)
b_block_desc_bn0_k_bn1,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 64
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_block_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(
GemmNRepeat, GemmNWave, NPerXdl)), //(1, 1, 32) //(2, 1, 32)
make_unmerge_transform(
make_tuple(B_K0, B_K1, B_K2, B_K3))), //(4, 4, 2, 4) //(2, 4, 2, 4)
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}));
}
static constexpr auto b_block_desc_n0_n1_n2_k0_k1_k2_k3 =
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3();
using BThreadSlice_N0_N1_N2_K0_K1_K2_K3 = Sequence<GemmNRepeat, 1, 1, 1, B_K1, 1, B_K3>;
static constexpr auto b_thread_desc_n0_n1_n2_k0_k1_k2_k3 =
make_naive_tensor_descriptor_packed(
make_tuple(Number<GemmNRepeat>{}, I1, I1, I1, Number<B_K1>{}, I1, Number<B_K3>{}));
__host__ __device__ static constexpr auto MakeBThreadDesc_K0_N_K1()
{
constexpr auto b_thread_desc_n_k = transform_tensor_descriptor(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<GemmNRepeat>{}, I1, I1)),
make_merge_transform_v3_division_mod(
make_tuple(I1, Number<B_K1>{}, I1, Number<B_K3>{}))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_thread_desc_n_k,
make_tuple(make_pass_through_transform(Number<GemmNRepeat>{}),
make_unmerge_transform(make_tuple(Number<B_K1>{}, Number<B_K3>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
}
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
1,
true>;
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0);
static constexpr auto b_block_reset_copy_step = make_multi_index(0, 0, 0, -B_K0, 0, 0, 0);
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)),
NPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
NXdlPerWave,
Gemm1NXdlPerWave,
GemmKPack,
true, // TransposeC
GemmKPack, // AMmaKStride
GemmKPack>;
};
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
struct Gemm2Params
{
static constexpr index_t Gemm2_M = MPerBlock; // 128 // 64
static constexpr index_t Gemm2_K = NPerBlock; // 128 // 128
static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 32 // 64
static constexpr index_t Sum_K = Gemm2KPerBlock;
static constexpr index_t A_K1 = 8; // dS will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1;
static constexpr index_t A_LdsPad = 0; // how many multiples of K1 per M * K1 elements
static_assert(Sum_K % NPerXdl == 0, "");
static constexpr index_t GemmNWave = Gemm2_N / Gemm2NXdlPerWave / NPerXdl; // 1 // 2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave; // 4 // 2
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
static constexpr index_t B_K0 = GemmKLoop; // 2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Sum_K - 1;
constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return Sequence<m0, k0, m1, k1, m2, k2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
}
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1()
{
return generate_sequence_v2(
[](auto I) { return GetABlockSliceLengths_M0_K0_M1_K1_M2_K2().At(I); },
Number<4>{});
}
using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
};
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
struct Gemm2
{
private:
static constexpr auto a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
ASrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
static constexpr auto M0 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); // repeat
static constexpr auto N0 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
static constexpr auto M1 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); // wave
static constexpr auto N1 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
static constexpr auto M2 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); // xdl
static constexpr auto M3 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
static constexpr auto M4 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
static constexpr auto N2 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static constexpr auto a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
ASrcBlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// // B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_n0_k_n1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm2Params::GemmMRepeat,
Gemm2Params::GemmMWave,
MPerXdl>(ABlockDesc_K0_M_K1{});
}
template <typename BBlockDesc_K0_N_K1>
__host__ __device__ static constexpr auto
MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_K0_N_K1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm2Params::GemmNRepeat, 1, 1>(
BBlockDesc_K0_N_K1{});
}
__host__ __device__ static constexpr auto MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2()
{
const auto K0_ = a_block_desc_k0_m_k1.GetLength(I0);
const auto M_ = a_block_desc_k0_m_k1.GetLength(I1);
const auto K1_ = a_block_desc_k0_m_k1.GetLength(I2);
const auto a_block_desc_m_k = transform_tensor_descriptor(
a_block_desc_k0_m_k1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0_, K1_)),
make_pass_through_transform(M_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return transform_tensor_descriptor(
a_block_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(I1, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(I1, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__ __device__ static auto MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2()
{
const auto a_thread_origin_on_block_idx =
ASrcBlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0);
constexpr auto a_block_slice_lengths_m0_k0_m1_k1 =
typename Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1{}; // mrepeat, nrepeat,
// mwaves, nwaves,
return make_tuple(
a_thread_origin_on_block_idx[I0], // mrepeat
a_thread_origin_on_block_idx[I1], // nrepeat
a_thread_origin_on_block_idx[I2] % a_block_slice_lengths_m0_k0_m1_k1[I2], // mwave
a_thread_origin_on_block_idx[I3] % a_block_slice_lengths_m0_k0_m1_k1[I3], // nwave
a_thread_origin_on_block_idx[I4], // xdlops
a_thread_origin_on_block_idx[I5],
a_thread_origin_on_block_idx[I6],
a_thread_origin_on_block_idx[I7]);
}
static constexpr auto a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2 =
MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2();
using ASrcBlockSliceWindowIterator =
SpaceFillingCurve<Sequence<M0, N0, M1, N1>,
Sequence<0, 1, 2, 3>,
typename Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1,
false>;
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2),
ElementwiseOp,
Sequence<Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I0), // ThreadSliceLengths
Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I1),
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = b_block_desc_n0_k_n1.GetLength(I0);
const auto K_ = b_block_desc_n0_k_n1.GetLength(I1);
const auto N1_ = b_block_desc_n0_k_n1.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
b_block_desc_n0_k_n1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_block_desc_n_k,
make_tuple(
make_unmerge_transform(make_tuple(Gemm2Params::GemmNRepeat,
Gemm2Params::GemmNWave,
NPerXdl)), //(1, 1, 32) //(1, 2, 32)
make_unmerge_transform(
make_tuple(Gemm2Params::B_K0,
Gemm2Params::B_K1,
Gemm2Params::B_K2,
Gemm2Params::B_K3))), //(2, 4, 2, 8) //(2, 4, 2, 8)
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}));
}
static constexpr auto b_block_desc_n0_n1_n2_k0_k1_k2_k3 =
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3();
using BThreadSlice_N0_N1_N2_K0_K1_K2_K3 =
Sequence<Gemm2Params::GemmNRepeat, 1, 1, 1, Gemm2Params::B_K1, 1, Gemm2Params::B_K3>;
static constexpr auto b_thread_desc_n0_n1_n2_k0_k1_k2_k3 =
make_naive_tensor_descriptor_packed(make_tuple(Number<Gemm2Params::GemmNRepeat>{},
I1,
I1,
I1,
Number<Gemm2Params::B_K1>{},
I1,
Number<Gemm2Params::B_K3>{}));
__host__ __device__ static constexpr auto MakeBThreadDesc_K0_N_K1()
{
constexpr auto b_thread_desc_n_k = transform_tensor_descriptor(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<Gemm2Params::GemmNRepeat>{}, I1, I1)),
make_merge_transform_v3_division_mod(make_tuple(
I1, Number<Gemm2Params::B_K1>{}, I1, Number<Gemm2Params::B_K3>{}))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_thread_desc_n_k,
make_tuple(make_pass_through_transform(Number<Gemm2Params::GemmNRepeat>{}),
make_unmerge_transform(make_tuple(Number<Gemm2Params::B_K1>{},
Number<Gemm2Params::B_K3>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
}
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
1,
true>;
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0);
static constexpr auto b_block_reset_copy_step =
make_multi_index(0, 0, 0, -Gemm2Params::B_K0, 0, 0, 0);
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1),
decltype(MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_k0_m_k1)),
decltype(MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)),
MPerBlock,
Gemm1NPerBlock,
Gemm2Params::Sum_K,
MPerXdl,
NPerXdl,
Gemm2Params::GemmMRepeat,
Gemm2Params::GemmNRepeat,
Gemm2Params::GemmKPack,
true, // TransposeC
Gemm2Params::GemmKPack *
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params::GemmKPack, false>{}
.K0PerXdlops,
Gemm2Params::GemmKPack>;
static constexpr auto c_block_slice_copy_step =
make_multi_index(-Gemm2Params::GemmMRepeat, 0, 0, 0, 0, 0, 0, 0);
template <typename CGradDesc_M_N>
__host__ __device__ static auto
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(const CGradDesc_M_N& c_grid_desc_m_n)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(I1, Gemm2Params::GemmMWave, MPerXdl)),
make_unmerge_transform(make_tuple(I1, Gemm2Params::GemmNWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
BlockwiseGemm{}.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(
c_grid_desc_m0_n0_m1_n1_m2_n2);
return c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4;
}
static constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
BlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
__host__ __device__ static auto GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4()
{
return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0));
}
template <typename CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
OutputDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
ElementwiseOp, // CElementwiseOperation
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
};
// S Gemm (type 3 rcc)
struct Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>;
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
// dY matrix in LDS memory, dst of blockwise copy
static constexpr auto ygrad_block_desc_o0_m_o1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
__host__ __device__ static constexpr auto MakeYGradBlockDesc_M_O()
{
const auto O0_ = ygrad_block_desc_o0_m_o1.GetLength(I0);
const auto M_ = ygrad_block_desc_o0_m_o1.GetLength(I1);
const auto O1_ = ygrad_block_desc_o0_m_o1.GetLength(I2);
static_assert(O0_ * O1_ == BlockSliceLength_O_, "");
static_assert(M_ == BlockSliceLength_M_, "");
return transform_tensor_descriptor( //(128, 64)
ygrad_block_desc_o0_m_o1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(O0_, O1_)), //(8, 8)
make_pass_through_transform(M_)), // 128
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
}
static constexpr auto ygrad_block_desc_m_o = MakeYGradBlockDesc_M_O();
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// D0
static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
static constexpr auto D0M0 = Number<MPerBlock>{} / Number<MPerXdl>{};
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M0, D0M1, D0M2)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 3, 5>{}, Sequence<1, 4>{}));
return d0_grid_desc_m0_n0_m1_m2_n1_m3;
}
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Operator
{
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
}
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
constexpr auto d0_n0_n1_m0_m1_m2 = transform_tensor_descriptor(
d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple(D0M1 / I2, I2)),
make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2;
}
static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
4, // SrcScalarPerVector
4, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
4, // SrcScalarPerVector
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
};
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(
p_slash_sgrad_bytes_end, softmax_bytes_end, d0_bytes_end, c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask,
typename YGradGridDesc_O0_M_O1>
__device__ static void
Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
const D0DataType* __restrict__ p_d0_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const SElementwiseOperation& s_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const KGridDesc_K0_N_K1& kgrad_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const VGridDesc_O0_N_O1& vgrad_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph,
const index_t z_random_matrix_offset,
const index_t raw_n_padded,
const index_t block_idx_n)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_k_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_v_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, vgrad_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, kgrad_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t block_work_idx_n = Deterministic ? block_idx_n : block_work_idx[I0];
// HACK: this force n_block_data_idx_on_grid into SGPR
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx_n * NPerBlock);
const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = MPerBlock / Gemm1KPerBlock;
// 6 GEMM operations are categorized into 4 buckets. SizeK == SizeO == head_dim
// dP_MNO Gemm (Gemm0 rcc)
// dV_NOM / dK_NKM Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm2 crr)
// S_MNK Gemm (Gemm3 rcc)
// LDS allocation for Q / K / V / dY
auto q_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::q_block_space_offset,
GemmBlockwiseCopy::q_block_desc_k0_m_k1.GetElementSpaceSize());
auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize());
auto v_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2.GetElementSpaceSize());
auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1.GetElementSpaceSize());
// Q matrix blockwise copy
auto gemm_tile_q_blockwise_copy =
typename GemmBlockwiseCopy::template QBlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
a_element_op,
GemmBlockwiseCopy::q_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// K matrix blockwise copy
auto gemm_tile_k_blockwise_copy =
typename GemmBlockwiseCopy::template KBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
GemmBlockwiseCopy::k_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
const auto v_grid_desc_k0_k1_n0_n1_n2_n3_k2 =
GemmBlockwiseCopy::MakeVGridDescriptor_K0_K1_N0_N1_N2_N3_K2(v_grid_desc_o0_n_o1);
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// V matrix blockwise copy
auto gemm_tile_v_blockwise_copy =
typename GemmBlockwiseCopy::template VBlockwiseCopy<decltype(
v_grid_desc_k0_k1_n0_n1_n2_n3_k2)>(
v_grid_desc_k0_k1_n0_n1_n2_n3_k2,
make_multi_index(
0, wave_m_n_id[I0], block_work_idx_n, 0, wave_id[I1], wave_m_n_id[I1], 0));
ignore = b1_element_op;
// dY matrix blockwise copy
auto gemm_tile_ygrad_blockwise_copy =
typename GemmBlockwiseCopy::template YGradBlockwiseCopy<decltype(
ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
a_element_op,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
//
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2)>;
// dP: blockwise gemm
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
//
// set up S Gemm (type 4 rcc)
//
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm3::BlockwiseGemm{}; // TransposeC
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
//
// set up dV / dK Gemm (type 2 rrr)
//
using Gemm1 =
Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()),
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2())>;
// Gemm1: VGPR allocation for A and B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
// dV: A matrix blockwise copy
auto vgrad_gemm_tile_p_blockwise_copy =
typename Gemm1::template ABlockwiseCopy<tensor_operation::element_wise::Relu>{
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
// dV: blockwise gemm
auto vgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0), make_tuple(0, 0, 0, 0)};
// dV: B matrix blockwise copy
auto ygrad_thread_origin = vgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
// dV: B matrix LDS-to-VGPR blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy = typename Gemm1::BBlockwiseCopy{
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
make_multi_index(0, // nrepeat
ygrad_thread_origin[I1], // nwave
ygrad_thread_origin[I2], // nperxdl
0, // k0
0, // k1
ygrad_thread_origin[I3] / Gemm1::GemmKPack, // k2
0)};
auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
// dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(vgrad_grid_desc_o0_n_o1);
// dK: A matrix blockwise copy
auto kgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
tensor_operation::element_wise::PassThrough{}};
// dK: blockwise gemm
auto kgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0), make_tuple(0, 0, 0, 0)};
// dK: B matrix blockwise copy
auto q_thread_origin = kgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
// dK: B matrix LDS-to-VGPR blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy = typename Gemm1::BBlockwiseCopy{
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
make_multi_index(0, // nrepeat
q_thread_origin[I1], // nwave
q_thread_origin[I2], // nperxdl
0, // k0
0, // k1
q_thread_origin[I3] / Gemm1::GemmKPack, // k2
0)};
auto kgrad_thread_buf = kgrad_blockwise_gemm.GetCThreadBuffer();
// dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(kgrad_grid_desc_k0_n_k1);
//
// set up dQ Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::p_slash_sgrad_block_space_offset,
Gemm2::a_block_desc_k0_m_k1.GetElementSpaceSize());
auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
// dQ: A matrix VGPR-to-LDS blockwise copy
auto qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
Gemm2::a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2,
Gemm2::MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2(),
tensor_operation::element_wise::PassThrough{}};
// dQ: blockwise gemm
auto qgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
qgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto k_thread_origin = qgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
// dQ: B matrix LDS-to-VGPR blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy = typename Gemm2::BBlockwiseCopy{
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
make_multi_index(0, // nrepeat
k_thread_origin[I1], // nwave
k_thread_origin[I2], // nperxdl
0, // k0
0, // k1
k_thread_origin[I3] / Gemm2Params::GemmKPack, // k2
0)}; // k3
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
// dQ: transform output tensor descriptors
const auto qgrad_grid_desc_m_k = MakeQGradGridDesc_M_K(q_grid_desc_k0_m_k1);
const auto qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4 =
Gemm2::MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(qgrad_grid_desc_m_k);
// dQ: C VGPR-to-global copy
const auto qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4() +
make_multi_index((num_gemm0_m_block_outer_loop - 1) * Gemm2Params::GemmMRepeat,
I0,
I0,
I0,
I0,
I0,
I0,
I0);
auto qgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4),
decltype(scale_rp_dropout)>(qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4,
scale_rp_dropout);
//
// Blockwise softmax
//
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2 =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths() /
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I2);
constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I3);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I4);
constexpr auto tm3 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I5);
constexpr auto tm4 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I6);
constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I7);
// get acc0 thread map
constexpr auto n0_m_n1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tn0 * tn1, tn2)),
make_pass_through_transform(I1)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_n0_m_n1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(tn0 * tn1, tm0 * tm1 * tm2 * tm3 * tm4, tn2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor =
chain_tensor_adaptors(n0_m_n1_to_m_n_adaptor, threadid_to_n0_m_n1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto n0 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto m1 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto n1 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto m2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto m3 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto m4 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2 * tm3 * tm4, tn0 * tn1 * tn2));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2 * m3 * m4, n0 * n1 * n2));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
constexpr auto lse_thread_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2, m3, m4));
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatLSE,
FloatLSE,
decltype(lse_grid_desc_mb_m0_m1_m2_m3_m4),
decltype(lse_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MGroupNum
m3, // MInputNum
m4, // RegisterNum
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
uint8_t,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // MBlockId
block_work_idx_n, // NBlockId
0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], //
0, //
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto P_M3 = p_block_lengths[I5];
constexpr auto P_M4 = p_block_lengths[I6];
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto ygrad_thread_desc_m_o = make_naive_tensor_descriptor_packed(
make_tuple(YDotYGrad_M_O::ThreadSliceLength_M, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
constexpr auto ygrad_thread_cluster_desc = make_cluster_descriptor(
Sequence<YDotYGrad_M_O::ThreadClusterLength_M, YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1>{});
const auto ygrad_thread_cluster_idx = ygrad_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto ygrad_thread_data_on_block_idx =
ygrad_thread_cluster_idx * ygrad_thread_desc_m_o.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(num_gemm0_m_block_outer_loop - 1, I0, I0, I0) +
y_thread_data_on_block_idx;
// performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
// performs for ygrad
auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
GemmDataType,
FloatGemmAcc,
decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()),
Sequence<0, 1>,
1, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
ygrad_thread_data_on_block_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
MPerBlock);
constexpr auto y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, P_M0, P_M1, P_M2, P_M3, P_M4));
constexpr auto y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4 =
lse_thread_desc_mb_m0_m1_m2_m3_m4; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4),
decltype(y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic)
{
block_sync_lds();
}
// Initialize dK&dV
kgrad_thread_buf.Clear();
vgrad_thread_buf.Clear();
// load k
gemm_tile_k_blockwise_copy.Run(k_grid_desc_k0_n_k1,
k_grid_buf,
GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf,
I0);
// load v
gemm_tile_v_blockwise_copy.Run(v_grid_desc_k0_k1_n0_n1_n2_n3_k2,
v_grid_buf,
GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
v_thread_buf);
do
{
auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
// load ygrad
gemm_tile_ygrad_blockwise_copy.Run(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1,
ygrad_block_buf,
I0);
block_sync_lds();
//
// calculate Y dot dY
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
y_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
ygrad_threadwise_copy.Run(YDotYGrad_M_O::ygrad_block_desc_m_o,
ygrad_block_buf,
ygrad_thread_desc_m_o,
make_tuple(I0, I0),
ygrad_thread_buf);
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto y_offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
constexpr auto ygrad_offset =
ygrad_thread_desc_m_o.CalculateOffset(make_multi_index(iM, iO));
y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<y_offset>{}] * ygrad_thread_buf[Number<ygrad_offset>{}];
});
});
// blockwise reduction using atomic_add
block_sync_lds();
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(idx_on_block,
true,
y_dot_ygrad_thread_accum_buf[iM] *
p_dropout); // p_dropoutD1
});
block_sync_lds();
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after
// barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run(y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4,
y_dot_ygrad_block_accum_buf,
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
lse_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0),
lse_thread_buf);
// gemm dP
// dP = dY * V^T
pgrad_thread_buf.Clear();
pgrad_blockwise_gemm.Run(ygrad_block_buf, v_thread_buf, pgrad_thread_buf);
// gemm S
// S = Q * K^T
s_slash_p_thread_buf.Clear();
gemm_tile_q_blockwise_copy.Run(q_grid_desc_k0_m_k1,
q_grid_buf,
GemmBlockwiseCopy::q_block_desc_k0_m_k1,
q_block_buf,
I0);
block_sync_lds();
s_blockwise_gemm.Run(q_block_buf, k_block_buf, s_slash_p_thread_buf);
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_element_op(s_slash_p_thread_buf(i),
masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i]);
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if constexpr(IsDropout)
{
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
}
block_sync_lds(); // wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator();
constexpr auto pgrad_thread_idx_to_m_n_adaptor =
pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout
bool undropped_flag = s_slash_p_thread_buf[i] >= 0;
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0grad_block_copy_lds_to_global.Run(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// gemm dV
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// main body
static_for<0, num_gemm1_k_block_inner_loop, 1>{}([&](auto i) {
vgrad_gemm_tile_p_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
s_slash_p_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.Run(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
ygrad_block_buf,
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
gemm1_b_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3, Gemm1::b_block_slice_copy_step);
block_sync_lds();
vgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_thread_buf, vgrad_thread_buf);
// block_sync_lds();
});
} // end gemm dV
// gemm dK
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// main body
static_for<0, num_gemm1_k_block_inner_loop, 1>{}([&](auto i) {
kgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
sgrad_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
kgrad_gemm_tile_q_blockwise_copy.Run(Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
q_block_buf,
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
gemm1_b_thread_buf);
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3, Gemm1::b_block_slice_copy_step);
block_sync_lds();
kgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_thread_buf, kgrad_thread_buf);
// block_sync_lds();
});
} // end gemm dK
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_gemm2_loop = NPerBlock / Gemm2Params::Sum_K;
static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
"");
// TODO: tune gemm2 pipeline
// gemm dQ
// dQ = scalar * dS * K
qgrad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dQ
// load VGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range = make_tuple(
sgrad_slice_idx[I2],
sgrad_slice_idx[I2] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I2));
constexpr auto nwave_range = make_tuple(
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2,
gemm2_a_block_buf);
}
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
gemm2_b_thread_buf);
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before read
qgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_thread_buf, qgrad_thread_buf);
}); // end gemm dQ
// atomic_add dQ
qgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
qgrad_thread_buf,
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
qgrad_grid_buf);
// move slice window
gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
GemmBlockwiseCopy::gemm_tile_q_block_slice_copy_step); // step M
gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1,
GemmBlockwiseCopy::gemm_tile_ygrad_block_slice_copy_step); // step M
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm1::b_block_reset_copy_step); // rewind M
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm2::b_block_reset_copy_step); // rewind K
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm1::b_block_reset_copy_step); // rewind M
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(-1, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write
{
static_assert(Gemm1::GemmMRepeat % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1::GemmNRepeat % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = Gemm1::GemmMWave;
constexpr index_t NWave = Gemm1::GemmNWave;
// TODO: hacky, fix it!
// thread desc same with kgrad_blockwise_gemm
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
// block desc same with kgrad_blockwise_gemm
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
vgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2)), // M2 = MPerXdl
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2, // N2 * N3 * N4 = NPerXdl
N3,
N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// index same with kgrad_blockwise_gemm
const auto c_thread_mtx_on_block =
vgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto vgrad_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
tensor_operation::element_wise::Scale{rp_dropout}};
// shuffle: blockwise copy C from LDS to global
auto vgrad_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(vgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// shuffle: threadwise copy C from VGPR to LDS
auto kgrad_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global
auto kgrad_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(kgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = SpaceFillingCurve<
Sequence<Gemm1::GemmMRepeat, Gemm1::GemmNRepeat, 1, 1, 1, N2, 1, N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
1,
N2,
1,
N4>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, NPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
// dK
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
kgrad_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
kgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
kgrad_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
kgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
kgrad_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
kgrad_grid_desc_nblock_nperblock_oblock_operblock, c_global_step);
}
});
// dV
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
vgrad_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
vgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
vgrad_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
vgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
vgrad_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
vgrad_grid_desc_nblock_nperblock_oblock_operblock, c_global_step);
}
});
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
template <typename InputDataType,
typename D0DataType,
typename OutputDataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatLSE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename SElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1,
typename KGridDesc_N_K,
typename D0GridDesc_M_N,
typename ZGridDesc_M_N,
typename VGridDesc_O0_N_O1,
typename YGridDesc_M_O,
typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto V_K3 = BK1;
static constexpr auto V_K2 = mfma.num_input_blks;
static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3;
static constexpr auto V_K0 = Gemm1NPerBlock / KPerBlock;
static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
}
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto GetKBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
// K matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(K_K0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3()
{
// V matrix in Vgpr, dst of threadwise copy
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<V_K1>{}, I1, I1, Number<V_N1>{}, I1, I1, Number<V_K3>{}));
}
template <typename AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2>
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2& acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2)
{
// acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 to a_src_thread_desc_k0_m_k1
// m0_m1_m2_m3 -> k0
// n0_n1_n2 -> m
// m4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
const auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
const auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
const auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
const auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
const auto m3 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
const auto m4 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
const auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
return transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2, m3)),
make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2)),
make_pass_through_transform(m4)),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 7>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = Gemm0NWaves;
constexpr index_t NWave = Gemm0MWaves;
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetA2BlockDescriptor_K0_M_K1()
{
return make_naive_tensor_descriptor(
make_tuple(Number<Gemm2Param::A_K0>{},
Number<Gemm2Param::Gemm2_M>{},
Number<Gemm2Param::A_K1>{}),
make_tuple(Number<Gemm2Param::Gemm2_M + Gemm2Param::A_LdsPad>{} *
Number<Gemm2Param::A_K1>{},
Number<Gemm2Param::A_K1>{},
I1));
}
template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetB2BlockDescriptor_K0_N_K1()
{
return make_naive_tensor_descriptor(
make_tuple(Number<Gemm2Param::B_K0>{},
Number<Gemm2Param::Gemm2_N>{},
Number<Gemm2Param::B_K1>{}),
make_tuple(Number<Gemm2Param::Gemm2_N + Gemm2Param::B_LdsPad>{} *
Number<Gemm2Param::B_K1>{},
Number<Gemm2Param::B_K1>{},
I1));
}
__host__ __device__ static constexpr bool
CheckValidity(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDesc_M_O& y_grid_desc_m_o)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(O != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
O % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(const YGridDesc_M_O& y_grid_desc_m_o)
{
const auto M = y_grid_desc_m_o.GetLength(I0);
const auto O = y_grid_desc_m_o.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
y_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
template <typename SrcBlockwiseGemm>
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4(const LSEGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr auto SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
SrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// M0 MXdlPerWave, M1 MWave, M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
const auto M0 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I0);
const auto M1 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I2);
const auto M2 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I4);
const auto M3 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I5);
const auto M4 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I6);
const auto lse_grid_desc_mb_m0_m1_m2_m3_m4 = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, M0, M1, M2, M3, M4))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}));
return lse_grid_desc_mb_m0_m1_m2_m3_m4;
}
__device__ static auto MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1)
{
const auto O0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto O1 = k_grid_desc_k0_n_k1.GetLength(I2);
const auto O = O0 * O1;
const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto k_grid_desc_n_o = transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(O0, O1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
k_grid_desc_n_o,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
__device__ static auto MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1)
{
const auto O0 = v_grid_desc_o0_n_o1.GetLength(I0);
const auto N = v_grid_desc_o0_n_o1.GetLength(I1);
const auto O1 = v_grid_desc_o0_n_o1.GetLength(I2);
const auto O = O0 * O1;
const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto v_grid_desc_n_o = transform_tensor_descriptor(
v_grid_desc_o0_n_o1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(O0, O1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
v_grid_desc_n_o,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
__device__ static auto MakeQGradGridDesc_M_K(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1)
{
const auto K0_ = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M_ = q_grid_desc_k0_m_k1.GetLength(I1);
const auto K1_ = q_grid_desc_k0_m_k1.GetLength(I2);
return transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(M_),
make_merge_transform_v3_division_mod(make_tuple(K0_, K1_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const KGridDesc_N_K& k_grid_desc_n_k)
{
return BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, Gemm1NPerBlock, KGridDesc_N_K>(
k_grid_desc_n_k);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(YGridDesc_M_O{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// K / V
struct GemmBlockwiseCopy
{
__device__ static auto
MakeVGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1)
{
const auto K0_ = v_grid_desc_o0_n_o1.GetLength(I0);
const auto N_ = v_grid_desc_o0_n_o1.GetLength(I1);
const auto K1_ = v_grid_desc_o0_n_o1.GetLength(I2);
constexpr auto V_N3 = NPerXdl;
constexpr auto V_N2 = Gemm0NWaves;
const auto V_N0 = N_ / NPerBlock;
const auto v_grid_desc_n_k = transform_tensor_descriptor(
v_grid_desc_o0_n_o1,
make_tuple(make_pass_through_transform(N_),
make_merge_transform_v3_division_mod(make_tuple(K0_, K1_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
v_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(V_N0, V_N1, V_N2, V_N3)),
make_unmerge_transform(make_tuple(V_K0, V_K1, V_K2, V_K3))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<3, 4, 5, 6>{}, Sequence<0, 1, 2, 7>{}));
}
// K matrix in LDS, dst of blockwise copy
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// V matrix in Vgpr, dst of threadwise copy
static constexpr auto v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3();
template <typename GridDesc_K0_N_K1>
using KBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K_K0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_K1_k2_N0_N1_N2_N3_K3>
using VBlockwiseCopy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
GemmDataType,
GridDesc_K0_K1_k2_N0_N1_N2_N3_K3,
decltype(v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
decltype(v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLengths()),
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
BK1,
1,
true /* ResetCoordAfterRun */>;
static constexpr auto VBlockBufferSize = V_K0;
static constexpr auto v_block_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
};
// dP Gemm (type 1 rcc, B in Vgpr)
template <typename BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)
{
// b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 to b_thread_desc_k0_n_k1
// k0_k1_k2 -> k0
// n0_n1_n2_n3 -> n
// k3 -> k1
const auto k0 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0);
const auto k1 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I1);
const auto k2 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I2);
const auto n0 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I3);
const auto n1 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I4);
const auto n2 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I5);
const auto n3 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I6);
const auto k3 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I7);
return transform_tensor_descriptor(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(k0, k1, k2)),
make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_pass_through_transform(k3)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, 1, 1>(
BBlockDesc_BK0_N_BK1{});
}
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_src_thread_desc_k0_n_k1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_src_thread_desc_k0_n_k1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
false,
KPack * XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, KPack, false>{}.K0PerXdlops,
KPack>;
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
};
// dV / dK Gemm (type 2 rrr)
template <typename ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2>
struct Gemm1
{
private:
static constexpr auto m0 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I0);
static constexpr auto n0 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I1);
static constexpr auto m1 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I2);
static constexpr auto n1 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I3);
static constexpr auto m2 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I4);
static constexpr auto m3 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I5);
static constexpr auto m4 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I6);
static constexpr auto n2 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I7);
// M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
static constexpr auto M3 = ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I5);
public:
static constexpr auto AThreadSliceLength_K0 = Number<Gemm1KPerBlock / m4 / M3>{};
static constexpr auto AThreadSliceLength_M = Number<n0 * n1 * n2>{};
static constexpr auto AThreadSliceLength_K1 = Number<m4>{};
// A source matrix layout in AccVGPR
static constexpr auto a_src_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
make_tuple(AThreadSliceLength_K0, AThreadSliceLength_M, AThreadSliceLength_K1));
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, 1, 1>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr auto ASrcScalarPerVector = m4;
using AThreadSliceLengths_K0_M_K1 = decltype(a_thread_desc_k0_m_k1.GetLengths());
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy =
ThreadwiseTensorSliceTransfer_StaticToStatic<FloatGemmAcc,
GemmDataType,
decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1),
ElementwiseOp,
AThreadSliceLengths_K0_M_K1,
Sequence<1, 0, 2>,
2,
ASrcScalarPerVector>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static constexpr auto a_block_slice_copy_step = make_tuple(AThreadSliceLength_K0, I0, I0);
static constexpr auto b_block_slice_copy_step =
make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
NPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
NXdlPerWave,
Gemm1NXdlPerWave,
GemmKPack,
true, // TransposeC
GemmKPack, // AMmaKStride
GemmKPack * XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}
.K0PerXdlops /* BMmaKStride */>;
};
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
struct Gemm2Params
{
static constexpr index_t Gemm2_M = MPerBlock; // 64
static constexpr index_t Gemm2_K = NPerBlock; // 128
static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 128
static constexpr index_t Sum_K = Gemm2KPerBlock;
static constexpr index_t A_K1 = 8; // dS will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1;
static constexpr index_t A_LdsPad = 0; // how many multiples of K1 per M * K1 elements
static_assert(Sum_K % NPerXdl == 0, "");
static constexpr index_t GemmNWave = Gemm2_N / Gemm2NXdlPerWave / NPerXdl; // 1 // 2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave; // 4 // 2
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
static constexpr index_t B_K0 = GemmKLoop; // 2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Sum_K - 1;
constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return Sequence<m0, k0, m1, k1, m2, k2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
}
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1()
{
return generate_sequence_v2(
[](auto I) { return GetABlockSliceLengths_M0_K0_M1_K1_M2_K2().At(I); },
Number<4>{});
}
using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
};
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm, typename BSrcBlockDesc_N0_K_N1>
struct Gemm2
{
private:
static constexpr auto a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
ASrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
static constexpr auto M0 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); // repeat
static constexpr auto N0 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
static constexpr auto M1 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); // wave
static constexpr auto N1 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
static constexpr auto M2 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); // xdl
static constexpr auto M3 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
static constexpr auto M4 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
static constexpr auto N2 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static constexpr auto a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
ASrcBlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm2Params::GemmMRepeat,
Gemm2Params::GemmMWave,
MPerXdl>(ABlockDesc_K0_M_K1{});
}
template <typename BBlockDesc_K0_N_K1>
__host__ __device__ static constexpr auto
MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_K0_N_K1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm2Params::GemmNRepeat, 1, 1>(
BBlockDesc_K0_N_K1{});
}
__host__ __device__ static constexpr auto MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2()
{
const auto K0_ = a_block_desc_k0_m_k1.GetLength(I0);
const auto M_ = a_block_desc_k0_m_k1.GetLength(I1);
const auto K1_ = a_block_desc_k0_m_k1.GetLength(I2);
const auto a_block_desc_m_k = transform_tensor_descriptor(
a_block_desc_k0_m_k1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0_, K1_)),
make_pass_through_transform(M_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return transform_tensor_descriptor(
a_block_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(I1, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(I1, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__ __device__ static auto MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2()
{
const auto a_thread_origin_on_block_idx =
ASrcBlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0);
constexpr auto a_block_slice_lengths_m0_k0_m1_k1 =
typename Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1{}; // mrepeat, nrepeat,
// mwaves, nwaves,
return make_tuple(
a_thread_origin_on_block_idx[I0], // mrepeat
a_thread_origin_on_block_idx[I1], // nrepeat
a_thread_origin_on_block_idx[I2] % a_block_slice_lengths_m0_k0_m1_k1[I2], // mwave
a_thread_origin_on_block_idx[I3] % a_block_slice_lengths_m0_k0_m1_k1[I3], // nwave
a_thread_origin_on_block_idx[I4], // xdlops
a_thread_origin_on_block_idx[I5],
a_thread_origin_on_block_idx[I6],
a_thread_origin_on_block_idx[I7]);
}
static constexpr auto a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2 =
MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2();
using ASrcBlockSliceWindowIterator =
SpaceFillingCurve<Sequence<M0, N0, M1, N1>,
Sequence<0, 1, 2, 3>,
typename Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1,
false>;
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2),
ElementwiseOp,
Sequence<Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I0), // ThreadSliceLengths
Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I1),
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I0);
const auto K_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I1);
const auto N1_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
BSrcBlockDesc_N0_K_N1{},
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_block_desc_n_k,
make_tuple(
make_unmerge_transform(make_tuple(Gemm2Params::GemmNRepeat,
Gemm2Params::GemmNWave,
NPerXdl)), //(1, 1, 32) //(1, 2, 32)
make_unmerge_transform(
make_tuple(Gemm2Params::B_K0,
Gemm2Params::B_K1,
Gemm2Params::B_K2,
Gemm2Params::B_K3))), //(2, 4, 2, 8) //(2, 4, 2, 8)
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}));
}
static constexpr auto b_block_desc_n0_n1_n2_k0_k1_k2_k3 =
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3();
using BThreadSlice_N0_N1_N2_K0_K1_K2_K3 =
Sequence<Gemm2Params::GemmNRepeat, 1, 1, 1, Gemm2Params::B_K1, 1, Gemm2Params::B_K3>;
static constexpr auto b_thread_desc_n0_n1_n2_k0_k1_k2_k3 =
make_naive_tensor_descriptor_packed(make_tuple(Number<Gemm2Params::GemmNRepeat>{},
I1,
I1,
I1,
Number<Gemm2Params::B_K1>{},
I1,
Number<Gemm2Params::B_K3>{}));
__host__ __device__ static constexpr auto MakeBThreadDesc_K0_N_K1()
{
constexpr auto b_thread_desc_n_k = transform_tensor_descriptor(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<Gemm2Params::GemmNRepeat>{}, I1, I1)),
make_merge_transform_v3_division_mod(make_tuple(
I1, Number<Gemm2Params::B_K1>{}, I1, Number<Gemm2Params::B_K3>{}))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_thread_desc_n_k,
make_tuple(make_pass_through_transform(Number<Gemm2Params::GemmNRepeat>{}),
make_unmerge_transform(make_tuple(Number<Gemm2Params::B_K1>{},
Number<Gemm2Params::B_K3>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
}
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
1,
true>;
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0);
static constexpr auto b_block_reset_copy_step =
make_multi_index(0, 0, 0, -Gemm2Params::B_K0, 0, 0, 0);
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1),
decltype(MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_k0_m_k1)),
decltype(MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)),
MPerBlock,
Gemm1NPerBlock,
Gemm2Params::Sum_K,
MPerXdl,
NPerXdl,
Gemm2Params::GemmMRepeat,
Gemm2Params::GemmNRepeat,
Gemm2Params::GemmKPack,
true, // TransposeC
Gemm2Params::GemmKPack *
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params::GemmKPack, false>{}
.K0PerXdlops,
Gemm2Params::GemmKPack>;
static constexpr auto c_block_slice_copy_step =
make_multi_index(-Gemm2Params::GemmMRepeat, 0, 0, 0, 0, 0, 0, 0);
template <typename CGradDesc_M_N>
__host__ __device__ static auto
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(const CGradDesc_M_N& c_grid_desc_m_n)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(I1, Gemm2Params::GemmMWave, MPerXdl)),
make_unmerge_transform(make_tuple(I1, Gemm2Params::GemmNWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
BlockwiseGemm{}.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(
c_grid_desc_m0_n0_m1_n1_m2_n2);
return c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4;
}
static constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
BlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
__host__ __device__ static auto GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4()
{
return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0));
}
template <typename CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
OutputDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
ElementwiseOp, // CElementwiseOperation
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
};
// S Gemm (type 4 rcc, B in LDS)
template <typename BSrcBlockDesc_K0_N_K1>
struct Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
BSrcBlockDesc_K0_N_K1,
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(BSrcBlockDesc_K0_N_K1{})),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>;
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KPerBlock);
static constexpr auto b_block_reset_copy_step = make_multi_index(0, 0, 0, -Gemm1NPerBlock);
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
// TODO:
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
template <typename YGradGridDesc_M0_O_M1_>
__device__ static auto
MakeYGradGridDesc_O0_M_O1(const YGradGridDesc_M0_O_M1_& ygrad_grid_desc_m0_o_m1)
{
const auto M0 = ygrad_grid_desc_m0_o_m1.GetLength(I0);
const auto O = ygrad_grid_desc_m0_o_m1.GetLength(I1);
const auto M1 = ygrad_grid_desc_m0_o_m1.GetLength(I2);
constexpr auto Y_O1 = AK1;
const auto Y_O0 = O / Y_O1;
const auto ygrad_grid_desc_o0_m_o1 = transform_tensor_descriptor(
ygrad_grid_desc_m0_o_m1,
make_tuple(make_unmerge_transform(make_tuple(Y_O0, Y_O1)),
make_merge_transform_v3_division_mod(make_tuple(M0, M1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return ygrad_grid_desc_o0_m_o1;
}
};
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
struct QGradGemmTile_M_K_N
{
template <typename KGridDesc_K0_N_K1_>
__device__ static auto MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
{
const auto K0_ = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N_ = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K1_ = k_grid_desc_k0_n_k1.GetLength(I2);
constexpr auto N1_ = B1K1;
const auto N0_ = N_ / N1_;
const auto k_grid_desc_n0_k_n1 = transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(make_tuple(N0_, N1_)),
make_merge_transform_v3_division_mod(make_tuple(K0_, K1_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return k_grid_desc_n0_k_n1;
}
};
struct KGradGemmTile_N_K_M
{
// B position
template <typename QGridDesc_K0_M_K1_>
__device__ static auto MakeQGridDesc_M0_K_M1(const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1)
{
const auto Q_K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto Q_K1 = q_grid_desc_k0_m_k1.GetLength(I2);
constexpr auto Q_M1 = B1K1;
const auto Q_M0 = M / Q_M1;
const auto q_grid_desc_m0_k_m1 = transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_unmerge_transform(make_tuple(Q_M0, Q_M1)),
make_merge_transform_v3_division_mod(make_tuple(Q_K0, Q_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return q_grid_desc_m0_k_m1;
}
};
// D0
static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
static constexpr auto D0M0 = Number<MPerBlock>{} / Number<MPerXdl>{};
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M0, D0M1, D0M2)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 3, 5>{}, Sequence<1, 4>{}));
return d0_grid_desc_m0_n0_m1_m2_n1_m3;
}
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Operator
{
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
}
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
constexpr auto d0_n0_n1_m0_m1_m2 = transform_tensor_descriptor(
d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple(D0M1 / I2, I2)),
make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2;
}
static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
4, // SrcScalarPerVector
4, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
4, // SrcScalarPerVector
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
};
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(math::max(a_block_space_size_aligned.value,
b1_block_space_size_aligned.value,
a2_block_space_size_aligned.value) +
k_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
softmax_bytes_end,
d0_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1>
__device__ static void
Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
const D0DataType* __restrict__ p_d0_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const SElementwiseOperation& s_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const KGridDesc_K0_N_K1& kgrad_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const VGridDesc_O0_N_O1& vgrad_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph,
const index_t z_random_matrix_offset,
const index_t raw_n_padded,
const index_t block_idx_n)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_k_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_v_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, vgrad_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, kgrad_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t block_work_idx_n = Deterministic ? block_idx_n : block_work_idx[I0];
// HACK: this force n_block_data_idx_on_grid into SGPR
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx_n * NPerBlock);
const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = MPerBlock / Gemm1KPerBlock;
// 6 GEMM operations are categorized into 4 buckets. SizeK == SizeO == head_dim
// dP_MNO Gemm (Gemm0 rcc)
// dV_NOM / dK_NKM Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm2 crr)
// S_MNK Gemm (Gemm3 rcc)
// LDS allocation for K
auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize());
// K matrix blockwise copy
auto gemm_tile_k_blockwise_copy =
typename GemmBlockwiseCopy::template KBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
GemmBlockwiseCopy::k_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// Vgpr allocation for V
auto v_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<
AddressSpaceEnum::Vgpr,
GemmDataType,
GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
},
Number<GemmBlockwiseCopy::VBlockBufferSize>{});
const auto v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GemmBlockwiseCopy::MakeVGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(v_grid_desc_o0_n_o1);
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// V matrix blockwise copy
auto gemm_tile_v_blockwise_copy =
typename GemmBlockwiseCopy::template VBlockwiseCopy<decltype(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3)>(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_multi_index(
0, 0, wave_m_n_id[I0], block_work_idx_n, 0, wave_id[I1], wave_m_n_id[I1], 0));
//
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)>;
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
// dP: transform input tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
// dP: A matrix blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: blockwise gemm
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), -MPerBlock, 0);
constexpr index_t num_ok_block_main_loop = Gemm1NPerBlock / KPerBlock;
//
// set up S Gemm (type 4 rcc)
//
using Gemm3 = Gemm3<decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm3::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
// S: A matrix blockwise copy
auto s_gemm_tile_q_blockwise_copy =
typename Gemm3::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
a_element_op,
Gemm3::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm3::BlockwiseGemm{}; // TransposeC
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
const auto s_gemm_tile_q_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), -MPerBlock, 0);
//
// set up dV / dK Gemm (type 2 rrr)
//
using Gemm1 =
Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()),
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2())>;
// Gemm1: VGPR allocation for A and LDS allocation for B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// dV: A matrix blockwise copy
auto vgrad_gemm_tile_p_blockwise_copy =
typename Gemm1::template ABlockwiseCopy<tensor_operation::element_wise::Relu>{
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
// dV: B matrix blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(ygrad_grid_desc_m0_o_m1)>(
ygrad_grid_desc_m0_o_m1,
make_multi_index(MPerBlock / B1K1 * (num_gemm0_m_block_outer_loop - 1), 0, 0),
b1_element_op,
Gemm1::b_block_desc_bk0_n_bk1, // there n actually is k, k is N, so name can be
// b_block_desc_bn0_k_bn1
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
const auto vgrad_gemm_tile_ygrad_block_next_copy_step =
make_multi_index(-2 * MPerBlock / B1K1, 0, 0);
// dV: blockwise gemm
auto vgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
// dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(vgrad_grid_desc_o0_n_o1);
// dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
KGradGemmTile_N_K_M::MakeQGridDesc_M0_K_M1(q_grid_desc_k0_m_k1);
// dK: A matrix blockwise copy
auto kgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
tensor_operation::element_wise::PassThrough{}};
// dK: B matrix blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(q_grid_desc_m0_k_m1)>(
q_grid_desc_m0_k_m1,
make_multi_index(MPerBlock / B1K1 * (num_gemm0_m_block_outer_loop - 1), 0, 0),
b1_element_op,
Gemm1::b_block_desc_bk0_n_bk1, // there n actually is k, k is N, so name can be
// b_block_desc_bn0_k_bn1
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
const auto kgrad_gemm_tile_q_block_next_copy_step =
make_multi_index(-2 * MPerBlock / B1K1, 0, 0);
// dK: blockwise gemm
auto kgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto kgrad_thread_buf = kgrad_blockwise_gemm.GetCThreadBuffer();
// dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(kgrad_grid_desc_k0_n_k1);
//
// set up dQ Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params,
decltype(pgrad_blockwise_gemm),
decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
Gemm2::a_block_desc_k0_m_k1.GetElementSpaceSize());
auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
// dQ: A matrix VGPR-to-LDS blockwise copy
auto qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
Gemm2::a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2,
Gemm2::MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2(),
tensor_operation::element_wise::PassThrough{}};
// dQ: blockwise gemm
auto qgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
qgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto k_thread_origin = qgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
// dQ: B matrix LDS-to-VGPR blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy = typename Gemm2::BBlockwiseCopy{
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
make_multi_index(0, // nrepeat
k_thread_origin[I1], // nwave
k_thread_origin[I2], // nperxdl
0, // k0
0, // k1
k_thread_origin[I3] / Gemm2Params::GemmKPack, // k2
0)}; // k3
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
// dQ: transform output tensor descriptors
const auto qgrad_grid_desc_m_k = MakeQGradGridDesc_M_K(q_grid_desc_k0_m_k1);
const auto qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4 =
Gemm2::MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(qgrad_grid_desc_m_k);
// dQ: C VGPR-to-global copy
const auto qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4() +
make_multi_index((num_gemm0_m_block_outer_loop - 1) * Gemm2Params::GemmMRepeat,
I0,
I0,
I0,
I0,
I0,
I0,
I0);
auto qgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4),
decltype(scale_rp_dropout)>(qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4,
scale_rp_dropout);
//
// Blockwise softmax
//
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2 =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths() /
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I2);
constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I3);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I4);
constexpr auto tm3 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I5);
constexpr auto tm4 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I6);
constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I7);
// get acc0 thread map
constexpr auto n0_m_n1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tn0 * tn1, tn2)),
make_pass_through_transform(I1)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_n0_m_n1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(tn0 * tn1, tm0 * tm1 * tm2 * tm3 * tm4, tn2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor =
chain_tensor_adaptors(n0_m_n1_to_m_n_adaptor, threadid_to_n0_m_n1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto n0 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto m1 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto n1 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto m2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto m3 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto m4 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2 * tm3 * tm4, tn0 * tn1 * tn2));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2 * m3 * m4, n0 * n1 * n2));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
constexpr auto lse_thread_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2, m3, m4));
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatLSE,
FloatLSE,
decltype(lse_grid_desc_mb_m0_m1_m2_m3_m4),
decltype(lse_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MGroupNum
m3, // MInputNum
m4, // RegisterNum
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
uint8_t,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // MBlockId
block_work_idx_n, // NBlockId
0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], //
0, //
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto P_M3 = p_block_lengths[I5];
constexpr auto P_M4 = p_block_lengths[I6];
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(num_gemm0_m_block_outer_loop - 1, I0, I0, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
MPerBlock);
constexpr auto y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, P_M0, P_M1, P_M2, P_M3, P_M4));
constexpr auto y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4 =
lse_thread_desc_mb_m0_m1_m2_m3_m4; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4),
decltype(y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic)
{
block_sync_lds();
}
// Initialize dK&dV
kgrad_thread_buf.Clear();
vgrad_thread_buf.Clear();
// load k
gemm_tile_k_blockwise_copy.Run(k_grid_desc_k0_n_k1,
k_grid_buf,
GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf,
I0);
// load v
static_for<0, GemmBlockwiseCopy::VBlockBufferSize, 1>{}([&](auto ii) {
gemm_tile_v_blockwise_copy.Run(v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
v_grid_buf,
GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
v_thread_buf(Number<ii>{}));
gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, GemmBlockwiseCopy::v_block_slice_copy_step);
});
do
{
auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
//
// calculate Y dot dY
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
ygrad_thread_buf);
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<offset>{}] * ygrad_thread_buf[Number<offset>{}];
});
});
// blockwise reduction using atomic_add
block_sync_lds();
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(idx_on_block,
true,
y_dot_ygrad_thread_accum_buf[iM] *
p_dropout); // p_dropoutD1
});
block_sync_lds();
// distribute y_dot_ygrad to threads;
// LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run(y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4,
y_dot_ygrad_block_accum_buf,
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
block_sync_lds();
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
lse_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0),
lse_thread_buf);
// S = Q * K^T
{
// preload data into LDS
s_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_k0_m_k1, q_grid_buf);
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_k0_m_k1,
Gemm3::a_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read
s_slash_p_thread_buf.Clear();
s_gemm_tile_q_blockwise_copy.RunWrite(Gemm3::a_block_desc_ak0_m_ak1,
gemm3_a_block_buf);
// main body
if constexpr(HasMainKBlockLoop)
{
index_t i = 0;
do
{
s_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_k0_m_k1, q_grid_buf);
block_sync_lds();
s_blockwise_gemm.Run(gemm3_a_block_buf, k_block_buf, s_slash_p_thread_buf);
s_blockwise_gemm.MoveBBlockSrcSliceWindow(Gemm3::b_block_slice_copy_step);
block_sync_lds();
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1, Gemm3::a_block_slice_copy_step);
s_gemm_tile_q_blockwise_copy.RunWrite(Gemm3::a_block_desc_ak0_m_ak1,
gemm3_a_block_buf);
++i;
} while(i < (num_ok_block_main_loop - 1));
}
// tail
{
block_sync_lds();
s_blockwise_gemm.Run(gemm3_a_block_buf, k_block_buf, s_slash_p_thread_buf);
s_blockwise_gemm.MoveBBlockSrcSliceWindow(Gemm3::b_block_slice_copy_step);
}
} // end gemm S
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_element_op(s_slash_p_thread_buf(i),
masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i]);
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if constexpr(IsDropout)
{
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
}
block_sync_lds(); // wait for gemm1 LDS read
// gemm dV
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
vgrad_gemm_tile_p_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
s_slash_p_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf);
block_sync_lds();
vgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, vgrad_thread_buf);
block_sync_lds();
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step);
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
});
}
// tail
{
vgrad_gemm_tile_p_blockwise_copy.Run(
Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
s_slash_p_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
block_sync_lds();
vgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, vgrad_thread_buf);
}
} // end gemm dV
// gemm dP
block_sync_lds();
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
{
// preload data into LDS
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf);
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read
pgrad_thread_buf.Clear();
pgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm0::a_block_desc_ak0_m_ak1,
gemm0_a_block_buf);
// main body
if constexpr(num_ok_block_main_loop > 1)
{
static_for<0, num_ok_block_main_loop - 1, 1>{}([&](auto i) {
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf);
block_sync_lds();
pgrad_blockwise_gemm.Run(
gemm0_a_block_buf, v_thread_buf(Number<i>{}), pgrad_thread_buf);
block_sync_lds();
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step);
pgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm0::a_block_desc_ak0_m_ak1,
gemm0_a_block_buf);
});
}
// tail
{
block_sync_lds();
pgrad_blockwise_gemm.Run(gemm0_a_block_buf,
v_thread_buf(Number<num_ok_block_main_loop - 1>{}),
pgrad_thread_buf);
}
} // end gemm dP
// dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator();
constexpr auto pgrad_thread_idx_to_m_n_adaptor =
pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout
bool undropped_flag = s_slash_p_thread_buf[i] >= 0;
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0grad_block_copy_lds_to_global.Run(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_gemm2_loop = NPerBlock / Gemm2Params::Sum_K;
static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
"");
// TODO: tune gemm2 pipeline
// gemm dQ
// dQ = scalar * dS * K
qgrad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dQ
// load QGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range = make_tuple(
sgrad_slice_idx[I2],
sgrad_slice_idx[I2] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I2));
constexpr auto nwave_range = make_tuple(
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2,
gemm2_a_block_buf);
}
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
gemm2_b_thread_buf);
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before read
qgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_thread_buf, qgrad_thread_buf);
}); // end gemm dQ
// atomic_add dQ
qgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
qgrad_thread_buf,
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
qgrad_grid_buf);
// gemm dK
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm1::b_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
kgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
sgrad_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
block_sync_lds();
kgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, kgrad_thread_buf);
block_sync_lds();
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1, Gemm1::b_block_slice_copy_step);
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
});
}
// tail
{
kgrad_gemm_tile_sgrad_blockwise_copy.Run(
Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
sgrad_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
block_sync_lds();
kgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, kgrad_thread_buf);
}
} // end gemm dK
// move slice window
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
s_gemm_tile_q_block_reset_copy_step); // rewind K and step M
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1,
pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O and step M
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm2::b_block_reset_copy_step); // rewind N
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1, kgrad_gemm_tile_q_block_next_copy_step); // step M
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_block_next_copy_step); // step M
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(-1, 0, 0, 0));
s_blockwise_gemm.MoveBBlockSrcSliceWindow(Gemm3::b_block_reset_copy_step);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write
{
static_assert(NXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = Gemm0NWaves;
constexpr index_t NWave = Gemm0MWaves;
// TODO: hacky, fix it!
// thread desc same with kgrad_blockwise_gemm
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
// block desc same with kgrad_blockwise_gemm
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
vgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2)), // M2 = MPerXdl
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2, // N2 * N3 * N4 = NPerXdl
N3,
N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// index same with kgrad_blockwise_gemm
const auto c_thread_mtx_on_block =
vgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto vgrad_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
tensor_operation::element_wise::Scale{rp_dropout}};
// shuffle: blockwise copy C from LDS to global
auto vgrad_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(vgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// shuffle: threadwise copy C from VGPR to LDS
auto kgrad_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global
auto kgrad_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(kgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<NXdlPerWave, Gemm1NXdlPerWave, 1, 1, 1, N2, 1, N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
1,
N2,
1,
N4>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, NPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
// dK
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
kgrad_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
kgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
kgrad_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
kgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
kgrad_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
kgrad_grid_desc_nblock_nperblock_oblock_operblock, c_global_step);
}
});
// dV
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
vgrad_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
vgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
vgrad_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
vgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
vgrad_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
vgrad_grid_desc_nblock_nperblock_oblock_operblock, c_global_step);
}
});
}
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment