Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
66fd7712
Commit
66fd7712
authored
Mar 27, 2022
by
carlushuang
Browse files
add 4x24 ukernel
parent
3a4df3da
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
312 additions
and
5 deletions
+312
-5
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+286
-0
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+26
-5
No files found.
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
66fd7712
...
@@ -300,6 +300,292 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -300,6 +300,292 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
}
};
};
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
index_t
Mr
,
index_t
Nr
,
typename
ALayout
,
// default is k*m, trans->m*k
typename
BLayout
,
// default is n/8*k*n8, trans->k*n
bool
NonTemporalStore
>
struct
ThreadwiseGemmAvx2_MxN_4x24
{
using
ALayout_
=
ALayout
;
using
BLayout_
=
BLayout
;
static
constexpr
auto
Mr_
=
Mr
;
static
constexpr
auto
Nr_
=
Nr
;
static
constexpr
auto
NonTemporalStore_
=
NonTemporalStore
;
__host__
constexpr
ThreadwiseGemmAvx2_MxN_4x24
()
{
static_assert
(
Mr
<=
4
&&
Mr
>=
1
&&
(
Nr
==
8
||
Nr
==
16
||
Nr
==
24
),
"wrong! Mr x Nr not valid"
);
}
__host__
static
void
Run
(
ThreadwiseGemmParam
*
param
)
{
/* 4x24 ukernel
*
* Mat_B
* |ymm12 |ymm13 |ymm14 |
* Mat_A +--------+--------+--------+
* ymm15 |ymm0 |ymm1 |ymm2 |
* |ymm3 |ymm4 |ymm5 |
* |ymm6 |ymm7 |ymm8 |
* |ymm9 |ymm10 |ymm11 |
*
* ALayout:ColumnMajor (k*m), lda not needed
* ALayout:RowMajor (m*k), lda = k
* BLayout:ColumnMajor (n/8*k*n8), ldb = k*n8. At least this should be 8 continuous n for a
* ymm register BLayout:RowMajor (k*n), ldb not needed
*
* lda/ldb/ldc all in unit of byte
*
*/
// clang-format off
__asm__
__volatile__
(
"L_GemmAvx2_MxN_4x24_Entry%=:
\n
"
".set m_Mr, %c[m_Mr]
\n
"
".set m_Nr, %c[m_Nr]
\n
"
".set m_TransA, %c[m_TransA]
\n
"
".set m_TransB, %c[m_TransB]
\n
"
".set m_NTStore, %c[m_NTStore]
\n
"
".set m_ABytes, %c[m_ABytes]
\n
"
".set m_BBytes, %c[m_BBytes]
\n
"
".set m_CBytes, %c[m_CBytes]
\n
"
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
".if m_TransA != 0
\n
"
"movq 32(%[m_param]), %%rcx
\n
"
// lda
".endif
\n
"
".if m_TransB == 0
\n
"
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
".endif
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
"vbroadcastss
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
ymm
\n
"
".else
\n
"
"vbroadcastss
\\
i_offset(
\\
r_base),
\\
ymm
\n
"
".endif
\n
"
".endm
\n
"
".macro vmovups_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
"vmovups
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
ymm
\n
"
".else
\n
"
"vmovups
\\
i_offset(
\\
r_base),
\\
ymm
\n
"
".endif
\n
"
".endm
\n
"
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * 4,
\\
ymm
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * 4,
\\
ymm
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-2,
\\
i_k * 4,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1, 2
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*4*8,
\\
ymm
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*4,
\\
ymm
\n
"
".endif
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*4*8,
\\
ymm
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*4,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
" vxorps %%ymm0, %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vxorps %%ymm1, %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Nr >16)
\n
vxorps %%ymm2, %%ymm2, %%ymm2
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vxorps %%ymm3, %%ymm3, %%ymm3
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vxorps %%ymm4, %%ymm4, %%ymm4
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr >16)
\n
vxorps %%ymm5, %%ymm5, %%ymm5
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vxorps %%ymm6, %%ymm6, %%ymm6
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vxorps %%ymm7, %%ymm7, %%ymm7
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr >16)
\n
vxorps %%ymm8, %%ymm8, %%ymm8
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vxorps %%ymm9, %%ymm9, %%ymm9
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vxorps %%ymm10, %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vxorps %%ymm11, %%ymm11, %%ymm11
\n
.endif
\n
"
".if m_TransA != 0
\n
"
".if m_Mr > 2
\n
"
"lea (%%rax, %%rcx, 2), %%r8
\n
"
".endif
\n
"
".endif
\n
"
"cmp $4, %%rsi
\n
"
"jl L_GemmAvx2_MxN_4x24_K_Loop_Remain%=
\n
"
"L_GemmAvx2_MxN_4x24_K_Loop_Start%=:
\n
"
".irp i_k, 0, 1, 2, 3
\n
"
" vload_b%=
\\
i_k, 0, %%ymm12
\n
"
// B
".if (m_Nr > 8)
\n
vload_b%=
\\
i_k, 1, %%ymm13
\n
.endif
\n
"
// B
".if (m_Nr >16)
\n
vload_b%=
\\
i_k, 2, %%ymm14
\n
.endif
\n
"
// B
" vbroadcast_a%=
\\
i_k, 0, %%ymm15
\n
"
// A broadcast 0
" vfmadd231ps %%ymm12, %%ymm15, %%ymm0
\n
"
// 0x0
".if (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm1
\n
.endif
\n
"
// 0x1
".if (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm2
\n
.endif
\n
"
// 0x2
".if (m_Mr > 1)
\n
vbroadcast_a%=
\\
i_k, 1, %%ymm15
\n
.endif
\n
"
// A broadcast 1
".if (m_Mr > 1)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm3
\n
.endif
\n
"
// 1x0
".if (m_Mr > 1) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm4
\n
.endif
\n
"
// 1x1
".if (m_Mr > 1) && (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm5
\n
.endif
\n
"
// 1x2
".if (m_Mr > 2)
\n
vbroadcast_a%=
\\
i_k, 2, %%ymm15
\n
.endif
\n
"
// A broadcast 2
".if (m_Mr > 2)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm6
\n
.endif
\n
"
// 2x0
".if (m_Mr > 2) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm7
\n
.endif
\n
"
// 2x1
".if (m_Mr > 2) && (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm8
\n
.endif
\n
"
// 2x2
".if (m_Mr > 3)
\n
vbroadcast_a%=
\\
i_k, 3, %%ymm15
\n
.endif
\n
"
// A broadcast 3
".if (m_Mr > 3)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm9
\n
.endif
\n
"
// 3x0
".if (m_Mr > 3) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm10
\n
.endif
\n
"
// 3x1
".if (m_Mr > 3) && (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm11
\n
.endif
\n
"
// 3x2
".endr
\n
"
".if m_TransA != 0
\n
"
" lea 4*4(%%rax), %%rax
\n
"
".if m_Mr > 2
\n
lea 4*4(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
" lea m_Mr * 4 * 4(%%rax), %%rax
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4 * 4(%%rbx), %%rbx
\n
"
".else
\n
"
" lea 8 * 4 * 4(%%rbx), %%rbx
\n
"
".endif
\n
"
"sub $4, %%rsi
\n
"
"cmp $4, %%rsi
\n
"
"jge L_GemmAvx2_MxN_4x24_K_Loop_Start%=
\n
"
"testq %%rsi, %%rsi
\n
"
"je L_GemmAvx2_MxN_4x24_K_Loop_End%=
\n
"
"L_GemmAvx2_MxN_4x24_K_Loop_Remain%=:
\n
"
" vload_b%= 0, 0, %%ymm12
\n
"
// B
".if (m_Nr > 8)
\n
vload_b%= 0, 1, %%ymm13
\n
.endif
\n
"
// B
".if (m_Nr >16)
\n
vload_b%= 0, 2, %%ymm14
\n
.endif
\n
"
// B
" vbroadcast_a%= 0, 0, %%ymm15
\n
"
// A broadcast 0
" vfmadd231ps %%ymm12, %%ymm15, %%ymm0
\n
"
// 0x0
".if (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm1
\n
.endif
\n
"
// 0x1
".if (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm2
\n
.endif
\n
"
// 0x2
".if (m_Mr > 1)
\n
vbroadcast_a%= 0, 1, %%ymm15
\n
.endif
\n
"
// A broadcast 1
".if (m_Mr > 1)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm3
\n
.endif
\n
"
// 1x0
".if (m_Mr > 1) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm4
\n
.endif
\n
"
// 1x1
".if (m_Mr > 1) && (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm5
\n
.endif
\n
"
// 1x2
".if (m_Mr > 2)
\n
vbroadcast_a%= 0, 2, %%ymm15
\n
.endif
\n
"
// A broadcast 2
".if (m_Mr > 2)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm6
\n
.endif
\n
"
// 2x0
".if (m_Mr > 2) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm7
\n
.endif
\n
"
// 2x1
".if (m_Mr > 2) && (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm8
\n
.endif
\n
"
// 2x2
".if (m_Mr > 3)
\n
vbroadcast_a%= 0, 3, %%ymm15
\n
.endif
\n
"
// A broadcast 3
".if (m_Mr > 3)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm9
\n
.endif
\n
"
// 3x0
".if (m_Mr > 3) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm10
\n
.endif
\n
"
// 3x1
".if (m_Mr > 3) && (m_Nr >16)
\n
vfmadd231ps %%ymm14, %%ymm15, %%ymm11
\n
.endif
\n
"
// 3x2
".if m_TransA != 0
\n
"
" lea 4(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea 4(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
" lea m_Mr * 4(%%rax), %%rax
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4(%%rbx), %%rbx
\n
"
".else
\n
"
" lea 8*4(%%rbx), %%rbx
\n
"
".endif
\n
"
"sub $1, %%rsi
\n
"
"jne L_GemmAvx2_MxN_4x24_K_Loop_Remain%=
\n
"
"L_GemmAvx2_MxN_4x24_K_Loop_End%=:
\n
"
"mov 56(%[m_param]), %%eax
\n
"
// alpha
"cmp $0x3f800000, %%eax
\n
"
"je L_GemmAvx2_MxN_4x24_Update_C%=
\n
"
"vbroadcastss 56(%[m_param]), %%ymm12
\n
"
" vmulps %%ymm12, %%ymm0, %%ymm0
\n
"
// 0x0
".if (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm1, %%ymm1
\n
.endif
\n
"
// 0x1
".if (m_Nr >16)
\n
vmulps %%ymm12, %%ymm2, %%ymm2
\n
.endif
\n
"
// 0x2
".if (m_Mr > 1)
\n
vmulps %%ymm12, %%ymm3, %%ymm3
\n
.endif
\n
"
// 1x0
".if (m_Mr > 1) && (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm4, %%ymm4
\n
.endif
\n
"
// 1x1
".if (m_Mr > 1) && (m_Nr >16)
\n
vmulps %%ymm12, %%ymm5, %%ymm5
\n
.endif
\n
"
// 1x2
".if (m_Mr > 2)
\n
vmulps %%ymm12, %%ymm6, %%ymm6
\n
.endif
\n
"
// 2x0
".if (m_Mr > 2) && (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm7, %%ymm7
\n
.endif
\n
"
// 2x1
".if (m_Mr > 2) && (m_Nr >16)
\n
vmulps %%ymm12, %%ymm8, %%ymm8
\n
.endif
\n
"
// 2x2
".if (m_Mr > 3)
\n
vmulps %%ymm12, %%ymm9, %%ymm9
\n
.endif
\n
"
// 3x0
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm10, %%ymm10
\n
.endif
\n
"
// 3x1
".if (m_Mr > 3) && (m_Nr >16)
\n
vmulps %%ymm12, %%ymm11, %%ymm11
\n
.endif
\n
"
// 3x2
"L_GemmAvx2_MxN_4x24_Update_C%=:
\n
"
"movq 16(%[m_param]), %%rax
\n
"
// p_c
"movq 48(%[m_param]), %%rdi
\n
"
// ldc
".if (m_Mr > 1)
\n
lea (%%rax, %%rdi, 1), %%rbx
\n
.endif
\n
"
".if (m_Mr > 2)
\n
lea (%%rbx, %%rdi, 1), %%rcx
\n
.endif
\n
"
".if (m_Mr > 3)
\n
lea (%%rcx, %%rdi, 1), %%rdx
\n
.endif
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Nr >16)
\n
vaddps 64(%%rax), %%ymm2, %%ymm2
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm3, %%ymm3
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vaddps 32(%%rbx), %%ymm4, %%ymm4
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr >16)
\n
vaddps 64(%%rbx), %%ymm5, %%ymm5
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vaddps (%%rcx), %%ymm6, %%ymm6
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vaddps 32(%%rcx), %%ymm7, %%ymm7
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr >16)
\n
vaddps 64(%%rcx), %%ymm8, %%ymm8
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vaddps (%%rdx), %%ymm9, %%ymm9
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vaddps 32(%%rdx), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vaddps 64(%%rdx), %%ymm11, %%ymm11
\n
.endif
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Nr >16)
\n
vmovups %%ymm2, 64(%%rax)
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vmovups %%ymm3, (%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vmovups %%ymm4, 32(%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr >16)
\n
vmovups %%ymm5, 64(%%rbx)
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vmovups %%ymm6, (%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vmovups %%ymm7, 32(%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr >16)
\n
vmovups %%ymm8, 64(%%rcx)
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vmovups %%ymm9, (%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmovups %%ymm10, 32(%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vmovups %%ymm11, 64(%%rdx)
\n
.endif
\n
"
"L_GemmAvx2_MxN_4x24_Exit%=:
\n
"
:
:
[
m_Mr
]
"i"
(
Mr
),
[
m_Nr
]
"i"
(
Nr
),
[
m_TransA
]
"i"
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
?
1
:
0
),
[
m_TransB
]
"i"
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
?
1
:
0
),
[
m_NTStore
]
"i"
(
NonTemporalStore
),
[
m_ABytes
]
"i"
(
sizeof
(
FloatA
)),
[
m_BBytes
]
"i"
(
sizeof
(
FloatB
)),
[
m_CBytes
]
"i"
(
sizeof
(
FloatC
)),
[
m_param
]
"r"
(
param
)
:
"cc"
,
"rax"
,
"rbx"
,
"rcx"
,
"rdx"
,
"rsi"
,
"rdi"
,
"r8"
,
"ymm0"
,
"ymm1"
,
"ymm2"
,
"ymm3"
,
"ymm4"
,
"ymm5"
,
"ymm6"
,
"ymm7"
,
"ymm8"
,
"ymm9"
,
"ymm10"
,
"ymm11"
,
"ymm12"
,
"ymm13"
,
"ymm14"
,
"ymm15"
);
// clang-format on
}
};
}
// namespace cpu
}
// namespace cpu
}
// namespace ck
}
// namespace ck
#endif
#endif
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
66fd7712
...
@@ -29,6 +29,20 @@
...
@@ -29,6 +29,20 @@
// #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
#define ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(FA, FB, FC, TA, TB, NT) \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 4, 24, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 3, 24, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 2, 24, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 1, 24, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 4, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 3, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 2, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 1, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 4, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 3, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 2, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 1, 8, TA, TB, NT>
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -45,6 +59,17 @@ using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
...
@@ -45,6 +59,17 @@ using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
typename
ALayout
,
typename
BLayout
>
using
thread_gemm_avx2_mxn_4x24_instances
=
std
::
tuple
<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
float
,
float
,
float
,
ALayout
,
BLayout
,
false
)
// clang-format on
>
;
void
dump_cache_hierarchy
()
void
dump_cache_hierarchy
()
{
{
auto
dump_cache_type
=
[
&
](
const
ck
::
cpu
::
cpuid_cache_type
&
type
)
{
auto
dump_cache_type
=
[
&
](
const
ck
::
cpu
::
cpuid_cache_type
&
type
)
{
...
@@ -336,15 +361,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
...
@@ -336,15 +361,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
ref_cpu_gemm_uk
<
data_type
,
ALayout
,
BLayout
>
(
mat_a
,
mat_b
,
mat_c_ref
,
alpha
,
m
,
n
,
k
);
ref_cpu_gemm_uk
<
data_type
,
ALayout
,
BLayout
>
(
mat_a
,
mat_b
,
mat_c_ref
,
alpha
,
m
,
n
,
k
);
using
thread_gemm_instance
=
thread_gemm_avx2_mxn_6x16_instances
<
ALayout
,
BLayout
>
;
using
thread_gemm_instance
=
thread_gemm_avx2_mxn_6x16_instances
<
ALayout
,
BLayout
>
;
// using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool
found
=
false
;
bool
found
=
false
;
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_instance
>
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_instance
>
,
1
>
{}([
&
](
auto
i
)
{
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_instance
>
;
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_instance
>
;
// if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
// !std::is_same<typename uk_type::BLayout_, BLayout>::value)
// {
// return;
// }
if
(
m
%
uk_type
::
Mr_
!=
0
||
n
%
uk_type
::
Nr_
!=
0
)
if
(
m
%
uk_type
::
Mr_
!=
0
||
n
%
uk_type
::
Nr_
!=
0
)
return
;
return
;
if
((
m
!=
uk_type
::
Mr_
&&
std
::
is_same
<
typename
uk_type
::
ALayout_
,
Col
>::
value
)
||
if
((
m
!=
uk_type
::
Mr_
&&
std
::
is_same
<
typename
uk_type
::
ALayout_
,
Col
>::
value
)
||
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment