Commit 4cb63955 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update for compatibility

parent 2ab1dbf9
......@@ -66,7 +66,8 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = CDataType>
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout,
CLayout,
......@@ -131,7 +132,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer,
ComputeType>;
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
......
......@@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
Gemm1KPack,
false, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......
......@@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
Gemm1KPack,
false, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{}
Gemm1KPack * XdlopsGemm<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm1KPack,
A0B0B1DataType,
false>{}
.K0PerXdlops>{ // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......
......@@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......
......@@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......
......@@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -568,6 +568,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeDataType,
ComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -602,6 +602,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -457,6 +457,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -588,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -1012,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -108,7 +108,8 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC>
typename ComputeTypeA = FloatC,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -547,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ComputeType) +
b_block_space_size_aligned * sizeof(ComputeType)),
return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
b_block_space_size_aligned * sizeof(ComputeTypeB)),
c_block_size * sizeof(FloatCShuffle));
}
......@@ -750,7 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
ComputeType,
ComputeTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -781,7 +782,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
ComputeType,
ComputeTypeB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -809,13 +810,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeTypeA,
ComputeTypeB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -833,10 +835,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
......@@ -495,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize,
ABDataType,
ABDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -737,6 +737,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted,
FloatABAdjusted,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
......
......@@ -490,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
......
......@@ -426,6 +426,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted,
FloatABAdjusted,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
......@@ -571,6 +572,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatABAdjusted,
FloatABAdjusted,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......
......@@ -762,6 +762,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......
......@@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_ak0_m_ak1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment