"vscode:/vscode.git/clone" did not exist on "239716a70405b424d33f411cd1948b99e127092f"
Commit 8551dd43 authored by Anthony Chang's avatar Anthony Chang
Browse files

start with dY

start with dY
parent ecd5f7c9
......@@ -40,7 +40,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using DeviceGemmInstance = DeviceGemmInstance0;
using DeviceGemmInstance = DeviceGemmInstance1;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
......@@ -44,9 +44,14 @@ using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using DataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using DataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......@@ -54,7 +59,6 @@ static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
#if 0
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......@@ -63,7 +67,70 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#endif
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
......@@ -306,6 +373,7 @@ int run(int argc, char* argv[])
DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
// TODO ANT: make sure K/V gradients are zeroed
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
......@@ -313,14 +381,17 @@ int run(int argc, char* argv[])
ygrad_device_buf.ToDevice(y_gs_ms_os.mData.data());
// TODO ANT: attention backward kernel
#if 0
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
......@@ -335,11 +406,11 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
q_element_op,
k_element_op,
s_element_op,
v_element_op,
y_element_op);
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{});
if(!gemm.IsSupportedArgument(argument))
{
......@@ -361,7 +432,6 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
#endif
bool pass = true;
if(do_verification)
......
......@@ -185,6 +185,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!");
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
......
......@@ -23,8 +23,7 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename DataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
......@@ -34,6 +33,8 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
......@@ -43,10 +44,14 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
......@@ -57,6 +62,10 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
// const KGradGridDescriptor_N_K kgrad_grid_desc_n_k,
const VGradGridDescriptor_N_O vgrad_grid_desc_n_o,
const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
......@@ -68,6 +77,8 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
......@@ -81,6 +92,10 @@ __global__ void
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -91,6 +106,8 @@ __global__ void
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask);
#else
......@@ -122,10 +139,7 @@ template <index_t NumDimG,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType,
typename BDataType,
typename B1DataType,
typename CDataType,
typename DataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType,
......@@ -183,15 +197,16 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
#if 0
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
DataType,
DataType,
DataType,
DataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
......@@ -200,6 +215,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
#endif
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
......@@ -226,6 +242,22 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
Number<NPerBlock>,
Number<KPerBlock>,
Number<Gemm1NPerBlock>>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
......@@ -235,6 +267,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec,
CSpec>;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
......@@ -243,6 +287,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<AK1>{});
}
// K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
......@@ -251,6 +296,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<BK1>{});
}
// V in Gemm B1 position
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
......@@ -261,6 +307,114 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<B1K1>{});
}
//
// dV = P^T * dY
//
// VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new];
}
const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadM, padder.PadO>{});
}
template <typename YGridDesc_M_O, typename Number>
static auto MakeYGradGridDescriptor_M0_O_M1(const YGridDesc_M_O& ygrad_grid_desc_m_o,
const Number& M1)
{
const auto M = ygrad_grid_desc_m_o.GetLength(I0);
const auto O = ygrad_grid_desc_m_o.GetLength(I1);
const auto M0 = M / M1;
return transform_tensor_descriptor(
ygrad_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(M0, M1)), make_pass_through_transform(O)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// can we construct YGrad_m0_o_m1 from Y_m_o?
// static auto MakeYGradGridDescriptor_M0_O_M1(const std::vector<index_t>&
// y_gs_ms_os_lengths_vec,
// const std::vector<index_t>&
// y_gs_ms_os_strides_vec)
// {
// }
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec,
const std::vector<index_t>& y_gs_ms_os_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec),
Number<Y_O1>{});
}
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec,
const std::vector<index_t>& q_gs_ms_ks_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec,
const std::vector<index_t>& k_gs_ns_ks_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
......@@ -270,6 +424,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using VGradGridDesc_N_O = decltype(MakeVGradGridDescriptor_N_O({}, {}));
using YGradGridDesc_M0_O_M1 =
decltype(MakeYGradGridDescriptor_M0_O_M1(CGridDesc_M_N{}, Number<Y_M1>{}));
constexpr static auto make_MaskOutPredicate()
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
......@@ -325,10 +483,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
DataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
......@@ -391,10 +548,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct Argument : public BaseArgument
{
Argument(
const ADataType* p_a_grid,
const BDataType* p_b_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
const DataType* p_a_grid,
const DataType* p_b_grid,
const DataType* p_b1_grid,
const DataType* p_c_grid, // for dS
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -420,6 +581,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
p_vgrad_grid_{p_vgrad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
......@@ -428,6 +590,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
// dV = P^T * dY
vgrad_grid_desc_n_o_{DeviceOp::MakeVGradGridDescriptor_N_O(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
/* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_{
DeviceOp::MakeYGradGridDescriptor_M0_O_M1(c_grid_desc_m_n_, Number<Y_M1>{})},
// batch offsets
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
......@@ -501,10 +670,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
const DataType* p_a_grid_;
const DataType* p_b_grid_;
const DataType* p_b1_grid_;
const DataType* p_c_grid_;
const DataType* p_ygrad_grid_;
DataType* p_vgrad_grid_;
DataType* p_qgrad_grid_;
DataType* p_kgrad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -518,6 +691,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
VGradGridDesc_N_O vgrad_grid_desc_n_o_;
YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -566,8 +742,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
......@@ -577,6 +752,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::VGradGridDesc_N_O,
DeviceOp::YGradGridDesc_M0_O_M1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
......@@ -591,6 +768,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.p_b_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
......@@ -600,6 +781,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.vgrad_grid_desc_n_o_,
arg.ygrad_grid_desc_m0_o_m1_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
......@@ -622,7 +805,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
const StreamConfig& stream_config = StreamConfig{}) // override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
......@@ -705,16 +888,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
bool IsSupportedArgument(const BaseArgument* p_arg) // override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const ADataType* p_a,
const BDataType* p_b,
const B1DataType* p_b1,
CDataType* p_c,
const DataType* p_a,
const DataType* p_b,
const DataType* p_b1,
const DataType* p_c,
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -741,6 +928,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b,
p_b1,
p_c,
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_vgrad_grid,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
......@@ -770,7 +961,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const void* p_a,
const void* p_b,
const void* p_b1,
void* p_c,
const void* p_c,
const void* p_ygrad_grid,
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -791,12 +986,16 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
CElementwiseOperation c_element_op) // override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b),
static_cast<const DataType*>(p_b1),
static_cast<const DataType*>(p_c),
static_cast<const DataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
......@@ -819,13 +1018,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
std::string GetTypeString() const // override
{
auto str = std::stringstream();
......
......@@ -209,7 +209,8 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
// TODO ANT: is this necessary?
// block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0);
......
......@@ -18,10 +18,9 @@
namespace ck {
template <typename FloatAB,
template <typename DataType,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
......@@ -107,6 +106,105 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
// VGrad Gemm
template <index_t Sum_M_ = MPerXdl * 2>
struct VGradGemmTile_N_O_M_
{
static constexpr index_t Free0_N = NPerBlock;
static constexpr index_t Free1_O = Gemm1NPerBlock;
static constexpr index_t Sum_M = Sum_M_;
static constexpr index_t P_M1 = 8; // P will be row-major
static constexpr index_t P_M0 = Sum_M / P_M1;
static constexpr index_t P_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static constexpr index_t YGrad_M1 = 2; // dY assumed row-major, typically =2 for fp16
static constexpr index_t YGrad_M0 = Sum_M / YGrad_M1;
static constexpr index_t YGrad_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static_assert(Sum_M % MPerXdl == 0, "");
static constexpr index_t YGrad_SrcVectorDim = 1; // Free1_O dimension
static constexpr index_t YGrad_SrcScalarPerVector = 4;
static constexpr index_t GemmNWave = 2;
static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl;
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack =
math::max(math::lcm(P_M1, YGrad_M1),
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using YGrad_BlockSliceLengths = Sequence<YGrad_M0, Free1_O, YGrad_M1>;
using YGrad_ThreadClusterLengths =
Sequence<BlockSize / (Free1_O / YGrad_SrcScalarPerVector),
Free1_O / YGrad_SrcScalarPerVector,
1>;
using YGrad_ThreadClusterArrangeOrder = Sequence<0, 2, 1>;
__host__ __device__ static constexpr auto GetPBlockDescriptor_M0_N_M1()
{
constexpr index_t P_M0 = Sum_M / P_M1;
return make_naive_tensor_descriptor(
make_tuple(Number<P_M0>{}, Number<Free0_N>{}, Number<P_M1>{}),
make_tuple(Number<Free0_N + P_LdsPad>{} * Number<P_M1>{}, Number<P_M1>{}, I1));
}
__host__ __device__ static constexpr auto GetYGradBlockDescriptor_M0_O_M1()
{
constexpr index_t YGrad_M0 = Sum_M / YGrad_M1;
return make_naive_tensor_descriptor(
make_tuple(Number<YGrad_M0>{}, Number<Free1_O>{}, Number<YGrad_M1>{}),
make_tuple(
Number<Free1_O + YGrad_LdsPad>{} * Number<YGrad_M1>{}, Number<YGrad_M1>{}, I1));
}
__host__ __device__ static constexpr auto GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2()
{
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Sum_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t n = Free0_N - 1;
constexpr index_t n2 = n % NPerXdl;
constexpr index_t n1 = n / NPerXdl % Gemm0NWaves;
constexpr index_t n0 = n / NPerXdl / Gemm0NWaves % MXdlPerWave;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return Sequence<m0, n0, m1, n1, m2, n2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
}
// template <typename PBlockDesc_M0_N_M1>
// __host__ __device__ static constexpr auto
// MakePMmaTileDescriptor_N0_N1_N2_M(const PBlockDesc_M0_N_M1&)
// {
// constexpr auto lengths = GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2();
// return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<lengths[I0], lengths[I2],
// lengths[I4]>(
// PBlockDesc_M0_N_M1{});
// }
// template <typename BBlockDesc_BK0_N_BK1>
// __host__ __device__ static constexpr auto
// MakeYGradMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
// {
// constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
// return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
// BBlockDesc_BK0_N_BK1{});
// }
};
using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
// QGrad Gemm
// KGrad Gemm
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
......@@ -188,14 +286,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto
GetPBlockDescriptor_NBlock_NPerBlock_MBlock_MPerBlock()
{
constexpr auto ptrans_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
I1, Number<VGradGemmTile_N_O_M::Free0_N>{}, I1, Number<VGradGemmTile_N_O_M::Sum_M>{}));
return ptrans_block_desc;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatAB);
sizeof(DataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB);
sizeof(DataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
......@@ -338,11 +445,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
template <bool HasMainKBlockLoop,
typename Block2CTileMap,
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -354,6 +469,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask)
{
......@@ -363,7 +480,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
......@@ -404,8 +521,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
DataType,
DataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -435,8 +552,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
DataType,
DataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -465,11 +582,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
DataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -489,11 +606,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......@@ -566,7 +683,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
FloatAB,
DataType,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
......@@ -584,8 +701,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
DataType,
DataType,
decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
......@@ -606,12 +723,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
......@@ -624,11 +741,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
DataType,
FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
......@@ -644,7 +761,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
Gemm1KPack * XdlopsGemm<DataType, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......@@ -713,6 +830,263 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
//
// dV
//
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto p_dst_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
constexpr auto p_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_N0 = p_block_lengths[I1];
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_N1 = p_block_lengths[I3];
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto P_N2 = p_block_lengths[I5];
constexpr auto P_N3 = p_block_lengths[I6];
constexpr auto P_N4 = p_block_lengths[I7];
constexpr auto p_dst_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = [&]() constexpr
{
constexpr auto p_dst_block_desc_m_n = transform_tensor_descriptor(
p_dst_block_desc_m0_n_m1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(VGradGemmTile_N_O_M::P_M0, VGradGemmTile_N_O_M::P_M1)),
make_pass_through_transform(VGradGemmTile_N_O_M::Free0_N)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
p_dst_block_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(P_M0, P_M1, P_M2)),
make_unmerge_transform(make_tuple(P_N0, P_N1, P_N2, P_N3, P_N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
}
();
// TODO ANT: check lds offset
auto p_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared), p_dst_block_desc_m0_n_m1.GetElementSpaceSize());
const auto p_dst_thread_origin = [&]() {
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(P_M0, P_M1, P_M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(P_N0, P_N1, P_N2, P_N3, P_N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
return make_tuple(0, // mrepeat
0, // nrepeat
m_thread_data_on_block_idx[I1], // mwave
n_thread_data_on_block_idx[I1], // nwave
m_thread_data_on_block_idx[I2], // xdlops
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]);
}();
constexpr auto p_block_slice_lengths_m0_n0_m1_n1_m2_n2 = // mrepeat, nrepeat, mwaves,
// nwaves, mperxdl, nperxdl
VGradGemmTile_N_O_M::GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2();
// how to properly perform copy for a sub-workgroup?
auto p_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(p_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(p_dst_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I0], // ThreadSliceLengths
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I1],
I1,
I1,
I1,
P_N2,
I1,
P_N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{
p_dst_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(p_dst_thread_origin[I0],
p_dst_thread_origin[I1],
p_dst_thread_origin[I2] % p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I2],
p_dst_thread_origin[I3] % p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I3],
p_dst_thread_origin[I4],
p_dst_thread_origin[I5],
p_dst_thread_origin[I6],
p_dst_thread_origin[I7]),
tensor_operation::element_wise::PassThrough{}};
// construct space filling curve
// p_thread_copy_vgpr_to_lds.Run();
constexpr auto ygrad_dst_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
auto ygrad_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
typename VGradGemmTile_N_O_M::YGrad_BlockSliceLengths,
typename VGradGemmTile_N_O_M::YGrad_ThreadClusterLengths,
typename VGradGemmTile_N_O_M::YGrad_ThreadClusterArrangeOrder,
DataType,
DataType,
decltype(ygrad_grid_desc_m0_o_m1),
decltype(ygrad_dst_block_desc_m0_o_m1),
typename VGradGemmTile_N_O_M::YGrad_ThreadClusterArrangeOrder, // access order == thread
// order
Sequence<1, 0, 2>,
VGradGemmTile_N_O_M::YGrad_SrcVectorDim,
2, // DstVectorDim
VGradGemmTile_N_O_M::YGrad_SrcScalarPerVector,
VGradGemmTile_N_O_M::YGrad_M1,
1,
1,
true,
true,
1>(ygrad_grid_desc_m0_o_m1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
ygrad_dst_block_desc_m0_o_m1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto vgrad_blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
BlockSize,
DataType,
FloatGemmAcc,
decltype(p_dst_block_desc_m0_n_m1),
decltype(ygrad_dst_block_desc_m0_o_m1),
MPerXdl,
NPerXdl,
VGradGemmTile_N_O_M::GemmNRepeat, // NRepeat
VGradGemmTile_N_O_M::GemmORepeat, // ORepeat
VGradGemmTile_N_O_M::GemmMPack>{};
constexpr auto vgrad_block_lengths =
vgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor(
vgrad_grid_desc_n_o,
make_tuple(
make_unmerge_transform(make_tuple(I1, // may place a dummy variable
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I2],
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I4])),
make_unmerge_transform(make_tuple(I1,
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I3],
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I5]))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
constexpr auto vgrad_thread_desc_n0_o0_n1_o1_n2_n3_n4_o2 =
vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_n3_n4_o2 =
vgrad_blockwise_gemm.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2);
const auto vgrad_thread_mtx_on_block_n_o =
vgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
constexpr auto vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2 =
decltype(vgrad_blockwise_gemm)::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto VGrad_N0 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I0);
constexpr auto VGrad_O0 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I1);
constexpr auto VGrad_N1 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I2);
constexpr auto VGrad_O1 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I3);
constexpr auto VGrad_N2 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I4);
constexpr auto VGrad_N3 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I5);
constexpr auto VGrad_N4 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I6);
constexpr auto VGrad_O2 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I7);
const index_t n_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I0]; // TODO ANT: step n after each Gemm1 outer loop
const index_t o_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid;
const auto n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(VGrad_N0, VGrad_N1, VGrad_N2, VGrad_N3, VGrad_N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_nd_idx_on_grid =
n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_idx_on_grid));
const auto o_thread_data_on_grid_to_o0_o1_o2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(VGrad_O0, VGrad_O1, VGrad_O2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto o_thread_data_nd_idx_on_grid =
o_thread_data_on_grid_to_o0_o1_o2_adaptor.CalculateBottomIndex(
make_multi_index(o_thread_data_idx_on_grid));
auto vgrad_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_n3_n4_o2),
decltype(vgrad_grid_desc_n0_o0_n1_o1_n2_n3_n4_o2),
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
decltype(vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
1, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1,
true>(vgrad_grid_desc_n0_o0_n1_o1_n2_n3_n4_o2,
make_multi_index(n_thread_data_nd_idx_on_grid[I0],
o_thread_data_nd_idx_on_grid[I0],
n_thread_data_nd_idx_on_grid[I1],
o_thread_data_nd_idx_on_grid[I1],
n_thread_data_nd_idx_on_grid[I2],
n_thread_data_nd_idx_on_grid[I3],
n_thread_data_nd_idx_on_grid[I4],
o_thread_data_nd_idx_on_grid[I2]),
tensor_operation::element_wise::PassThrough{});
// TODO ANT: ygrad slice window step size
#if 0
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
......@@ -928,8 +1302,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
#endif
// shuffle C and write out
// TODO ANT:
// shuffle dQ and write
if constexpr(false)
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
......@@ -1054,7 +1431,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
DataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
......@@ -54,7 +54,8 @@ template <typename SrcData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename SliceLengths, // TODO ANT: can we generalize to allow sub-wg slice transfer? need
// to distinguish what dimensions are spread across waves
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
......
......@@ -19,4 +19,37 @@ struct ThisThreadBlock
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
};
template <index_t ThreadPerBlock>
struct SubThreadBlock
{
static constexpr index_t kNumThread_ = ThreadPerBlock;
__device__ SubThreadBlock(int mwave, int nwave) : mwave_(mwave), nwave_(nwave) {}
__device__ static constexpr index_t GetNumOfThread() { return kNumThread_; }
template <typename Tuple2>
__device__ constexpr bool IsBelong(const Tuple2& mwave_range, const Tuple2& nwave_range)
{
// wave_range[I0] inclusive, wave_range[I1] exclusive
if(mwave_ < mwave_range[I0])
return false;
else if(mwave_ >= mwave_range[I1])
return false;
else if(nwave_ < nwave_range[I0])
return false;
else if(nwave_ >= nwave_range[I1])
return false;
else
return true;
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
private:
index_t mwave_, nwave_;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
};
} // namespace ck
add_test_executable(test_space_filling_curve space_filling_curve.cpp)
add_test_executable(test_threadwise_copy test_threadwise_copy.cpp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment