Commit 25907115 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Add checks for correctness of vector lenghts in DL GEMM

parent 21ff7850
......@@ -273,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
M_raw{M},
N_raw{N},
K_raw{K},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
......@@ -314,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
index_t M01_;
index_t N01_;
index_t M_raw;
index_t N_raw;
index_t K_raw;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -485,6 +492,50 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
// Make sure that the M, N, K dimensions before padding are divisible by respective vector
// lengths.
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto A_K_vec_length =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(0) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(3);
if(arg.K_raw % A_K_vec_length != 0)
{
return false;
}
}
else
{
constexpr auto A_M_vec_lenght =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(1) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(2);
if(arg.M_raw % A_M_vec_lenght != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
constexpr auto B_N_vec_lenght =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(1) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(2);
if(arg.N_raw % B_N_vec_lenght != 0)
{
return false;
}
}
else
{
constexpr auto B_K_vec_length =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(0) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(3);
if(arg.K_raw % B_K_vec_length != 0)
{
return false;
}
}
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment