Unverified Commit e9d4e893 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

fix build (#434)

* fix

* fix

* add instance
parent aa0b0515
...@@ -332,7 +332,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -332,7 +332,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op} cde_element_op_{cde_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw}
{ {
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -400,6 +403,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -400,6 +403,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// for checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
}; };
// Invoker // Invoker
...@@ -486,6 +494,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -486,6 +494,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return false; return false;
} }
// check vector load/store
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of Ds
// only support RowMajor for now
bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
if(!all_valid)
{
return false;
}
// check vector store of E
// only support RowMajor for now
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
......
...@@ -3,27 +3,27 @@ ...@@ -3,27 +3,27 @@
#include <cstring> #include <cstring>
// int profile_gemm(int, char*[]); int profile_gemm(int, char*[]);
// int profile_gemm_splitk(int, char*[]); int profile_gemm_splitk(int, char*[]);
// int profile_gemm_bilinear(int, char*[]); int profile_gemm_bilinear(int, char*[]);
// int profile_gemm_add_add_fastgelu(int, char*[]); int profile_gemm_add_add_fastgelu(int, char*[]);
// int profile_gemm_reduce(int, char*[]); int profile_gemm_reduce(int, char*[]);
// int profile_gemm_bias_add_reduce(int, char*[]); int profile_gemm_bias_add_reduce(int, char*[]);
// int profile_batched_gemm(int, char*[]); int profile_batched_gemm(int, char*[]);
// int profile_batched_gemm_gemm(int, char*[]); int profile_batched_gemm_gemm(int, char*[]);
// int profile_batched_gemm_add_relu_gemm_add(int, char*[]); int profile_batched_gemm_add_relu_gemm_add(int, char*[]);
// int profile_batched_gemm_reduce(int, char*[]); int profile_batched_gemm_reduce(int, char*[]);
// int profile_grouped_gemm(int, char*[]); int profile_grouped_gemm(int, char*[]);
// int profile_conv_fwd(int, char*[]); int profile_conv_fwd(int, char*[]);
// int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]);
// int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]);
// int profile_conv_bwd_data(int, char*[]); int profile_conv_bwd_data(int, char*[]);
// int profile_conv_bwd_weight(int, char*[]); int profile_conv_bwd_weight(int, char*[]);
// int profile_grouped_conv_fwd(int, char*[]); int profile_grouped_conv_fwd(int, char*[]);
// int profile_normalization(int, char*[]); int profile_normalization(int, char*[]);
int profile_layernorm(int, char*[]); int profile_layernorm(int, char*[]);
int profile_groupnorm(int, char*[]); int profile_groupnorm(int, char*[]);
// int profile_reduce(int, char*[]); int profile_reduce(int, char*[]);
static void print_helper_message() static void print_helper_message()
{ {
...@@ -57,7 +57,6 @@ int main(int argc, char* argv[]) ...@@ -57,7 +57,6 @@ int main(int argc, char* argv[])
return 0; return 0;
} }
#if 0
else if(strcmp(argv[1], "gemm") == 0) else if(strcmp(argv[1], "gemm") == 0)
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
...@@ -134,7 +133,6 @@ int main(int argc, char* argv[]) ...@@ -134,7 +133,6 @@ int main(int argc, char* argv[])
{ {
return profile_normalization(argc, argv); return profile_normalization(argc, argv);
} }
#endif
else if(strcmp(argv[1], "layernorm") == 0) else if(strcmp(argv[1], "layernorm") == 0)
{ {
return profile_layernorm(argc, argv); return profile_layernorm(argc, argv);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment