Commit cf875a91 authored by Anthony Chang's avatar Anthony Chang
Browse files

properly check size K in IsSupportedArgument()

parent 55b19fce
...@@ -448,13 +448,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -448,13 +448,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC} compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_, b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
block_2_ctile_map_)) block_2_ctile_map_,
lengths_m_n_k_o_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -481,6 +483,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -481,6 +483,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> lengths_m_n_k_o_;
}; };
// Invoker // Invoker
...@@ -494,7 +499,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -494,7 +499,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_ctile_map_,
arg.lengths_m_n_k_o_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
...@@ -588,7 +594,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -588,7 +594,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_,
arg.lengths_m_n_k_o_);
} }
// polymorphic // polymorphic
......
...@@ -200,7 +200,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -200,7 +200,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map,
const std::vector<index_t>& lengths_m_n_k_o)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -216,6 +217,13 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -216,6 +217,13 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false; return false;
} }
// K is rounded to nearest multiples of K1 during tensor transformation so instead get KRaw
const auto KRaw = lengths_m_n_k_o[2];
if(!(KRaw % AK1 == 0 && KRaw % BK1 == 0))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) Gemm1N % Gemm1NPerBlock == 0))
{ {
......
...@@ -68,7 +68,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN) ...@@ -68,7 +68,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
this->Run(); this->Run();
} }
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Test_FP16_OddK) // Currently expected that no kernels can support this case
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
{ {
this->lengths_ = std::vector<std::vector<int>>{ this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 1}, {128, 128, 33, 128, 1},
...@@ -108,7 +109,7 @@ using ck::tensor_operation::device::GemmSpecialization; ...@@ -108,7 +109,7 @@ using ck::tensor_operation::device::GemmSpecialization;
TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMatch) TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMatch)
{ {
int P = 129; // requires padding int P = 120; // requires padding
int Q = 128; // do not require padding int Q = 128; // do not require padding
// IsSupported(M, N, K, O) // IsSupported(M, N, K, O)
...@@ -134,18 +135,11 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMatch) ...@@ -134,18 +135,11 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMatch)
TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch) TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch)
{ {
int P = 129; // requires padding
int Q = 128; // do not require padding
// IsSupported(M, N, K, O) // IsSupported(M, N, K, O)
// clang-format off // clang-format off
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(Q, Q, Q, P)); EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(Q, Q, P, P)); EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(Q, P, Q, P)); // Kernel can't support odd K because K must be multiples of K1 values of either A or B
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, Q, Q, P)); EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(Q, P, P, P));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, Q, P));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, Q, P, P));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, P));
// clang-format on // clang-format on
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment