Commit b5fbb74b authored by ltqin's avatar ltqin
Browse files

add GemmDataType

parent 9096e2af
......@@ -61,6 +61,7 @@ using YElementOp = PassThrough;
using VElementOp = Scale;
using DataType = F16;
using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
......@@ -97,6 +98,7 @@ using DeviceGemmInstance =
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -164,6 +166,7 @@ using DeviceGemmInstance =
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -220,7 +223,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec>; // MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
......@@ -331,14 +334,14 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t M = 512;
ck::index_t N = 512;
#if USING_K128
ck::index_t K = 128;
ck::index_t O = 128;
ck::index_t K = 128;
ck::index_t O = 128;
#else
ck::index_t K = 64;
ck::index_t O = 64;
ck::index_t K = 64;
ck::index_t O = 64;
#endif
ck::index_t G0 = 3;
ck::index_t G1 = 2;
......
......@@ -49,7 +49,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/1)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid,
......@@ -171,6 +171,7 @@ template <index_t NumDimG,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
......@@ -595,9 +596,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype
LSEDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
LSEDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
......
......@@ -602,7 +602,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
{
is_lse_storing_ = false;
}
}
void Print() const
......
......@@ -21,6 +21,7 @@
namespace ck {
template <typename DataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatLSE,
......@@ -85,21 +86,6 @@ template <typename DataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
template <typename T>
struct TypeMap
{
using type = T;
};
#if defined(__gfx90a_masking__)
template <>
struct TypeMap<ck::half_t>
{
using type = ck::bhalf_t;
};
#endif
using LDSDataType = typename TypeMap<DataType>::type;
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -141,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
......@@ -157,7 +143,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N)
{
constexpr auto mfma = MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
......@@ -471,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
LDSDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -496,7 +482,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
LDSDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -513,12 +499,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
LDSDataType,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -580,7 +566,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
LDSDataType,
GemmDataType,
decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
......@@ -599,7 +585,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
DataType,
LDSDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
......@@ -630,11 +616,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
LDSDataType,
GemmDataType,
FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -650,7 +636,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
GemmKPack,
true, // TransposeC
GemmKPack, // AMmaKStride
GemmKPack * XdlopsGemm<LDSDataType, MPerXdl, NPerXdl, GemmKPack, false>{}
GemmKPack * XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}
.K0PerXdlops /* BMmaKStride */>;
};
......@@ -682,7 +668,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack =
math::max(math::lcm(A_M1, B_M1),
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>;
using BThreadClusterLengths =
......@@ -807,7 +793,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
LDSDataType,
GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ElementwiseOp,
......@@ -837,7 +823,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType,
LDSDataType,
GemmDataType,
GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1),
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order
......@@ -854,7 +840,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
LDSDataType,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_m0_n_m1),
decltype(b_block_desc_m0_o_m1),
......@@ -1095,7 +1081,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr auto b2_block_desc_m0_o_m1 =
GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(LDSDataType)>{};
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
......@@ -1131,13 +1117,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(LDSDataType);
sizeof(GemmDataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(LDSDataType);
sizeof(GemmDataType);
const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned +
SharedMemTrait::ygrad_block_space_size_aligned) *
sizeof(LDSDataType);
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
......@@ -1243,11 +1229,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline
......@@ -1339,11 +1325,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Gemm1: VGPR allocation for A and LDS allocation for B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, LDSDataType>(
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// dQ: transform input and output tensor descriptors
......@@ -1535,11 +1521,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b2_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b2_block_space_offset,
Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize());
// dV: transform input and output tensor descriptors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment