Commit e10cbe9e authored by Alan Turner's avatar Alan Turner
Browse files

Add constexpr IsSupported

parent d9676215
...@@ -611,6 +611,99 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -611,6 +611,99 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true; return true;
} }
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row>)
{
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col>)
{
// FIXME: not rigorous
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B
if constexpr(is_same_v<BLayout, Row>)
{
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Col>)
{
// FIXME: not rigorous
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B1
if constexpr(is_same_v<B1Layout, Row>)
{
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<B1Layout, Col>)
{
// FIXME: not rigorous
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of C
if constexpr(is_same_v<CLayout, Row>)
{
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else if constexpr(is_same_v<CLayout, Col>)
{
// FIXME: not rigorous
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
return true;
}
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -625,52 +718,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -625,52 +718,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_) and
} IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
static constexpr bool IsSupported(index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw)
{
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if (!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return true;
} }
// polymorphic // polymorphic
...@@ -861,7 +914,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -861,7 +914,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemmSpec = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -928,8 +981,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -928,8 +981,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n; CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask; C0MatrixMask c0_matrix_mask;
typename GridwiseGemmSpec::DefaultBlock2CTileMap block_2_ctile_map; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemmSpec::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock; typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op // element-wise op
AElementwiseOperation a_element_op; AElementwiseOperation a_element_op;
...@@ -952,23 +1005,27 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -952,23 +1005,27 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemmSpec::MakeDefaultBlock2CTileMap( block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n)}, c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{ c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemmSpec::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)}, GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemmSpec::CalculateHasMainKBlockLoop( has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)}, c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_}, a_element_op{a_element_op_},
b_element_op{b_element_op_}, b_element_op{b_element_op_},
b1_element_op{b1_element_op_}, b1_element_op{b1_element_op_},
c_element_op{c_element_op_}, c_element_op{c_element_op_},
is_valid{GridwiseGemmSpec::CheckValidity( is_valid{GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n, c_grid_desc_m_n,
block_2_ctile_map)} block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{ {
} }
...@@ -1001,17 +1058,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1001,17 +1058,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const ADataType* __restrict__ p_b1_grid, const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid) CDataType* __restrict__ p_c_grid)
{ {
assert(desc.is_valid and assert(desc.is_valid);
IsSupported(desc.a_grid_desc_ak0_m_ak1.GetLength(I1), __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
desc.b_grid_desc_bk0_n_bk1.GetLength(I1),
desc.a_grid_desc_ak0_m_ak1.GetLength(I0) * desc.a_grid_desc_ak0_m_ak1.GetLength(I2),
desc.b1_grid_desc_bk0_n_bk1.GetLength(I1)));
__shared__ char p_shared_block[Desc::GridwiseGemmSpec::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale}; AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop) if(desc.has_main_k_block_loop)
{ {
Desc::GridwiseGemmSpec::template Run<true>(p_a_grid, Desc::GridwiseGemm::template Run<true>(p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
...@@ -1030,7 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1030,7 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
else else
{ {
Desc::GridwiseGemmSpec::template Run<false>(p_a_grid, Desc::GridwiseGemm::template Run<false>(p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment