Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d714fa15
"...composable_kernel.git" did not exist on "4a1e97cf865fbf9fee3f02164e54c1d227299334"
Commit
d714fa15
authored
Mar 26, 2022
by
carlushuang
Browse files
avx2 ukernel ready, add test on L1D cache
parent
9a17e7fb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
842 additions
and
0 deletions
+842
-0
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+305
-0
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
+26
-0
include/ck/utility/cpuid.hpp
include/ck/utility/cpuid.hpp
+193
-0
library/src/tensor_operation_instance/cpu/CMakeLists.txt
library/src/tensor_operation_instance/cpu/CMakeLists.txt
+0
-0
test/CMakeLists.txt
test/CMakeLists.txt
+2
-0
test/cpu_ukernel/CMakeLists.txt
test/cpu_ukernel/CMakeLists.txt
+1
-0
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+315
-0
No files found.
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
0 → 100644
View file @
d714fa15
#ifndef CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
namespace
ck
{
namespace
cpu
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
index_t
Mr
,
index_t
Nr
,
typename
ALayout
,
// default is k*m, trans->m*k
typename
BLayout
,
// default is n/8*k*n8, trans->k*n
bool
NonTemporalStore
>
struct
ThreadwiseGemmAvx2_MxN_6x16
{
using
ALayout_
=
ALayout
;
using
BLayout_
=
BLayout
;
static
constexpr
auto
Mr_
=
Mr
;
static
constexpr
auto
Nr_
=
Nr
;
static
constexpr
auto
NonTemporalStore_
=
NonTemporalStore
;
__host__
constexpr
ThreadwiseGemmAvx2_MxN_6x16
()
{
static_assert
(
Mr
<=
6
&&
Mr
>=
1
&&
(
Nr
==
8
||
Nr
==
16
),
"wrong! Mr x Nr not valid"
);
}
__host__
static
void
Run
(
ThreadwiseGemmParam
*
param
)
{
/* 6x16 ukernel
*
* Mat_B
* |ymm12 |ymm13 |
* Mat_A +--------+--------+
* ymm14 |ymm0 |ymm1 | cycle 0
* ymm15 |ymm2 |ymm3 | cycle 1
* ymm14 |ymm4 |ymm5 | cycle 2
* ymm15 |ymm6 |ymm7 | cycle 3
* ymm14 |ymm8 |ymm9 | cycle 4
* ymm15 |ymm10 |ymm11 | Mat_C cycle 5
*
* ALayout:ColumnMajor (k*m), lda not needed
* ALayout:RowMajor (m*k), lda = k
* BLayout:ColumnMajor (n/8*k*n8), ldb = k*n8. At least this should be 8 continuous n for a
* ymm register BLayout:RowMajor (k*n), ldb not needed
*
* lda/ldb/ldc all in unit of byte
*
*/
// clang-format off
__asm__
__volatile__
(
"L_GemmAvx2_MxN_6x16_Entry%=:
\n
"
".set m_Mr, %c[m_Mr]
\n
"
".set m_Nr, %c[m_Nr]
\n
"
".set m_TransA, %c[m_TransA]
\n
"
".set m_TransB, %c[m_TransB]
\n
"
".set m_NTStore, %c[m_NTStore]
\n
"
".set m_ABytes, %c[m_ABytes]
\n
"
".set m_BBytes, %c[m_BBytes]
\n
"
".set m_CBytes, %c[m_CBytes]
\n
"
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
".if m_TransA != 0
\n
"
"movq 32(%[m_param]), %%rcx
\n
"
// lda
".endif
\n
"
".if m_TransB == 0
\n
"
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
".endif
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
"vbroadcastss
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
ymm
\n
"
".else
\n
"
"vbroadcastss
\\
i_offset(
\\
r_base),
\\
ymm
\n
"
".endif
\n
"
".endm
\n
"
".macro vmovaps_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
"vmovaps
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
ymm
\n
"
".else
\n
"
"vmovaps
\\
i_offset(
\\
r_base),
\\
ymm
\n
"
".endif
\n
"
".endm
\n
"
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8, r9), lda in rcx
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * 4,
\\
ymm
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * 4,
\\
ymm
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-3,
\\
i_k * 4,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
"vmovaps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*4*8,
\\
ymm
\n
"
".else
\n
"
"vmovaps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*4,
\\
ymm
\n
"
".endif
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
"vmovaps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*4*8,
\\
ymm
\n
"
".else
\n
"
"vmovaps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*4,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
" vxorps %%ymm0, %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vxorps %%ymm1, %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vxorps %%ymm2, %%ymm2, %%ymm2
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vxorps %%ymm3, %%ymm3, %%ymm3
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vxorps %%ymm4, %%ymm4, %%ymm4
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vxorps %%ymm5, %%ymm5, %%ymm5
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vxorps %%ymm6, %%ymm6, %%ymm6
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vxorps %%ymm7, %%ymm7, %%ymm7
\n
.endif
\n
"
".if (m_Mr > 4)
\n
vxorps %%ymm8, %%ymm8, %%ymm8
\n
.endif
\n
"
".if (m_Mr > 4) && (m_Nr > 8)
\n
vxorps %%ymm9, %%ymm9, %%ymm9
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vxorps %%ymm10, %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vxorps %%ymm11, %%ymm11, %%ymm11
\n
.endif
\n
"
".if m_TransA != 0
\n
"
".if m_Mr > 3
\n
"
"lea (%%rcx, %%rcx, 2), %%r9
\n
"
"lea (%%rax, %%r9), %%r8
\n
"
".endif
\n
"
".endif
\n
"
"cmp $4, %%rsi
\n
"
"jl L_GemmAvx2_MxN_6x16_K_Loop_Remain%=
\n
"
"L_GemmAvx2_MxN_6x16_K_Loop_Start%=:
\n
"
".irp i_k, 0, 1, 2, 3
\n
"
" vload_b%=
\\
i_k, 0, %%ymm12
\n
"
// B
".if (m_Nr > 8)
\n
vload_b%=
\\
i_k, 1, %%ymm13
\n
.endif
\n
"
// B
" vbroadcast_a%=
\\
i_k, 0, %%ymm14
\n
"
// A broadcast 0
".if (m_Mr > 1)
\n
vbroadcast_a%=
\\
i_k, 1, %%ymm15
\n
.endif
\n
"
// A broadcast 1
" vfmadd231ps %%ymm12, %%ymm14, %%ymm0
\n
"
// 0x0
".if (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm14, %%ymm1
\n
.endif
\n
"
// 0x1
".if (m_Mr > 1)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm2
\n
.endif
\n
"
// 1x0
".if (m_Mr > 1) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm3
\n
.endif
\n
"
// 1x1
".if (m_Mr > 2)
\n
vbroadcast_a%=
\\
i_k, 2, %%ymm14
\n
.endif
\n
"
// A broadcast 2
".if (m_Mr > 3)
\n
vbroadcast_a%=
\\
i_k, 3, %%ymm15
\n
.endif
\n
"
// A broadcast 3
".if (m_Mr > 2)
\n
vfmadd231ps %%ymm12, %%ymm14, %%ymm4
\n
.endif
\n
"
// 2x0
".if (m_Mr > 2) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm14, %%ymm5
\n
.endif
\n
"
// 2x1
".if (m_Mr > 3)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm6
\n
.endif
\n
"
// 3x0
".if (m_Mr > 3) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm7
\n
.endif
\n
"
// 3x1
".if (m_Mr > 4)
\n
vbroadcast_a%=
\\
i_k, 4, %%ymm14
\n
.endif
\n
"
// A broadcast 4
".if (m_Mr > 5)
\n
vbroadcast_a%=
\\
i_k, 5, %%ymm15
\n
.endif
\n
"
// A broadcast 5
".if (m_Mr > 4)
\n
vfmadd231ps %%ymm12, %%ymm14, %%ymm8
\n
.endif
\n
"
// 4x0
".if (m_Mr > 4) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm14, %%ymm9
\n
.endif
\n
"
// 4x1
".if (m_Mr > 5)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm10
\n
.endif
\n
"
// 5x0
".if (m_Mr > 5) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm11
\n
.endif
\n
"
// 5x1
".endr
\n
"
".if m_TransA != 0
\n
"
" lea 4*4(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea 4*4(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
" lea m_Mr * 4 * 4(%%rax), %%rax
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4 * 4(%%rbx), %%rbx
\n
"
".else
\n
"
" lea 8 * 4 * 4(%%rbx), %%rbx
\n
"
".endif
\n
"
"sub $4, %%rsi
\n
"
"cmp $4, %%rsi
\n
"
"jge L_GemmAvx2_MxN_6x16_K_Loop_Start%=
\n
"
"testq %%rsi, %%rsi
\n
"
"je L_GemmAvx2_MxN_6x16_K_Loop_End%=
\n
"
"L_GemmAvx2_MxN_6x16_K_Loop_Remain%=:
\n
"
" vload_b%= 0, 0, %%ymm12
\n
"
// B
".if (m_Nr > 8)
\n
vload_b%= 0, 1, %%ymm13
\n
.endif
\n
"
// B
" vbroadcast_a%= 0, 0, %%ymm14
\n
"
// A broadcast 0
".if (m_Mr > 1)
\n
vbroadcast_a%= 0, 1, %%ymm15
\n
.endif
\n
"
// A broadcast 1
" vfmadd231ps %%ymm12, %%ymm14, %%ymm0
\n
"
// 0x0
".if (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm14, %%ymm1
\n
.endif
\n
"
// 0x1
".if (m_Mr > 1)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm2
\n
.endif
\n
"
// 1x0
".if (m_Mr > 1) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm3
\n
.endif
\n
"
// 1x1
".if (m_Mr > 2)
\n
vbroadcast_a%= 0, 2, %%ymm14
\n
.endif
\n
"
// A broadcast 2
".if (m_Mr > 3)
\n
vbroadcast_a%= 0, 3, %%ymm15
\n
.endif
\n
"
// A broadcast 3
".if (m_Mr > 2)
\n
vfmadd231ps %%ymm12, %%ymm14, %%ymm4
\n
.endif
\n
"
// 2x0
".if (m_Mr > 2) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm14, %%ymm5
\n
.endif
\n
"
// 2x1
".if (m_Mr > 3)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm6
\n
.endif
\n
"
// 3x0
".if (m_Mr > 3) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm7
\n
.endif
\n
"
// 3x1
".if (m_Mr > 4)
\n
vbroadcast_a%= 0, 4, %%ymm14
\n
.endif
\n
"
// A broadcast 4
".if (m_Mr > 5)
\n
vbroadcast_a%= 0, 5, %%ymm15
\n
.endif
\n
"
// A broadcast 5
".if (m_Mr > 4)
\n
vfmadd231ps %%ymm12, %%ymm14, %%ymm8
\n
.endif
\n
"
// 4x0
".if (m_Mr > 4) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm14, %%ymm9
\n
.endif
\n
"
// 4x1
".if (m_Mr > 5)
\n
vfmadd231ps %%ymm12, %%ymm15, %%ymm10
\n
.endif
\n
"
// 5x0
".if (m_Mr > 5) && (m_Nr > 8)
\n
vfmadd231ps %%ymm13, %%ymm15, %%ymm11
\n
.endif
\n
"
// 5x1
".if m_TransA != 0
\n
"
" lea 4(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea 4(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
" lea m_Mr * 4(%%rax), %%rax
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4(%%rbx), %%rbx
\n
"
".else
\n
"
" lea 8*4(%%rbx), %%rbx
\n
"
".endif
\n
"
"sub $1, %%rsi
\n
"
"jne L_GemmAvx2_MxN_6x16_K_Loop_Remain%=
\n
"
"L_GemmAvx2_MxN_6x16_K_Loop_End%=:
\n
"
"mov 56(%[m_param]), %%eax
\n
"
// alpha
"cmp $0x3f800000, %%eax
\n
"
"je L_GemmAvx2_MxN_6x16_Update_C%=
\n
"
"vbroadcastss 56(%[m_param]), %%ymm12
\n
"
" vmulps %%ymm12, %%ymm0, %%ymm0
\n
"
// 0x0
".if (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm1, %%ymm1
\n
.endif
\n
"
// 0x1
".if (m_Mr > 1)
\n
vmulps %%ymm12, %%ymm2, %%ymm2
\n
.endif
\n
"
// 1x0
".if (m_Mr > 1) && (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm3, %%ymm3
\n
.endif
\n
"
// 1x1
".if (m_Mr > 2)
\n
vmulps %%ymm12, %%ymm4, %%ymm4
\n
.endif
\n
"
// 2x0
".if (m_Mr > 2) && (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm5, %%ymm5
\n
.endif
\n
"
// 2x1
".if (m_Mr > 3)
\n
vmulps %%ymm12, %%ymm6, %%ymm6
\n
.endif
\n
"
// 3x0
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm7, %%ymm7
\n
.endif
\n
"
// 3x1
".if (m_Mr > 4)
\n
vmulps %%ymm12, %%ymm8, %%ymm8
\n
.endif
\n
"
// 4x0
".if (m_Mr > 4) && (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm9, %%ymm9
\n
.endif
\n
"
// 4x1
".if (m_Mr > 5)
\n
vmulps %%ymm12, %%ymm10, %%ymm10
\n
.endif
\n
"
// 5x0
".if (m_Mr > 5) && (m_Nr > 8)
\n
vmulps %%ymm12, %%ymm11, %%ymm11
\n
.endif
\n
"
// 5x1
"L_GemmAvx2_MxN_6x16_Update_C%=:
\n
"
"movq 16(%[m_param]), %%rax
\n
"
// p_c
"movq 48(%[m_param]), %%rdi
\n
"
// ldc
".if (m_Mr > 1)
\n
lea (%%rax, %%rdi, 1), %%rbx
\n
.endif
\n
"
".if (m_Mr > 2)
\n
lea (%%rbx, %%rdi, 1), %%rcx
\n
.endif
\n
"
".if (m_Mr > 3)
\n
lea (%%rcx, %%rdi, 1), %%rdx
\n
.endif
\n
"
".if (m_Mr > 4)
\n
lea (%%rdx, %%rdi, 1), %%r8
\n
.endif
\n
"
".if (m_Mr > 5)
\n
lea (%%r8, %%rdi, 1), %%r9
\n
.endif
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm2, %%ymm2
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vaddps 32(%%rbx), %%ymm3, %%ymm3
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vaddps (%%rcx), %%ymm4, %%ymm4
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vaddps 32(%%rcx), %%ymm5, %%ymm5
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vaddps (%%rdx), %%ymm6, %%ymm6
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vaddps 32(%%rdx), %%ymm7, %%ymm7
\n
.endif
\n
"
".if (m_Mr > 4)
\n
vaddps (%%r8), %%ymm8, %%ymm8
\n
.endif
\n
"
".if (m_Mr > 4) && (m_Nr > 8)
\n
vaddps 32(%%r8), %%ymm9, %%ymm9
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
" vmovaps %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovaps %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vmovaps %%ymm2, (%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vmovaps %%ymm3, 32(%%rbx)
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vmovaps %%ymm4, (%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vmovaps %%ymm5, 32(%%rcx)
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vmovaps %%ymm6, (%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmovaps %%ymm7, 32(%%rdx)
\n
.endif
\n
"
".if (m_Mr > 4)
\n
vmovaps %%ymm8, (%%r8)
\n
.endif
\n
"
".if (m_Mr > 4) && (m_Nr > 8)
\n
vmovaps %%ymm9, 32(%%r8)
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vmovaps %%ymm10, (%%r9)
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vmovaps %%ymm11, 32(%%r9)
\n
.endif
\n
"
"L_GemmAvx2_MxN_6x16_Exit%=:
\n
"
:
:
[
m_Mr
]
"i"
(
Mr
),
[
m_Nr
]
"i"
(
Nr
),
[
m_TransA
]
"i"
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
?
1
:
0
),
[
m_TransB
]
"i"
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
?
1
:
0
),
[
m_NTStore
]
"i"
(
NonTemporalStore
),
[
m_ABytes
]
"i"
(
sizeof
(
FloatA
)),
[
m_BBytes
]
"i"
(
sizeof
(
FloatB
)),
[
m_CBytes
]
"i"
(
sizeof
(
FloatC
)),
[
m_param
]
"r"
(
param
)
:
"cc"
,
"rax"
,
"rbx"
,
"rcx"
,
"rdx"
,
"rsi"
,
"rdi"
,
"r8"
,
"r9"
,
"ymm0"
,
"ymm1"
,
"ymm2"
,
"ymm3"
,
"ymm4"
,
"ymm5"
,
"ymm6"
,
"ymm7"
,
"ymm8"
,
"ymm9"
,
"ymm10"
,
"ymm11"
,
"ymm12"
,
"ymm13"
,
"ymm14"
,
"ymm15"
);
// clang-format on
}
};
}
// namespace cpu
}
// namespace ck
#endif
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
0 → 100644
View file @
d714fa15
#ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace
ck
{
namespace
cpu
{
struct
ThreadwiseGemmParam
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
uint64_t
Kr
;
uint64_t
lda
;
// in unit of byte
uint64_t
ldb
;
// in unit of byte
uint64_t
ldc
;
// in unit of byte
float
alpha
;
uint32_t
_pack0
;
}
__attribute__
((
packed
));
}
// namespace cpu
}
// namespace ck
#endif
include/ck/utility/cpuid.hpp
0 → 100644
View file @
d714fa15
#ifndef CK_CPUID_HPP
#define CK_CPUID_HPP
namespace
ck
{
namespace
cpu
{
enum
cpuid_vendor
{
cpuid_vendor_intel
=
0
,
cpuid_vendor_amd
=
1
,
cpuid_vendor_other
=
2
,
};
enum
cpuid_cache_type
{
cpuid_cache_type_null
=
0
,
cpuid_cache_type_dcache
=
1
,
cpuid_cache_type_icache
=
2
,
cpuid_cache_type_unified
=
3
,
};
struct
cpuid_raw
{
uint32_t
eax
{
0
};
uint32_t
ebx
{
0
};
uint32_t
ecx
{
0
};
uint32_t
edx
{
0
};
};
struct
cpuid_cache_detail
{
uint32_t
size
{
0
};
uint32_t
type
{
0
};
uint32_t
cache_line_size
{
0
};
uint32_t
associativity
{
0
};
uint32_t
sets
{
0
};
uint32_t
partitions
{
0
};
uint32_t
shared_by_procs
{
0
};
// in HT, usually maybe 2 threads per core, hence for L1/L2,
// usually this maybe 2, unless turn of HT
uint32_t
cores_per_socket
{
0
};
// hardware cores in a physical socket. there maybe multiple
// sockets on the chip. TODO: may not needed?
uint32_t
flags
{
0
};
};
struct
cpuid_cache_hierarchy
{
cpuid_cache_detail
l1i
;
cpuid_cache_detail
l1d
;
cpuid_cache_detail
l2
;
cpuid_cache_detail
l3
;
cpuid_cache_detail
l4
;
};
static
inline
cpuid_raw
cpuid
(
uint32_t
eax
,
uint32_t
ecx
)
{
// some leaf feature require ecx value.
// for others, ecx actually not used.
uint32_t
ebx
,
edx
;
asm
__volatile__
(
"mov %0, %%eax
\n
"
"mov %2, %%ecx
\n
"
"cpuid
\n
"
"mov %%eax, %0
\n
"
"mov %%ebx, %1
\n
"
"mov %%ecx, %2
\n
"
"mov %%edx, %3
\n
"
:
"=r"
(
eax
),
"=r"
(
ebx
),
"=r"
(
ecx
),
"=r"
(
edx
)
:
"0"
(
eax
),
"2"
(
ecx
));
return
{
eax
,
ebx
,
ecx
,
edx
};
}
static
inline
cpuid_vendor
cpuid_query_vendor
()
{
cpuid_raw
r
=
cpuid
(
0
,
0
);
if
(
r
.
ebx
==
0x756E6547U
/*Genu*/
&&
r
.
edx
==
0x49656E69U
/*ineI*/
&&
r
.
ecx
==
0x6C65746EU
/*ntel*/
)
{
return
cpuid_vendor_intel
;
}
if
(
r
.
ebx
==
0x68747541U
/*Auth*/
&&
r
.
edx
==
0x74656273U
/*enti*/
&&
r
.
ecx
==
0x444D4163U
/*cAMD*/
)
{
return
cpuid_vendor_amd
;
}
if
(
r
.
ebx
==
0x69444D41U
/*AMDi*/
&&
r
.
edx
==
0x69746E65U
/*sbet*/
&&
r
.
ecx
==
0x21726574U
/*ter */
)
{
return
cpuid_vendor_amd
;
}
return
cpuid_vendor_other
;
}
static
inline
cpuid_cache_hierarchy
cpuid_query_cache
()
{
cpuid_cache_hierarchy
cache_hierarchy
;
cpuid_vendor
vendor
=
cpuid_query_vendor
();
uint32_t
leaf_cache_id
=
vendor
==
cpuid_vendor_amd
?
0x8000001d
:
0x4
;
for
(
uint32_t
ecx_idx
=
0
;;
ecx_idx
++
)
{
cpuid_raw
r
=
cpuid
(
leaf_cache_id
,
ecx_idx
);
uint32_t
cache_type
=
r
.
eax
&
0x1f
;
if
(
cache_type
==
cpuid_cache_type_null
)
break
;
// Null, no more cache
uint32_t
cache_level
=
(
r
.
eax
>>
5
)
&
0x7
;
uint32_t
cache_shared_by_cores
=
1
+
((
r
.
eax
>>
14
)
&
0xfff
);
uint32_t
cache_lpp_cores
=
1
+
((
r
.
eax
>>
26
)
&
0x3f
);
uint32_t
cache_line_size
=
1
+
(
r
.
ebx
&
0xfff
);
uint32_t
cache_partitions
=
1
+
((
r
.
ebx
>>
12
)
&
0x3ff
);
uint32_t
cache_associativity
=
1
+
(
r
.
ebx
>>
22
);
uint32_t
cache_sets
=
1
+
r
.
ecx
;
switch
(
cache_level
)
{
case
1
:
if
(
cache_type
==
cpuid_cache_type_dcache
||
cache_type
==
cpuid_cache_type_unified
)
{
cache_hierarchy
.
l1d
.
size
=
cache_partitions
*
cache_sets
*
cache_associativity
*
cache_line_size
;
cache_hierarchy
.
l1d
.
type
=
cache_type
;
cache_hierarchy
.
l1d
.
cache_line_size
=
cache_line_size
;
cache_hierarchy
.
l1d
.
associativity
=
cache_associativity
;
cache_hierarchy
.
l1d
.
sets
=
cache_sets
;
cache_hierarchy
.
l1d
.
partitions
=
cache_partitions
;
cache_hierarchy
.
l1d
.
shared_by_procs
=
cache_shared_by_cores
;
cache_hierarchy
.
l1d
.
cores_per_socket
=
cache_lpp_cores
;
}
else
if
(
cache_type
==
cpuid_cache_type_icache
)
{
cache_hierarchy
.
l1i
.
size
=
cache_partitions
*
cache_sets
*
cache_associativity
*
cache_line_size
;
cache_hierarchy
.
l1i
.
type
=
cache_type
;
cache_hierarchy
.
l1i
.
cache_line_size
=
cache_line_size
;
cache_hierarchy
.
l1i
.
associativity
=
cache_associativity
;
cache_hierarchy
.
l1i
.
sets
=
cache_sets
;
cache_hierarchy
.
l1i
.
partitions
=
cache_partitions
;
cache_hierarchy
.
l1i
.
shared_by_procs
=
cache_shared_by_cores
;
cache_hierarchy
.
l1i
.
cores_per_socket
=
cache_lpp_cores
;
}
break
;
case
2
:
if
(
cache_type
==
cpuid_cache_type_dcache
||
cache_type
==
cpuid_cache_type_unified
)
{
cache_hierarchy
.
l2
.
size
=
cache_partitions
*
cache_sets
*
cache_associativity
*
cache_line_size
;
cache_hierarchy
.
l2
.
type
=
cache_type
;
cache_hierarchy
.
l2
.
cache_line_size
=
cache_line_size
;
cache_hierarchy
.
l2
.
associativity
=
cache_associativity
;
cache_hierarchy
.
l2
.
sets
=
cache_sets
;
cache_hierarchy
.
l2
.
partitions
=
cache_partitions
;
cache_hierarchy
.
l2
.
shared_by_procs
=
cache_shared_by_cores
;
cache_hierarchy
.
l2
.
cores_per_socket
=
cache_lpp_cores
;
}
break
;
case
3
:
if
(
cache_type
==
cpuid_cache_type_dcache
||
cache_type
==
cpuid_cache_type_unified
)
{
cache_hierarchy
.
l3
.
size
=
cache_partitions
*
cache_sets
*
cache_associativity
*
cache_line_size
;
cache_hierarchy
.
l3
.
type
=
cache_type
;
cache_hierarchy
.
l3
.
cache_line_size
=
cache_line_size
;
cache_hierarchy
.
l3
.
associativity
=
cache_associativity
;
cache_hierarchy
.
l3
.
sets
=
cache_sets
;
cache_hierarchy
.
l3
.
partitions
=
cache_partitions
;
cache_hierarchy
.
l3
.
shared_by_procs
=
cache_shared_by_cores
;
cache_hierarchy
.
l3
.
cores_per_socket
=
cache_lpp_cores
;
}
break
;
case
4
:
if
(
cache_type
==
cpuid_cache_type_dcache
||
cache_type
==
cpuid_cache_type_unified
)
{
cache_hierarchy
.
l4
.
size
=
cache_partitions
*
cache_sets
*
cache_associativity
*
cache_line_size
;
cache_hierarchy
.
l4
.
type
=
cache_type
;
cache_hierarchy
.
l4
.
cache_line_size
=
cache_line_size
;
cache_hierarchy
.
l4
.
associativity
=
cache_associativity
;
cache_hierarchy
.
l4
.
sets
=
cache_sets
;
cache_hierarchy
.
l4
.
partitions
=
cache_partitions
;
cache_hierarchy
.
l4
.
shared_by_procs
=
cache_shared_by_cores
;
cache_hierarchy
.
l4
.
cores_per_socket
=
cache_lpp_cores
;
}
break
;
}
}
return
cache_hierarchy
;
}
}
// namespace cpu
}
// namespace ck
#endif
library/src/tensor_operation_instance/cpu/CMakeLists.txt
0 → 100644
View file @
d714fa15
test/CMakeLists.txt
View file @
d714fa15
...
@@ -10,6 +10,7 @@ include_directories(BEFORE
...
@@ -10,6 +10,7 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/warp
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/warp
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/thread
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/thread
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/element
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/element
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/cpu/thread
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/host_tensor
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/host_tensor
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/tensor_operation_instance
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/tensor_operation_instance
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/tensor_operation_instance/gpu/reduce
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/tensor_operation_instance/gpu/reduce
...
@@ -39,3 +40,4 @@ add_subdirectory(gemm_split_k)
...
@@ -39,3 +40,4 @@ add_subdirectory(gemm_split_k)
add_subdirectory
(
conv2d_fwd
)
add_subdirectory
(
conv2d_fwd
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
cpu_ukernel
)
test/cpu_ukernel/CMakeLists.txt
0 → 100644
View file @
d714fa15
add_test_executable
(
test_cpu_gemm_uk cpu_gemm_uk.cpp
)
test/cpu_ukernel/cpu_gemm_uk.cpp
0 → 100644
View file @
d714fa15
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <string>
#include <sstream>
#include <tuple>
#include <memory>
#include <chrono>
#include "config.hpp"
#include "print.hpp"
#include "cpuid.hpp"
#include "threadwise_gemm_avx2.hpp"
#define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
6
,
16
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
5
,
16
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
4
,
16
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
3
,
16
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
2
,
16
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
1
,
16
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
6
,
8
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
5
,
8
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
4
,
8
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
3
,
8
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
2
,
8
,
TA
,
TB
,
NT
>
,
\
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16
<
FA
,
FB
,
FC
,
1
,
8
,
TA
,
TB
,
NT
>
// #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
thread_gemm_avx2_mxn_6x16_instances
=
std
::
tuple
<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
float
,
float
,
float
,
Row
,
Row
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
float
,
float
,
float
,
Row
,
Col
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
float
,
float
,
float
,
Col
,
Row
,
false
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
float
,
float
,
float
,
Col
,
Col
,
false
)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false)
// clang-format on
>
;
void
dump_cache_hierarchy
()
{
auto
dump_cache_type
=
[
&
](
const
ck
::
cpu
::
cpuid_cache_type
&
type
)
{
if
(
type
==
ck
::
cpu
::
cpuid_cache_type_dcache
)
printf
(
"data cache"
);
else
if
(
type
==
ck
::
cpu
::
cpuid_cache_type_icache
)
printf
(
"inst cache"
);
else
if
(
type
==
ck
::
cpu
::
cpuid_cache_type_unified
)
printf
(
"unif cache"
);
};
auto
dump_cache_detail
=
[
&
](
const
ck
::
cpu
::
cpuid_cache_detail
&
detail
)
{
dump_cache_type
(
static_cast
<
const
ck
::
cpu
::
cpuid_cache_type
>
(
detail
.
type
));
printf
(
" size:%u, cache_line:%u, associativity:%u, sets:%u, partitions:%u, shared by "
"procs:%u(%u)
\n
"
,
detail
.
size
,
detail
.
cache_line_size
,
detail
.
associativity
,
detail
.
sets
,
detail
.
partitions
,
detail
.
shared_by_procs
,
detail
.
cores_per_socket
);
};
ck
::
cpu
::
cpuid_cache_hierarchy
cache
=
ck
::
cpu
::
cpuid_query_cache
();
if
(
cache
.
l1d
.
size
!=
0
)
{
printf
(
"l1 "
);
dump_cache_detail
(
cache
.
l1d
);
}
if
(
cache
.
l1i
.
size
!=
0
)
{
printf
(
"l1 "
);
dump_cache_detail
(
cache
.
l1i
);
}
if
(
cache
.
l2
.
size
!=
0
)
{
printf
(
"l2 "
);
dump_cache_detail
(
cache
.
l2
);
}
if
(
cache
.
l3
.
size
!=
0
)
{
printf
(
"l3 "
);
dump_cache_detail
(
cache
.
l3
);
}
if
(
cache
.
l4
.
size
!=
0
)
{
printf
(
"l4 "
);
dump_cache_detail
(
cache
.
l4
);
}
}
void
*
__aligned_malloc
(
size_t
required_bytes
,
size_t
alignment
)
{
if
(
alignment
==
0
||
(
alignment
&
(
alignment
-
1
)))
// check pow of 2
return
nullptr
;
void
*
p1
;
// original block
void
**
p2
;
// aligned block
int
offset
=
alignment
-
1
+
sizeof
(
void
*
);
if
((
p1
=
malloc
(
required_bytes
+
offset
))
==
nullptr
)
{
return
nullptr
;
}
p2
=
reinterpret_cast
<
void
**>
((
reinterpret_cast
<
size_t
>
(
p1
)
+
offset
)
&
~
(
alignment
-
1
));
p2
[
-
1
]
=
p1
;
return
p2
;
}
void
__aligned_free
(
void
*
p
)
{
free
((
reinterpret_cast
<
void
**>
(
p
))[
-
1
]);
}
template
<
typename
T
>
void
rand_vector
(
T
*
v
,
int
elem
)
{
int
i
;
static
int
flag
=
0
;
if
(
!
flag
)
{
srand
(
time
(
nullptr
));
flag
=
1
;
}
for
(
i
=
0
;
i
<
elem
;
i
++
)
{
v
[
i
]
=
(
static_cast
<
T
>
(
rand
()
%
100
))
/
100.0
f
;
}
}
bool
valid_vector
(
const
float
*
ref
,
const
float
*
rhs
,
uint32_t
elem
)
{
float
rtol
=
1e-5
;
float
atol
=
1e-8
;
uint32_t
err
=
0
;
for
(
uint32_t
i
=
0
;
i
<
elem
;
i
++
)
{
float
diff
=
std
::
abs
(
ref
[
i
]
-
rhs
[
i
]);
if
(
diff
>
atol
+
rtol
*
std
::
abs
(
ref
[
i
]))
{
printf
(
"diff at %u, ref:%f, rhs:%f
\n
"
,
i
,
ref
[
i
],
rhs
[
i
]);
err
++
;
}
}
return
err
==
0
;
}
template
<
typename
data_type
,
typename
ALayout
,
typename
BLayout
>
void
ref_cpu_gemm_uk
(
const
data_type
*
a
,
const
data_type
*
b
,
float
*
c
,
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
{
auto
a_offset
=
[
&
](
uint32_t
im
,
uint32_t
ik
)
{
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
)
{
return
im
*
k
+
ik
;
}
else
{
return
ik
*
m
+
im
;
}
};
auto
b_offset
=
[
&
](
uint32_t
ik
,
uint32_t
in
)
{
if
constexpr
(
std
::
is_same
<
Row
,
BLayout
>::
value
)
{
return
ik
*
n
+
in
;
}
else
{
// n*k*n8
return
(
in
/
8
)
*
k
*
8
+
ik
*
8
+
in
%
8
;
}
};
auto
c_offset
=
[
&
](
uint32_t
im
,
uint32_t
in
)
{
return
im
*
n
+
in
;
};
for
(
uint32_t
im
=
0
;
im
<
m
;
im
++
)
{
for
(
uint32_t
in
=
0
;
in
<
n
;
in
++
)
{
float
acc
=
.0
f
;
for
(
uint32_t
ik
=
0
;
ik
<
k
;
ik
++
)
{
acc
+=
a
[
a_offset
(
im
,
ik
)]
*
b
[
b_offset
(
ik
,
in
)];
}
acc
*=
alpha
;
c
[
c_offset
(
im
,
in
)]
=
acc
;
}
}
}
template
<
typename
data_type
,
typename
ALayout
,
typename
BLayout
,
typename
ukenrel_t
>
void
test_ukernel
(
ukenrel_t
uk
,
data_type
*
mat_a
,
data_type
*
mat_b
,
float
*
mat_c
,
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
{
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
p_a
=
mat_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
mat_c
;
param
.
Kr
=
k
;
param
.
lda
=
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
k
:
m
)
*
sizeof
(
data_type
);
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
data_type
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
alpha
=
alpha
;
printf
(
"gemm_uk_%dx%d_%c%c: "
,
uk
.
Mr_
,
uk
.
Nr_
,
ALayout
::
name
[
0
],
BLayout
::
name
[
0
]);
fflush
(
stdout
);
// printf("%s: ", typeid(uk).name());fflush(stdout);
memset
(
mat_c
,
0
,
m
*
n
*
sizeof
(
float
));
int
repeat
=
7e10
/
(
2
*
m
*
n
*
k
);
for
(
int
i
=
0
;
i
<
(
repeat
/
5
);
i
++
)
{
uk
.
Run
(
&
param
);
}
auto
t0
=
std
::
chrono
::
high_resolution_clock
::
now
();
for
(
int
i
=
0
;
i
<
repeat
;
i
++
)
{
uk
.
Run
(
&
param
);
}
auto
t1
=
std
::
chrono
::
high_resolution_clock
::
now
();
double
us
=
static_cast
<
double
>
(
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
t1
-
t0
).
count
())
/
repeat
;
double
gflops
=
static_cast
<
double
>
(
2
*
m
*
n
*
k
)
*
1e-3
/
us
;
memset
(
mat_c
,
0
,
m
*
n
*
sizeof
(
float
));
uk
.
Run
(
&
param
);
printf
(
"m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, "
,
m
,
n
,
k
,
alpha
,
us
,
gflops
);
fflush
(
stdout
);
}
// implement small ukernel on L1
template
<
typename
data_type
,
typename
ALayout
,
typename
BLayout
>
void
test_cpu_ukernel
(
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
{
data_type
*
mat_a
=
reinterpret_cast
<
data_type
*>
(
__aligned_malloc
(
m
*
k
*
sizeof
(
data_type
),
32
));
data_type
*
mat_b
=
reinterpret_cast
<
data_type
*>
(
__aligned_malloc
(
k
*
n
*
sizeof
(
data_type
),
32
));
float
*
mat_c
=
reinterpret_cast
<
float
*>
(
__aligned_malloc
(
m
*
n
*
sizeof
(
float
),
32
));
float
*
mat_c_ref
=
reinterpret_cast
<
float
*>
(
__aligned_malloc
(
m
*
n
*
sizeof
(
float
),
32
));
memset
(
mat_c_ref
,
0
,
m
*
n
*
sizeof
(
float
));
rand_vector
(
mat_a
,
m
*
k
);
rand_vector
(
mat_b
,
k
*
n
);
ref_cpu_gemm_uk
<
data_type
,
ALayout
,
BLayout
>
(
mat_a
,
mat_b
,
mat_c_ref
,
alpha
,
m
,
n
,
k
);
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_avx2_mxn_6x16_instances
>
,
1
>
{}([
&
](
auto
i
)
{
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_avx2_mxn_6x16_instances
>
;
if
constexpr
(
!
std
::
is_same
<
typename
uk_type
::
ALayout_
,
ALayout
>::
value
||
!
std
::
is_same
<
typename
uk_type
::
BLayout_
,
BLayout
>::
value
)
{
return
;
}
if
(
uk_type
::
Mr_
!=
m
||
uk_type
::
Nr_
!=
n
)
return
;
test_ukernel
<
data_type
,
ALayout
,
BLayout
>
(
uk_type
{},
mat_a
,
mat_b
,
mat_c
,
alpha
,
m
,
n
,
k
);
bool
is_valid
=
valid_vector
(
mat_c_ref
,
mat_c
,
m
*
n
);
printf
(
"vald:%s
\n
"
,
is_valid
?
"y"
:
"n"
);
// return ;
});
__aligned_free
(
mat_a
);
__aligned_free
(
mat_b
);
__aligned_free
(
mat_c
);
__aligned_free
(
mat_c_ref
);
}
int
main
(
int
argc
,
char
**
argv
)
{
int
m
=
6
;
int
n
=
16
;
int
k
=
64
;
float
alpha
=
1.0
f
;
if
(
argc
>
3
)
{
m
=
std
::
atoi
(
argv
[
1
]);
n
=
std
::
atoi
(
argv
[
2
]);
k
=
std
::
atoi
(
argv
[
3
]);
}
if
(
argc
>
4
)
{
alpha
=
std
::
atof
(
argv
[
4
]);
}
dump_cache_hierarchy
();
test_cpu_ukernel
<
float
,
Row
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
float
,
Row
,
Col
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
float
,
Col
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
float
,
Col
,
Col
>
(
alpha
,
m
,
n
,
k
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment