"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "7ae450a1bc1d429c4ac43099a32249b79285e146"
Commit 3cc7ac0a authored by carlushuang's avatar carlushuang
Browse files

add online cvt f16->f32

parent 66fd7712
...@@ -90,30 +90,59 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -90,30 +90,59 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vpbroadcastw_%= r_base, r_stride, i_scale, i_offset, xmm\n"
".if \\i_scale != 0\n"
"vpbroadcastw \\i_offset(\\r_base, \\r_stride, \\i_scale), \\xmm\n"
".else\n"
"vpbroadcastw \\i_offset(\\r_base), \\xmm\n"
".endif\n"
".endm\n"
".macro vcvtph2ps_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
"vcvtph2ps \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n"
"vcvtph2ps \\i_offset(\\r_base), \\ymm\n"
".endif\n"
".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx ".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx
".if m_TransA == 0\n" ".if m_ABytes == 4\n"
"vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * 4, \\ymm\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * m_ABytes, \\ymm\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * m_ABytes, \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, \\i_m-3, \\i_k * m_ABytes, \\ymm\n"
".endif\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * 4, \\ymm\n" "vpbroadcastw_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * m_ABytes, %%xmm15\n"
".else\n" ".else\n"
"vbroadcastss_%= %%r8, %%rcx, \\i_m-3, \\i_k * 4, \\ymm\n" ".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_m, \\i_k * m_ABytes, %%xmm15\n"
".else\n"
"vpbroadcastw_%= %%r8, %%rcx, \\i_m-3, \\i_k * m_ABytes, %%xmm15\n"
".endif\n"
".endif\n" ".endif\n"
"vcvtph2ps %%xmm15, \\ymm\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1 ".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4\n" ".if m_BBytes == 4\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*m_BBytes*8, \\ymm\n"
".else\n" ".else\n"
"vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n" "vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*m_BBytes, \\ymm\n"
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n" "vcvtph2ps_%= %%rbx, %%rdx, \\i_n, \\i_k*m_BBytes*8, \\ymm\n"
".else\n" ".else\n"
"vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n" "vcvtph2ps_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*m_BBytes, \\ymm\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
...@@ -168,15 +197,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -168,15 +197,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endr\n" ".endr\n"
".if m_TransA != 0\n" ".if m_TransA != 0\n"
" lea 4*4(%%rax), %%rax\n" " lea 4*m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4*4(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4 * 4(%%rax), %%rax\n" " lea m_Mr * 4 * m_ABytes(%%rax), %%rax\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4 * 4(%%rbx), %%rbx\n" " lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx\n"
".else\n" ".else\n"
" lea 8 * 4 * 4(%%rbx), %%rbx\n" " lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
"sub $4, %%rsi\n" "sub $4, %%rsi\n"
...@@ -210,15 +239,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -210,15 +239,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1 ".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1
".if m_TransA != 0\n" ".if m_TransA != 0\n"
" lea 4(%%rax), %%rax\n" " lea m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4(%%rax), %%rax\n" " lea m_Mr * m_ABytes(%%rax), %%rax\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4(%%rbx), %%rbx\n" " lea m_Nr * m_BBytes(%%rbx), %%rbx\n"
".else\n" ".else\n"
" lea 8*4(%%rbx), %%rbx\n" " lea 8*m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
"sub $1, %%rsi\n" "sub $1, %%rsi\n"
...@@ -380,30 +409,59 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -380,30 +409,59 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vpbroadcastw_%= r_base, r_stride, i_scale, i_offset, xmm\n"
".if \\i_scale != 0\n"
"vpbroadcastw \\i_offset(\\r_base, \\r_stride, \\i_scale), \\xmm\n"
".else\n"
"vpbroadcastw \\i_offset(\\r_base), \\xmm\n"
".endif\n"
".endm\n"
".macro vcvtph2ps_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
"vcvtph2ps \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n"
"vcvtph2ps \\i_offset(\\r_base), \\ymm\n"
".endif\n"
".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx ".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx
".if m_TransA == 0\n" ".if m_ABytes == 4\n"
"vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * 4, \\ymm\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * m_ABytes, \\ymm\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * m_ABytes, \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, \\i_m-2, \\i_k * m_ABytes, \\ymm\n"
".endif\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * 4, \\ymm\n" "vpbroadcastw_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * m_ABytes, %%xmm15\n"
".else\n" ".else\n"
"vbroadcastss_%= %%r8, %%rcx, \\i_m-2, \\i_k * 4, \\ymm\n" ".if (\\i_m == 0) || (\\i_m == 1)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_m, \\i_k * m_ABytes, %%xmm15\n"
".else\n"
"vpbroadcastw_%= %%r8, %%rcx, \\i_m-2, \\i_k * m_ABytes, %%xmm15\n"
".endif\n"
".endif\n" ".endif\n"
"vcvtph2ps %%xmm15, \\ymm\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1, 2 ".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1, 2
".if m_BBytes == 4\n" ".if m_BBytes == 4\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*m_BBytes*8, \\ymm\n"
".else\n" ".else\n"
"vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n" "vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*m_BBytes, \\ymm\n"
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n" "vcvtph2ps_%= %%rbx, %%rdx, \\i_n, \\i_k*m_BBytes*8, \\ymm\n"
".else\n" ".else\n"
"vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n" "vcvtph2ps_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*m_BBytes, \\ymm\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
...@@ -457,15 +515,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -457,15 +515,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endr\n" ".endr\n"
".if m_TransA != 0\n" ".if m_TransA != 0\n"
" lea 4*4(%%rax), %%rax\n" " lea 4*m_ABytes(%%rax), %%rax\n"
".if m_Mr > 2\n lea 4*4(%%r8), %%r8\n .endif\n" ".if m_Mr > 2\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4 * 4(%%rax), %%rax\n" " lea m_Mr * 4 * m_ABytes(%%rax), %%rax\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4 * 4(%%rbx), %%rbx\n" " lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx\n"
".else\n" ".else\n"
" lea 8 * 4 * 4(%%rbx), %%rbx\n" " lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
"sub $4, %%rsi\n" "sub $4, %%rsi\n"
...@@ -499,15 +557,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -499,15 +557,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 3) && (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm11\n .endif\n" // 3x2 ".if (m_Mr > 3) && (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm11\n .endif\n" // 3x2
".if m_TransA != 0\n" ".if m_TransA != 0\n"
" lea 4(%%rax), %%rax\n" " lea m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4(%%rax), %%rax\n" " lea m_Mr * m_ABytes(%%rax), %%rax\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4(%%rbx), %%rbx\n" " lea m_Nr * m_BBytes(%%rbx), %%rbx\n"
".else\n" ".else\n"
" lea 8*4(%%rbx), %%rbx\n" " lea 8*m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
"sub $1, %%rsi\n" "sub $1, %%rsi\n"
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tuple> #include <tuple>
#include <memory> #include <memory>
#include <chrono> #include <chrono>
#include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "cpuid.hpp" #include "cpuid.hpp"
...@@ -26,7 +27,7 @@ ...@@ -26,7 +27,7 @@
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 8, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 8, TA, TB, NT> ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 8, TA, TB, NT>
// #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \ //#define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT> // ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
#define ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(FA, FB, FC, TA, TB, NT) \ #define ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(FA, FB, FC, TA, TB, NT) \
...@@ -46,16 +47,22 @@ ...@@ -46,16 +47,22 @@
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// using AType = half_float::half;
// using BType = half_float::half;
using AType = float;
using BType = float;
using CType = float;
template <typename ALayout, typename BLayout> template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_6x16_instances = std::tuple< using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format off // clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore // FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, ALayout, BLayout, false) ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false) // ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(AType, BType, CType, ALayout, BLayout, false)
// clang-format on // clang-format on
>; >;
...@@ -63,10 +70,10 @@ template <typename ALayout, typename BLayout> ...@@ -63,10 +70,10 @@ template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_4x24_instances = std::tuple< using thread_gemm_avx2_mxn_4x24_instances = std::tuple<
// clang-format off // clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore // FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(float, float, float, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(float, float, float, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(float, float, float, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(float, float, float, ALayout, BLayout, false) ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false)
// clang-format on // clang-format on
>; >;
...@@ -175,14 +182,9 @@ bool valid_vector(const float* ref, const float* rhs, uint32_t elem) ...@@ -175,14 +182,9 @@ bool valid_vector(const float* ref, const float* rhs, uint32_t elem)
return err == 0; return err == 0;
} }
template <typename data_type, typename ALayout, typename BLayout> template <typename FloatA, typename FloatB, typename ALayout, typename BLayout>
void ref_cpu_gemm_uk(const data_type* a, void ref_cpu_gemm_uk(
const data_type* b, const FloatA* a, const FloatB* b, float* c, float alpha, uint32_t m, uint32_t n, uint32_t k)
float* c,
float alpha,
uint32_t m,
uint32_t n,
uint32_t k)
{ {
auto a_offset = [&](uint32_t im, uint32_t ik) { auto a_offset = [&](uint32_t im, uint32_t ik) {
if constexpr(std::is_same<Row, ALayout>::value) if constexpr(std::is_same<Row, ALayout>::value)
...@@ -216,7 +218,8 @@ void ref_cpu_gemm_uk(const data_type* a, ...@@ -216,7 +218,8 @@ void ref_cpu_gemm_uk(const data_type* a,
float acc = .0f; float acc = .0f;
for(uint32_t ik = 0; ik < k; ik++) for(uint32_t ik = 0; ik < k; ik++)
{ {
acc += a[a_offset(im, ik)] * b[b_offset(ik, in)]; acc += static_cast<float>(a[a_offset(im, ik)]) *
static_cast<float>(b[b_offset(ik, in)]);
} }
acc *= alpha; acc *= alpha;
c[c_offset(im, in)] = acc; c[c_offset(im, in)] = acc;
...@@ -224,10 +227,10 @@ void ref_cpu_gemm_uk(const data_type* a, ...@@ -224,10 +227,10 @@ void ref_cpu_gemm_uk(const data_type* a,
} }
} }
template <typename data_type, typename ALayout, typename BLayout, typename ukenrel_t> template <typename FloatA, typename FloatB, typename ALayout, typename BLayout, typename ukenrel_t>
void test_ukernel(ukenrel_t uk, void test_ukernel(ukenrel_t uk,
data_type* mat_a, FloatA* mat_a,
data_type* mat_b, FloatB* mat_b,
float* mat_c, float* mat_c,
float alpha, float alpha,
uint32_t m, uint32_t m,
...@@ -239,8 +242,8 @@ void test_ukernel(ukenrel_t uk, ...@@ -239,8 +242,8 @@ void test_ukernel(ukenrel_t uk,
param.p_b = mat_b; param.p_b = mat_b;
param.p_c = mat_c; param.p_c = mat_c;
param.Kr = k; param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(data_type); param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(FloatA);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(data_type); param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB);
param.ldc = n * sizeof(float); param.ldc = n * sizeof(float);
param.alpha = alpha; param.alpha = alpha;
...@@ -248,10 +251,10 @@ void test_ukernel(ukenrel_t uk, ...@@ -248,10 +251,10 @@ void test_ukernel(ukenrel_t uk,
if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Row, BLayout>::value) if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Row, BLayout>::value)
{ {
assert(m % uk.Mr_ == 0 && n == uk.Nr_); assert(m % uk.Mr_ == 0 && n == uk.Nr_);
data_type* p_a = mat_a; FloatA* p_a = mat_a;
float* p_c = mat_c; float* p_c = mat_c;
param.p_a = p_a; param.p_a = p_a;
param.p_c = p_c; param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_) for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
{ {
uk.Run(&param); uk.Run(&param);
...@@ -264,15 +267,15 @@ void test_ukernel(ukenrel_t uk, ...@@ -264,15 +267,15 @@ void test_ukernel(ukenrel_t uk,
else if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Col, BLayout>::value) else if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Col, BLayout>::value)
{ {
assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0); assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
data_type* p_a = mat_a; FloatA* p_a = mat_a;
float* p_c = mat_c; float* p_c = mat_c;
param.p_a = p_a; param.p_a = p_a;
param.p_b = mat_b; param.p_b = mat_b;
param.p_c = p_c; param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_) for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
{ {
float* p_c_n = p_c; float* p_c_n = p_c;
float* p_b_n = mat_b; FloatB* p_b_n = mat_b;
for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_) for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_)
{ {
uk.Run(&param); uk.Run(&param);
...@@ -296,10 +299,10 @@ void test_ukernel(ukenrel_t uk, ...@@ -296,10 +299,10 @@ void test_ukernel(ukenrel_t uk,
else else
{ {
assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0); assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
data_type* p_b = mat_b; FloatB* p_b = mat_b;
float* p_c = mat_c; float* p_c = mat_c;
param.p_b = p_b; param.p_b = p_b;
param.p_c = p_c; param.p_c = p_c;
for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_) for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_)
{ {
uk.Run(&param); uk.Run(&param);
...@@ -343,14 +346,12 @@ void test_ukernel(ukenrel_t uk, ...@@ -343,14 +346,12 @@ void test_ukernel(ukenrel_t uk,
} }
// implement small ukernel on L1 // implement small ukernel on L1
template <typename data_type, typename ALayout, typename BLayout> template <typename FloatA, typename FloatB, typename ALayout, typename BLayout>
void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{ {
data_type* mat_a = FloatA* mat_a = reinterpret_cast<FloatA*>(__aligned_malloc(m * k * sizeof(FloatA), 32));
reinterpret_cast<data_type*>(__aligned_malloc(m * k * sizeof(data_type), 32)); FloatB* mat_b = reinterpret_cast<FloatB*>(__aligned_malloc(k * n * sizeof(FloatB), 32));
data_type* mat_b = float* mat_c = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
reinterpret_cast<data_type*>(__aligned_malloc(k * n * sizeof(data_type), 32));
float* mat_c = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
float* mat_c_ref = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32)); float* mat_c_ref = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
memset(mat_c_ref, 0, m * n * sizeof(float)); memset(mat_c_ref, 0, m * n * sizeof(float));
...@@ -358,11 +359,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) ...@@ -358,11 +359,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
rand_vector(mat_a, m * k); rand_vector(mat_a, m * k);
rand_vector(mat_b, k * n); rand_vector(mat_b, k * n);
ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k); ref_cpu_gemm_uk<FloatA, FloatB, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k);
using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>; using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
// using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>; // using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool found = false; bool found = false;
ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) { ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_instance>; using uk_type = std::tuple_element_t<i, thread_gemm_instance>;
...@@ -376,7 +377,8 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) ...@@ -376,7 +377,8 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
if(found) if(found)
return; return;
test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k); test_ukernel<FloatA, FloatB, ALayout, BLayout>(
uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k);
bool is_valid = valid_vector(mat_c_ref, mat_c, m * n); bool is_valid = valid_vector(mat_c_ref, mat_c, m * n);
printf("vald:%s\n", is_valid ? "y" : "n"); printf("vald:%s\n", is_valid ? "y" : "n");
...@@ -406,8 +408,8 @@ int main(int argc, char** argv) ...@@ -406,8 +408,8 @@ int main(int argc, char** argv)
alpha = std::atof(argv[4]); alpha = std::atof(argv[4]);
} }
dump_cache_hierarchy(); dump_cache_hierarchy();
test_cpu_ukernel<float, Row, Row>(alpha, m, n, k); test_cpu_ukernel<AType, BType, Row, Row>(alpha, m, n, k);
test_cpu_ukernel<float, Row, Col>(alpha, m, n, k); test_cpu_ukernel<AType, BType, Row, Col>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Row>(alpha, m, n, k); test_cpu_ukernel<AType, BType, Col, Row>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Col>(alpha, m, n, k); test_cpu_ukernel<AType, BType, Col, Col>(alpha, m, n, k);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment