Unverified Commit 3e178e7c authored by Muhammed  Emin Ozturk's avatar Muhammed Emin Ozturk Committed by GitHub
Browse files

Merge branch 'develop' into muozturk_bf16fp8_streamk

parents 27fb084f 0e5e29c4
...@@ -15,6 +15,7 @@ struct fused_moe_args ...@@ -15,6 +15,7 @@ struct fused_moe_args
const void* g_scale_ptr; // [e, 1, n], gate(up) scale const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
void* o_ptr; // [m, k], output token (no need to do zeroing) void* o_ptr; // [m, k], output token (no need to do zeroing)
const void* topk_ids_ptr; // [tokens, topk] const void* topk_ids_ptr; // [tokens, topk]
...@@ -48,6 +49,8 @@ struct fused_moe_traits ...@@ -48,6 +49,8 @@ struct fused_moe_traits
int activation; // 0:gelu, 1:silu int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1 int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool local_expert_masking; // if mask experts as local expert
}; };
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
...@@ -11,6 +11,7 @@ struct fused_moesorting_trait ...@@ -11,6 +11,7 @@ struct fused_moesorting_trait
{ {
std::string index_type; std::string index_type;
std::string weight_type; // currently always float std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
}; };
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
......
...@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf ...@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
return 1; return 1;
}(); }();
auto t0 = fused_moesorting_trait{"int32", "fp32"}; auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
auto a0 = fused_moesorting_args{ auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids; a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights; a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids; a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights; a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
......
...@@ -24,11 +24,16 @@ ...@@ -24,11 +24,16 @@
return ave_time; return ave_time;
#else #else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \ constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \ constexpr bool local_expert_masking = local_expert_masking_; \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \ using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \ auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \ const dim3 grids = kernel::GridSize(a); \
...@@ -38,6 +43,44 @@ ...@@ -38,6 +43,44 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif #endif
#if !MOE_SORTING_USE_EX_KERNEL #if !MOE_SORTING_USE_EX_KERNEL
...@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto sub_token_ = r_ - 2; auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8; r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_; (void)c_;
if(is_sub_token_onshot)
{ MOE_SORTING_DISPATCH_EMASK_(r_);
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, true);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, true);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, true);
}
else
{
MOE_SORTING_DISPATCH_(1, true);
}
}
else
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, false);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, false);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, false);
}
else
{
MOE_SORTING_DISPATCH_(1, false);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0); // MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif #endif
} }
......
...@@ -162,6 +162,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -162,6 +162,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int tp = arg_parser.get_int("tp"); int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init"); int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed"); uint32_t seed = arg_parser.get_uint32("seed");
bool local_expert_masking = false; // TODO...
// w0 (Gate+Up or Gate only, N size) // w0 (Gate+Up or Gate only, N size)
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
...@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
ck_tile::HostTensor<IndexDataType> local_expert_mask_host({experts});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk; int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded}); ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
...@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host); ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem local_expert_mask_buf(local_expert_mask_host);
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host); ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
...@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m, block_m,
activation, activation,
gate_only, gate_only,
fused_quant}; fused_quant,
local_expert_masking};
fused_moe_args args{a_buf.GetDeviceBuffer(), fused_moe_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
...@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
: nullptr,
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(),
topk_weight_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(),
...@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
local_expert_mask_host,
sorted_token_ids_host, sorted_token_ids_host,
sorted_weight_host, sorted_weight_host,
sorted_expert_ids_host, sorted_expert_ids_host,
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m,
local_expert_masking);
if(activation == 0) if(activation == 0)
{ {
CPU_FUSED_MOE(ck_tile::element_wise::Gelu); CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
...@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
local_expert_mask_host,
sorted_token_ids_host, sorted_token_ids_host,
sorted_weight_host, sorted_weight_host,
sorted_expert_ids_host, sorted_expert_ids_host,
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m,
local_expert_masking);
// done, preparing GPU buffer // done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem a_buf(a_host);
......
...@@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true; return true;
} }
static constexpr bool
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row>)
{
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col>)
{
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B
if constexpr(is_same_v<BLayout, Row>)
{
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Col>)
{
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B1
if constexpr(is_same_v<B1Layout, Row>)
{
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<B1Layout, Col>)
{
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of C
if constexpr(is_same_v<CLayout, Row>)
{
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else if constexpr(is_same_v<CLayout, Col>)
{
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
return true;
}
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_) and
IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
} }
// polymorphic // polymorphic
...@@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str(); return str.str();
} }
template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor
{
template <class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{
}
constexpr bool IsValid() const { return is_valid; }
};
template <class ADesc, class BDesc, class B1Desc, class CDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
}
template <class Desc>
__device__ static void Run(const Desc& desc,
const float scale,
const ADataType* __restrict__ p_a_grid,
const ADataType* __restrict__ p_b_grid,
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
#ifndef __HIPCC_RTC__
assert(desc.is_valid);
#endif
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemm::template Run<true>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
else
{
Desc::GridwiseGemm::template Run<false>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
}
}; };
} // namespace device } // namespace device
......
...@@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// if workspace is not allocated // if workspace is not allocated
if(!arg.p_workspace_) if(!arg.p_workspace_)
{ {
std::cerr << "Warning: Workspace for " if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Warning: Workspace for "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer." "allocated, use SetWorkSpacePointer."
<< std::endl; << std::endl;
}
return false; return false;
} }
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
......
...@@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o] // acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2< auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -361,8 +361,16 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -361,8 +361,16 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const auto M = d0_grid_desc_m_n.GetLength(I0); const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1); const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = constexpr bool is_single_rate_mfma =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma; ((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
math::lcm(A0K1, B0K1) <= 4)
? true
: false;
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
A0B0B1DataType,
is_single_rate_mfma>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -643,9 +651,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -643,9 +651,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o] // acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check // sanity check
constexpr index_t KPack = math::max( constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1);
math::lcm(A0K1, B0K1), constexpr bool is_single_rate_mfma =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk); ((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
lcm_A0K1_B0K1 <= 4)
? true
: false;
constexpr index_t KPack =
math::max(lcm_A0K1_B0K1,
MfmaSelector<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
A0B0B1DataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< auto blockwise_gemm0 = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -343,7 +343,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -343,7 +343,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
const auto M = d0_grid_desc_m_n.GetLength(I0); const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1); const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma; constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
math::lcm(AK1, BK1) <= 4)
? true
: false;
constexpr auto mfma =
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -552,8 +558,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -552,8 +558,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o] // acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2< auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -469,8 +469,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -469,8 +469,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o] // acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2< auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -498,8 +498,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -498,8 +498,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
math::max(math::lcm(AK1, BK1), constexpr bool is_single_rate_mfma =
MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); ((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<AComputeType, MPerXdl, NPerXdl, AComputeType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
...@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -595,9 +595,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -595,9 +595,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
math::max(math::lcm(AK1, BK1), constexpr bool is_single_rate_mfma =
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); ((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ComputeType, MPerXdl, NPerXdl, ComputeType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -79,9 +79,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -79,9 +79,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
...@@ -139,9 +139,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -139,9 +139,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateMPadded(index_t M) __host__ static auto CalculateMPadded(index_t M)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment