"...src/git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "ecd2f176277db4f074e25a2c3646b04b51cec119"
Unverified Commit d58b7f51 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Make sure that GEMM sizes in K dimension are supported. (#527)

* apply new K-dimension check in gemm_xdl_cshuffle

* add K-dim check to gemm_xdl and batched_gemm_xdl

* fix syntax

* fix syntax

* clean-up the debug output
parent 614a7b1b
...@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{K}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
...@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
...@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(arg.kraw_ % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
...@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< " ) " << std::endl; << " ) " << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -265,7 +265,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -265,7 +265,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{K}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
...@@ -299,6 +300,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -299,6 +300,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -443,6 +445,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -443,6 +445,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false; return false;
} }
if(arg.kraw_ % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -422,7 +422,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -422,7 +422,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{KRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -448,6 +449,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -448,6 +449,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -578,6 +580,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -578,6 +580,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return false; return false;
} }
if((arg.kraw_ % AK1 != 0 || arg.kraw_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment