Commit bf86f84b authored by Anthony Chang's avatar Anthony Chang
Browse files

properly check size requirement given SrcScalarPerVector in IsSupportedArgument()

parent cf875a91
......@@ -449,14 +449,14 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
c_element_op_{c_element_op},
batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_,
lengths_m_n_k_o_))
raw_lengths_m_n_k_o_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -485,7 +485,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> lengths_m_n_k_o_;
std::vector<index_t> raw_lengths_m_n_k_o_;
};
// Invoker
......@@ -500,7 +500,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_,
arg.lengths_m_n_k_o_))
arg.raw_lengths_m_n_k_o_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
......@@ -590,12 +590,37 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_,
arg.lengths_m_n_k_o_);
arg.raw_lengths_m_n_k_o_);
}
// polymorphic
......
......@@ -78,7 +78,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Test_FP16_OddO)
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 1},
......@@ -141,5 +142,7 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch)
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K because K must be multiples of K1 values of either A or B
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
// Kernel can't support odd O size because B1SrcScalarPerVector=8 and must satisfy SizeO % 8 == 0
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment