Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3cc7ac0a
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "7ae450a1bc1d429c4ac43099a32249b79285e146"
Commit
3cc7ac0a
authored
Mar 27, 2022
by
carlushuang
Browse files
add online cvt f16->f32
parent
66fd7712
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
150 additions
and
90 deletions
+150
-90
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+96
-38
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+54
-52
No files found.
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
3cc7ac0a
...
@@ -90,30 +90,59 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -90,30 +90,59 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vpbroadcastw_%= r_base, r_stride, i_scale, i_offset, xmm
\n
"
".if
\\
i_scale != 0
\n
"
"vpbroadcastw
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
xmm
\n
"
".else
\n
"
"vpbroadcastw
\\
i_offset(
\\
r_base),
\\
xmm
\n
"
".endif
\n
"
".endm
\n
"
".macro vcvtph2ps_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
"vcvtph2ps
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
ymm
\n
"
".else
\n
"
"vcvtph2ps
\\
i_offset(
\\
r_base),
\\
ymm
\n
"
".endif
\n
"
".endm
\n
"
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8, r9), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8, r9), lda in rcx
".if m_TransA == 0
\n
"
".if m_ABytes == 4
\n
"
"vbroadcastss_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * 4,
\\
ymm
\n
"
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * m_ABytes,
\\
ymm
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * m_ABytes,
\\
ymm
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-3,
\\
i_k * m_ABytes,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if
(
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
".if
m_TransA == 0
\n
"
"vbroadcast
ss
_%= %%rax,
%%rcx
,
\\
i_m
,
\\
i_k *
4,
\\
ymm
\n
"
"v
p
broadcast
w
_%= %%rax,
0, 0
,
(
\\
i_m
+
\\
i_k *
m_Mr) * m_ABytes, %%xmm15
\n
"
".else
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-3,
\\
i_k * 4,
\\
ymm
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * m_ABytes, %%xmm15
\n
"
".else
\n
"
"vpbroadcastw_%= %%r8, %%rcx,
\\
i_m-3,
\\
i_k * m_ABytes, %%xmm15
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
"vcvtph2ps %%xmm15,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4
\n
"
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*
4
*8,
\\
ymm
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*
m_BBytes
*8,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*
4
,
\\
ymm
\n
"
"vmovups_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*
m_BBytes
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"v
movu
ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*
4
*8,
\\
ymm
\n
"
"v
cvtph2
ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*
m_BBytes
*8,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"v
movu
ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*
4
,
\\
ymm
\n
"
"v
cvtph2
ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*
m_BBytes
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -168,15 +197,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -168,15 +197,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endr
\n
"
".endr
\n
"
".if m_TransA != 0
\n
"
".if m_TransA != 0
\n
"
" lea 4*
4
(%%rax), %%rax
\n
"
" lea 4*
m_ABytes
(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea 4*
4
(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea 4*
m_ABytes
(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * 4 *
4
(%%rax), %%rax
\n
"
" lea m_Mr * 4 *
m_ABytes
(%%rax), %%rax
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4 *
4
(%%rbx), %%rbx
\n
"
" lea m_Nr * 4 *
m_BBytes
(%%rbx), %%rbx
\n
"
".else
\n
"
".else
\n
"
" lea 8 * 4 *
4
(%%rbx), %%rbx
\n
"
" lea 8 * 4 *
m_BBytes
(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
"sub $4, %%rsi
\n
"
"sub $4, %%rsi
\n
"
...
@@ -210,15 +239,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -210,15 +239,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm11
\n
.endif
\n
"
// 5x1
".if (m_Mr > 5) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm11
\n
.endif
\n
"
// 5x1
".if m_TransA != 0
\n
"
".if m_TransA != 0
\n
"
" lea
4
(%%rax), %%rax
\n
"
" lea
m_ABytes
(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea
4
(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea
m_ABytes
(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr *
4
(%%rax), %%rax
\n
"
" lea m_Mr *
m_ABytes
(%%rax), %%rax
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr *
4
(%%rbx), %%rbx
\n
"
" lea m_Nr *
m_BBytes
(%%rbx), %%rbx
\n
"
".else
\n
"
".else
\n
"
" lea 8*
4
(%%rbx), %%rbx
\n
"
" lea 8*
m_BBytes
(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
"sub $1, %%rsi
\n
"
"sub $1, %%rsi
\n
"
...
@@ -380,30 +409,59 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -380,30 +409,59 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vpbroadcastw_%= r_base, r_stride, i_scale, i_offset, xmm
\n
"
".if
\\
i_scale != 0
\n
"
"vpbroadcastw
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
xmm
\n
"
".else
\n
"
"vpbroadcastw
\\
i_offset(
\\
r_base),
\\
xmm
\n
"
".endif
\n
"
".endm
\n
"
".macro vcvtph2ps_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
"vcvtph2ps
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
ymm
\n
"
".else
\n
"
"vcvtph2ps
\\
i_offset(
\\
r_base),
\\
ymm
\n
"
".endif
\n
"
".endm
\n
"
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".if m_TransA == 0
\n
"
".if m_ABytes == 4
\n
"
"vbroadcastss_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * 4,
\\
ymm
\n
"
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * m_ABytes,
\\
ymm
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * m_ABytes,
\\
ymm
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-2,
\\
i_k * m_ABytes,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if
(
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
".if
m_TransA == 0
\n
"
"vbroadcast
ss
_%= %%rax,
%%rcx
,
\\
i_m
,
\\
i_k *
4,
\\
ymm
\n
"
"v
p
broadcast
w
_%= %%rax,
0, 0
,
(
\\
i_m
+
\\
i_k *
m_Mr) * m_ABytes, %%xmm15
\n
"
".else
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-2,
\\
i_k * 4,
\\
ymm
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * m_ABytes, %%xmm15
\n
"
".else
\n
"
"vpbroadcastw_%= %%r8, %%rcx,
\\
i_m-2,
\\
i_k * m_ABytes, %%xmm15
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
"vcvtph2ps %%xmm15,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1, 2
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1, 2
".if m_BBytes == 4
\n
"
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*
4
*8,
\\
ymm
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*
m_BBytes
*8,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*
4
,
\\
ymm
\n
"
"vmovups_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*
m_BBytes
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"v
movu
ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*
4
*8,
\\
ymm
\n
"
"v
cvtph2
ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*
m_BBytes
*8,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"v
movu
ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*
4
,
\\
ymm
\n
"
"v
cvtph2
ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*
m_BBytes
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -457,15 +515,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -457,15 +515,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endr
\n
"
".endr
\n
"
".if m_TransA != 0
\n
"
".if m_TransA != 0
\n
"
" lea 4*
4
(%%rax), %%rax
\n
"
" lea 4*
m_ABytes
(%%rax), %%rax
\n
"
".if m_Mr > 2
\n
lea 4*
4
(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 2
\n
lea 4*
m_ABytes
(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * 4 *
4
(%%rax), %%rax
\n
"
" lea m_Mr * 4 *
m_ABytes
(%%rax), %%rax
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4 *
4
(%%rbx), %%rbx
\n
"
" lea m_Nr * 4 *
m_BBytes
(%%rbx), %%rbx
\n
"
".else
\n
"
".else
\n
"
" lea 8 * 4 *
4
(%%rbx), %%rbx
\n
"
" lea 8 * 4 *
m_BBytes
(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
"sub $4, %%rsi
\n
"
"sub $4, %%rsi
\n
"
...
@@ -499,15 +557,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -499,15 +557,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 3) && (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm11
\n
.endif
\n
"
// 3x2
".if (m_Mr > 3) && (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm11
\n
.endif
\n
"
// 3x2
".if m_TransA != 0
\n
"
".if m_TransA != 0
\n
"
" lea
4
(%%rax), %%rax
\n
"
" lea
m_ABytes
(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea
4
(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea
m_ABytes
(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr *
4
(%%rax), %%rax
\n
"
" lea m_Mr *
m_ABytes
(%%rax), %%rax
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr *
4
(%%rbx), %%rbx
\n
"
" lea m_Nr *
m_BBytes
(%%rbx), %%rbx
\n
"
".else
\n
"
".else
\n
"
" lea 8*
4
(%%rbx), %%rbx
\n
"
" lea 8*
m_BBytes
(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
"sub $1, %%rsi
\n
"
"sub $1, %%rsi
\n
"
...
...
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
3cc7ac0a
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <tuple>
#include <tuple>
#include <memory>
#include <memory>
#include <chrono>
#include <chrono>
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "print.hpp"
#include "print.hpp"
#include "cpuid.hpp"
#include "cpuid.hpp"
...
@@ -26,7 +27,7 @@
...
@@ -26,7 +27,7 @@
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 8, TA, TB, NT>
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 8, TA, TB, NT>
//
#define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
//#define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
#define ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(FA, FB, FC, TA, TB, NT) \
#define ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(FA, FB, FC, TA, TB, NT) \
...
@@ -46,16 +47,22 @@
...
@@ -46,16 +47,22 @@
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// using AType = half_float::half;
// using BType = half_float::half;
using
AType
=
float
;
using
BType
=
float
;
using
CType
=
float
;
template
<
typename
ALayout
,
typename
BLayout
>
template
<
typename
ALayout
,
typename
BLayout
>
using
thread_gemm_avx2_mxn_6x16_instances
=
std
::
tuple
<
using
thread_gemm_avx2_mxn_6x16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
)
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
false
)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(
float, float, float, Row, Col
, false)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(
AType, BType, CType, ALayout, BLayout
, false)
// clang-format on
// clang-format on
>
;
>
;
...
@@ -63,10 +70,10 @@ template <typename ALayout, typename BLayout>
...
@@ -63,10 +70,10 @@ template <typename ALayout, typename BLayout>
using
thread_gemm_avx2_mxn_4x24_instances
=
std
::
tuple
<
using
thread_gemm_avx2_mxn_4x24_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
)
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
false
)
// clang-format on
// clang-format on
>
;
>
;
...
@@ -175,14 +182,9 @@ bool valid_vector(const float* ref, const float* rhs, uint32_t elem)
...
@@ -175,14 +182,9 @@ bool valid_vector(const float* ref, const float* rhs, uint32_t elem)
return
err
==
0
;
return
err
==
0
;
}
}
template
<
typename
data_type
,
typename
ALayout
,
typename
BLayout
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
ALayout
,
typename
BLayout
>
void
ref_cpu_gemm_uk
(
const
data_type
*
a
,
void
ref_cpu_gemm_uk
(
const
data_type
*
b
,
const
FloatA
*
a
,
const
FloatB
*
b
,
float
*
c
,
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
float
*
c
,
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
{
{
auto
a_offset
=
[
&
](
uint32_t
im
,
uint32_t
ik
)
{
auto
a_offset
=
[
&
](
uint32_t
im
,
uint32_t
ik
)
{
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
)
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
)
...
@@ -216,7 +218,8 @@ void ref_cpu_gemm_uk(const data_type* a,
...
@@ -216,7 +218,8 @@ void ref_cpu_gemm_uk(const data_type* a,
float
acc
=
.0
f
;
float
acc
=
.0
f
;
for
(
uint32_t
ik
=
0
;
ik
<
k
;
ik
++
)
for
(
uint32_t
ik
=
0
;
ik
<
k
;
ik
++
)
{
{
acc
+=
a
[
a_offset
(
im
,
ik
)]
*
b
[
b_offset
(
ik
,
in
)];
acc
+=
static_cast
<
float
>
(
a
[
a_offset
(
im
,
ik
)])
*
static_cast
<
float
>
(
b
[
b_offset
(
ik
,
in
)]);
}
}
acc
*=
alpha
;
acc
*=
alpha
;
c
[
c_offset
(
im
,
in
)]
=
acc
;
c
[
c_offset
(
im
,
in
)]
=
acc
;
...
@@ -224,10 +227,10 @@ void ref_cpu_gemm_uk(const data_type* a,
...
@@ -224,10 +227,10 @@ void ref_cpu_gemm_uk(const data_type* a,
}
}
}
}
template
<
typename
data_type
,
typename
ALayout
,
typename
BLayout
,
typename
ukenrel_t
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
ALayout
,
typename
BLayout
,
typename
ukenrel_t
>
void
test_ukernel
(
ukenrel_t
uk
,
void
test_ukernel
(
ukenrel_t
uk
,
data_type
*
mat_a
,
FloatA
*
mat_a
,
data_type
*
mat_b
,
FloatB
*
mat_b
,
float
*
mat_c
,
float
*
mat_c
,
float
alpha
,
float
alpha
,
uint32_t
m
,
uint32_t
m
,
...
@@ -239,8 +242,8 @@ void test_ukernel(ukenrel_t uk,
...
@@ -239,8 +242,8 @@ void test_ukernel(ukenrel_t uk,
param
.
p_b
=
mat_b
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
mat_c
;
param
.
p_c
=
mat_c
;
param
.
Kr
=
k
;
param
.
Kr
=
k
;
param
.
lda
=
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
k
:
m
)
*
sizeof
(
data_type
);
param
.
lda
=
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
k
:
m
)
*
sizeof
(
FloatA
);
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
data_type
);
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
FloatB
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
alpha
=
alpha
;
param
.
alpha
=
alpha
;
...
@@ -248,10 +251,10 @@ void test_ukernel(ukenrel_t uk,
...
@@ -248,10 +251,10 @@ void test_ukernel(ukenrel_t uk,
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Row
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Row
,
BLayout
>::
value
)
{
{
assert
(
m
%
uk
.
Mr_
==
0
&&
n
==
uk
.
Nr_
);
assert
(
m
%
uk
.
Mr_
==
0
&&
n
==
uk
.
Nr_
);
data_type
*
p_a
=
mat_a
;
FloatA
*
p_a
=
mat_a
;
float
*
p_c
=
mat_c
;
float
*
p_c
=
mat_c
;
param
.
p_a
=
p_a
;
param
.
p_a
=
p_a
;
param
.
p_c
=
p_c
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
Mr_
)
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
Mr_
)
{
{
uk
.
Run
(
&
param
);
uk
.
Run
(
&
param
);
...
@@ -264,15 +267,15 @@ void test_ukernel(ukenrel_t uk,
...
@@ -264,15 +267,15 @@ void test_ukernel(ukenrel_t uk,
else
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Col
,
BLayout
>::
value
)
else
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Col
,
BLayout
>::
value
)
{
{
assert
(
m
%
uk
.
Mr_
==
0
&&
n
%
uk
.
Nr_
==
0
);
assert
(
m
%
uk
.
Mr_
==
0
&&
n
%
uk
.
Nr_
==
0
);
data_type
*
p_a
=
mat_a
;
FloatA
*
p_a
=
mat_a
;
float
*
p_c
=
mat_c
;
float
*
p_c
=
mat_c
;
param
.
p_a
=
p_a
;
param
.
p_a
=
p_a
;
param
.
p_b
=
mat_b
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
p_c
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
Mr_
)
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
Mr_
)
{
{
float
*
p_c_n
=
p_c
;
float
*
p_c_n
=
p_c
;
f
loat
*
p_b_n
=
mat_b
;
F
loat
B
*
p_b_n
=
mat_b
;
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
Nr_
)
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
Nr_
)
{
{
uk
.
Run
(
&
param
);
uk
.
Run
(
&
param
);
...
@@ -296,10 +299,10 @@ void test_ukernel(ukenrel_t uk,
...
@@ -296,10 +299,10 @@ void test_ukernel(ukenrel_t uk,
else
else
{
{
assert
(
m
%
uk
.
Mr_
==
0
&&
n
%
uk
.
Nr_
==
0
);
assert
(
m
%
uk
.
Mr_
==
0
&&
n
%
uk
.
Nr_
==
0
);
data_type
*
p_b
=
mat_b
;
FloatB
*
p_b
=
mat_b
;
float
*
p_c
=
mat_c
;
float
*
p_c
=
mat_c
;
param
.
p_b
=
p_b
;
param
.
p_b
=
p_b
;
param
.
p_c
=
p_c
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
Nr_
)
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
Nr_
)
{
{
uk
.
Run
(
&
param
);
uk
.
Run
(
&
param
);
...
@@ -343,14 +346,12 @@ void test_ukernel(ukenrel_t uk,
...
@@ -343,14 +346,12 @@ void test_ukernel(ukenrel_t uk,
}
}
// implement small ukernel on L1
// implement small ukernel on L1
template
<
typename
data_type
,
typename
ALayout
,
typename
BLayout
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
ALayout
,
typename
BLayout
>
void
test_cpu_ukernel
(
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
void
test_cpu_ukernel
(
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
{
{
data_type
*
mat_a
=
FloatA
*
mat_a
=
reinterpret_cast
<
FloatA
*>
(
__aligned_malloc
(
m
*
k
*
sizeof
(
FloatA
),
32
));
reinterpret_cast
<
data_type
*>
(
__aligned_malloc
(
m
*
k
*
sizeof
(
data_type
),
32
));
FloatB
*
mat_b
=
reinterpret_cast
<
FloatB
*>
(
__aligned_malloc
(
k
*
n
*
sizeof
(
FloatB
),
32
));
data_type
*
mat_b
=
float
*
mat_c
=
reinterpret_cast
<
float
*>
(
__aligned_malloc
(
m
*
n
*
sizeof
(
float
),
32
));
reinterpret_cast
<
data_type
*>
(
__aligned_malloc
(
k
*
n
*
sizeof
(
data_type
),
32
));
float
*
mat_c
=
reinterpret_cast
<
float
*>
(
__aligned_malloc
(
m
*
n
*
sizeof
(
float
),
32
));
float
*
mat_c_ref
=
reinterpret_cast
<
float
*>
(
__aligned_malloc
(
m
*
n
*
sizeof
(
float
),
32
));
float
*
mat_c_ref
=
reinterpret_cast
<
float
*>
(
__aligned_malloc
(
m
*
n
*
sizeof
(
float
),
32
));
memset
(
mat_c_ref
,
0
,
m
*
n
*
sizeof
(
float
));
memset
(
mat_c_ref
,
0
,
m
*
n
*
sizeof
(
float
));
...
@@ -358,11 +359,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
...
@@ -358,11 +359,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
rand_vector
(
mat_a
,
m
*
k
);
rand_vector
(
mat_a
,
m
*
k
);
rand_vector
(
mat_b
,
k
*
n
);
rand_vector
(
mat_b
,
k
*
n
);
ref_cpu_gemm_uk
<
data_type
,
ALayout
,
BLayout
>
(
mat_a
,
mat_b
,
mat_c_ref
,
alpha
,
m
,
n
,
k
);
ref_cpu_gemm_uk
<
FloatA
,
FloatB
,
ALayout
,
BLayout
>
(
mat_a
,
mat_b
,
mat_c_ref
,
alpha
,
m
,
n
,
k
);
using
thread_gemm_instance
=
thread_gemm_avx2_mxn_6x16_instances
<
ALayout
,
BLayout
>
;
using
thread_gemm_instance
=
thread_gemm_avx2_mxn_6x16_instances
<
ALayout
,
BLayout
>
;
// using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
// using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool
found
=
false
;
bool
found
=
false
;
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_instance
>
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_instance
>
,
1
>
{}([
&
](
auto
i
)
{
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_instance
>
;
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_instance
>
;
...
@@ -376,7 +377,8 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
...
@@ -376,7 +377,8 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
if
(
found
)
if
(
found
)
return
;
return
;
test_ukernel
<
data_type
,
ALayout
,
BLayout
>
(
uk_type
{},
mat_a
,
mat_b
,
mat_c
,
alpha
,
m
,
n
,
k
);
test_ukernel
<
FloatA
,
FloatB
,
ALayout
,
BLayout
>
(
uk_type
{},
mat_a
,
mat_b
,
mat_c
,
alpha
,
m
,
n
,
k
);
bool
is_valid
=
valid_vector
(
mat_c_ref
,
mat_c
,
m
*
n
);
bool
is_valid
=
valid_vector
(
mat_c_ref
,
mat_c
,
m
*
n
);
printf
(
"vald:%s
\n
"
,
is_valid
?
"y"
:
"n"
);
printf
(
"vald:%s
\n
"
,
is_valid
?
"y"
:
"n"
);
...
@@ -406,8 +408,8 @@ int main(int argc, char** argv)
...
@@ -406,8 +408,8 @@ int main(int argc, char** argv)
alpha
=
std
::
atof
(
argv
[
4
]);
alpha
=
std
::
atof
(
argv
[
4
]);
}
}
dump_cache_hierarchy
();
dump_cache_hierarchy
();
test_cpu_ukernel
<
float
,
Row
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
float
,
Row
,
Col
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Col
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
float
,
Col
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
float
,
Col
,
Col
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Col
>
(
alpha
,
m
,
n
,
k
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment