Commit e10cbe9e authored by Alan Turner's avatar Alan Turner
Browse files

Add constexpr IsSupported

parent d9676215
......@@ -611,6 +611,99 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true;
}
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row>)
{
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col>)
{
// FIXME: not rigorous
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B
if constexpr(is_same_v<BLayout, Row>)
{
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Col>)
{
// FIXME: not rigorous
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B1
if constexpr(is_same_v<B1Layout, Row>)
{
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<B1Layout, Col>)
{
// FIXME: not rigorous
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of C
if constexpr(is_same_v<CLayout, Row>)
{
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else if constexpr(is_same_v<CLayout, Col>)
{
// FIXME: not rigorous
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -625,52 +718,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
static constexpr bool IsSupported(index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw)
{
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if (!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return true;
arg.block_2_ctile_map_) and
IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
}
// polymorphic
......@@ -861,7 +914,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm
using GridwiseGemmSpec = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
......@@ -928,8 +981,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask;
typename GridwiseGemmSpec::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemmSpec::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op
AElementwiseOperation a_element_op;
......@@ -952,23 +1005,27 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemmSpec::MakeDefaultBlock2CTileMap(
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemmSpec::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemmSpec::CalculateHasMainKBlockLoop(
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
is_valid{GridwiseGemmSpec::CheckValidity(
is_valid{GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map)}
block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{
}
......@@ -1001,17 +1058,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
assert(desc.is_valid and
IsSupported(desc.a_grid_desc_ak0_m_ak1.GetLength(I1),
desc.b_grid_desc_bk0_n_bk1.GetLength(I1),
desc.a_grid_desc_ak0_m_ak1.GetLength(I0) * desc.a_grid_desc_ak0_m_ak1.GetLength(I2),
desc.b1_grid_desc_bk0_n_bk1.GetLength(I1)));
__shared__ char p_shared_block[Desc::GridwiseGemmSpec::GetSharedMemoryNumberOfByte()];
assert(desc.is_valid);
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemmSpec::template Run<true>(p_a_grid,
Desc::GridwiseGemm::template Run<true>(p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
......@@ -1030,7 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
else
{
Desc::GridwiseGemmSpec::template Run<false>(p_a_grid,
Desc::GridwiseGemm::template Run<false>(p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment