Commit 2af81173 authored by carlushuang's avatar carlushuang
Browse files

only run the first A/B layout if match m/n

parent 35f95fe9
...@@ -32,13 +32,14 @@ ...@@ -32,13 +32,14 @@
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_6x16_instances = std::tuple< using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format off // clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore // FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Row, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Row, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Col, false) ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false) // ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false)
// clang-format on // clang-format on
...@@ -239,11 +240,10 @@ void test_ukernel(ukenrel_t uk, ...@@ -239,11 +240,10 @@ void test_ukernel(ukenrel_t uk,
{ {
assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0); assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
data_type* p_a = mat_a; data_type* p_a = mat_a;
// data_type* p_b = mat_b; float* p_c = mat_c;
float* p_c = mat_c; param.p_a = p_a;
param.p_a = p_a; param.p_b = mat_b;
param.p_b = mat_b; param.p_c = p_c;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_) for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
{ {
float* p_c_n = p_c; float* p_c_n = p_c;
...@@ -335,13 +335,16 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) ...@@ -335,13 +335,16 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k); ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k);
ck::static_for<0, std::tuple_size_v<thread_gemm_avx2_mxn_6x16_instances>, 1>{}([&](auto i) { using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
using uk_type = std::tuple_element_t<i, thread_gemm_avx2_mxn_6x16_instances>; bool found = false;
if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
!std::is_same<typename uk_type::BLayout_, BLayout>::value) ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) {
{ using uk_type = std::tuple_element_t<i, thread_gemm_instance>;
return; // if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
} // !std::is_same<typename uk_type::BLayout_, BLayout>::value)
// {
// return;
// }
if(m % uk_type::Mr_ != 0 || n % uk_type::Nr_ != 0) if(m % uk_type::Mr_ != 0 || n % uk_type::Nr_ != 0)
return; return;
if((m != uk_type::Mr_ && std::is_same<typename uk_type::ALayout_, Col>::value) || if((m != uk_type::Mr_ && std::is_same<typename uk_type::ALayout_, Col>::value) ||
...@@ -349,12 +352,14 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) ...@@ -349,12 +352,14 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
// only k is the fast changing dim of A/B can we do muldiplt m, n // only k is the fast changing dim of A/B can we do muldiplt m, n
return; return;
if(found)
return;
test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k); test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k);
bool is_valid = valid_vector(mat_c_ref, mat_c, m * n); bool is_valid = valid_vector(mat_c_ref, mat_c, m * n);
printf("vald:%s\n", is_valid ? "y" : "n"); printf("vald:%s\n", is_valid ? "y" : "n");
found = true;
// return ;
}); });
__aligned_free(mat_a); __aligned_free(mat_a);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment