Commit 2af81173 authored by carlushuang's avatar carlushuang
Browse files

only run the first A/B layout if match m/n

parent 35f95fe9
......@@ -32,13 +32,14 @@
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Row, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Row, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Col, false)
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false)
// clang-format on
......@@ -239,7 +240,6 @@ void test_ukernel(ukenrel_t uk,
{
assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
data_type* p_a = mat_a;
// data_type* p_b = mat_b;
float* p_c = mat_c;
param.p_a = p_a;
param.p_b = mat_b;
......@@ -335,13 +335,16 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k);
ck::static_for<0, std::tuple_size_v<thread_gemm_avx2_mxn_6x16_instances>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_avx2_mxn_6x16_instances>;
if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
!std::is_same<typename uk_type::BLayout_, BLayout>::value)
{
return;
}
using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
bool found = false;
ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_instance>;
// if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
// !std::is_same<typename uk_type::BLayout_, BLayout>::value)
// {
// return;
// }
if(m % uk_type::Mr_ != 0 || n % uk_type::Nr_ != 0)
return;
if((m != uk_type::Mr_ && std::is_same<typename uk_type::ALayout_, Col>::value) ||
......@@ -349,12 +352,14 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
// only k is the fast changing dim of A/B can we do muldiplt m, n
return;
if(found)
return;
test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k);
bool is_valid = valid_vector(mat_c_ref, mat_c, m * n);
printf("vald:%s\n", is_valid ? "y" : "n");
// return ;
found = true;
});
__aligned_free(mat_a);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment