Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e6ee6594
Commit
e6ee6594
authored
Apr 01, 2022
by
carlushuang
Browse files
non-temporal store support
parent
a6e310af
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
38 deletions
+101
-38
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+87
-26
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+14
-12
No files found.
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
e6ee6594
...
...
@@ -297,6 +297,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vmovups %%ymm2, (%%rbx)
\n
.endif
\n
"
...
...
@@ -309,6 +310,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4) && (m_Nr > 8)
\n
vmovups %%ymm9, 32(%%r8)
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vmovups %%ymm10, (%%r9)
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vmovups %%ymm11, 32(%%r9)
\n
.endif
\n
"
".else
\n
"
" vmovntps %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovntps %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vmovntps %%ymm2, (%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vmovntps %%ymm3, 32(%%rbx)
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vmovntps %%ymm4, (%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vmovntps %%ymm5, 32(%%rcx)
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vmovntps %%ymm6, (%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmovntps %%ymm7, 32(%%rdx)
\n
.endif
\n
"
".if (m_Mr > 4)
\n
vmovntps %%ymm8, (%%r8)
\n
.endif
\n
"
".if (m_Mr > 4) && (m_Nr > 8)
\n
vmovntps %%ymm9, 32(%%r8)
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vmovntps %%ymm10, (%%r9)
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vmovntps %%ymm11, 32(%%r9)
\n
.endif
\n
"
".endif
\n
"
"L_GemmAvx2_MxN_6x16_Exit%=:
\n
"
:
:
...
...
@@ -506,19 +521,34 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_mul_ps
(
ymm12
,
ymm11
);
}
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm0
);
if
constexpr
(
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Mr
>
1
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
if
constexpr
(
Mr
>
2
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
0
*
8
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
1
*
8
,
ymm5
);
if
constexpr
(
Mr
>
3
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
0
*
8
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
1
*
8
,
ymm7
);
if
constexpr
(
Mr
>
4
)
_mm256_storeu_ps
(
p_c
+
4
*
ldc
+
0
*
8
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
4
*
ldc
+
1
*
8
,
ymm9
);
if
constexpr
(
Mr
>
5
)
_mm256_storeu_ps
(
p_c
+
5
*
ldc
+
0
*
8
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
5
*
ldc
+
1
*
8
,
ymm11
);
// clang-format on
if
constexpr
(
NonTemporalStore
)
{
if
constexpr
(
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Mr
>
1
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
if
constexpr
(
Mr
>
2
)
_mm256_stream_ps
(
p_c
+
2
*
ldc
+
0
*
8
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
2
*
ldc
+
1
*
8
,
ymm5
);
if
constexpr
(
Mr
>
3
)
_mm256_stream_ps
(
p_c
+
3
*
ldc
+
0
*
8
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
3
*
ldc
+
1
*
8
,
ymm7
);
if
constexpr
(
Mr
>
4
)
_mm256_stream_ps
(
p_c
+
4
*
ldc
+
0
*
8
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
4
*
ldc
+
1
*
8
,
ymm9
);
if
constexpr
(
Mr
>
5
)
_mm256_stream_ps
(
p_c
+
5
*
ldc
+
0
*
8
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
5
*
ldc
+
1
*
8
,
ymm11
);
}
else
{
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm0
);
if
constexpr
(
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Mr
>
1
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
if
constexpr
(
Mr
>
2
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
0
*
8
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
1
*
8
,
ymm5
);
if
constexpr
(
Mr
>
3
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
0
*
8
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
1
*
8
,
ymm7
);
if
constexpr
(
Mr
>
4
)
_mm256_storeu_ps
(
p_c
+
4
*
ldc
+
0
*
8
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
4
*
ldc
+
1
*
8
,
ymm9
);
if
constexpr
(
Mr
>
5
)
_mm256_storeu_ps
(
p_c
+
5
*
ldc
+
0
*
8
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
5
*
ldc
+
1
*
8
,
ymm11
);
}
// clang-format on
#endif
}
};
...
...
@@ -803,6 +833,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 3) && (m_Nr > 8)
\n
vaddps 32(%%rdx), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vaddps 64(%%rdx), %%ymm11, %%ymm11
\n
.endif
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Nr >16)
\n
vmovups %%ymm2, 64(%%rax)
\n
.endif
\n
"
...
...
@@ -815,6 +846,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 3)
\n
vmovups %%ymm9, (%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmovups %%ymm10, 32(%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vmovups %%ymm11, 64(%%rdx)
\n
.endif
\n
"
".else
\n
"
" vmovntps %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovntps %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Nr >16)
\n
vmovntps %%ymm2, 64(%%rax)
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vmovntps %%ymm3, (%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vmovntps %%ymm4, 32(%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr >16)
\n
vmovntps %%ymm5, 64(%%rbx)
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vmovntps %%ymm6, (%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vmovntps %%ymm7, 32(%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr >16)
\n
vmovntps %%ymm8, 64(%%rcx)
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vmovntps %%ymm9, (%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmovntps %%ymm10, 32(%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vmovntps %%ymm11, 64(%%rdx)
\n
.endif
\n
"
".endif
\n
"
"L_GemmAvx2_MxN_4x24_Exit%=:
\n
"
:
:
...
...
@@ -1012,19 +1057,35 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_mul_ps
(
ymm12
,
ymm11
);
}
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm0
);
if
constexpr
(
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
2
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
2
*
8
,
ymm5
);
if
constexpr
(
Mr
>
2
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
0
*
8
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
1
*
8
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
2
*
8
,
ymm8
);
if
constexpr
(
Mr
>
3
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
0
*
8
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
1
*
8
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
2
*
8
,
ymm11
);
// clang-format on
if
constexpr
(
NonTemporalStore
)
{
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm0
);
if
constexpr
(
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Nr
>
16
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
2
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
2
*
8
,
ymm5
);
if
constexpr
(
Mr
>
2
)
_mm256_stream_ps
(
p_c
+
2
*
ldc
+
0
*
8
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
2
*
ldc
+
1
*
8
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
_mm256_stream_ps
(
p_c
+
2
*
ldc
+
2
*
8
,
ymm8
);
if
constexpr
(
Mr
>
3
)
_mm256_stream_ps
(
p_c
+
3
*
ldc
+
0
*
8
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
3
*
ldc
+
1
*
8
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
_mm256_stream_ps
(
p_c
+
3
*
ldc
+
2
*
8
,
ymm11
);
}
else
{
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm0
);
if
constexpr
(
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
2
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
2
*
8
,
ymm5
);
if
constexpr
(
Mr
>
2
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
0
*
8
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
1
*
8
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
2
*
8
,
ymm8
);
if
constexpr
(
Mr
>
3
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
0
*
8
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
1
*
8
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
2
*
8
,
ymm11
);
}
// clang-format on
#endif
}
};
...
...
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
e6ee6594
...
...
@@ -54,17 +54,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
AType
=
float
;
using
BType
=
float
;
using
CType
=
float
;
#define NTStore false
template
<
typename
ALayout
,
typename
BLayout
>
using
thread_gemm_avx2_mxn_6x16_instances
=
std
::
tuple
<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
fals
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
fals
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
fals
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
fals
e
)
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
NTStor
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
NTStor
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
NTStor
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
NTStor
e
)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(AType, BType, CType, ALayout, BLayout,
fals
e)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(AType, BType, CType, ALayout, BLayout,
NTStor
e)
// clang-format on
>
;
...
...
@@ -72,10 +73,10 @@ template <typename ALayout, typename BLayout>
using
thread_gemm_avx2_mxn_4x24_instances
=
std
::
tuple
<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
fals
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
fals
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
fals
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
fals
e
)
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
NTStor
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
NTStor
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
NTStor
e
),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE
(
AType
,
BType
,
CType
,
ALayout
,
BLayout
,
NTStor
e
)
// clang-format on
>
;
...
...
@@ -306,8 +307,10 @@ void test_ukernel(ukenrel_t uk,
#pragma omp parallel reduction(+ : us)
{
int
tid
=
omp_get_thread_num
();
float
*
private_c
=
reinterpret_cast
<
float
*>
(
malloc
(
m
*
n
*
sizeof
(
float
)));
int
tid
=
omp_get_thread_num
();
DeviceAlignedMemCPU
private_c_mem
(
m
*
n
*
sizeof
(
float
),
32
);
float
*
private_c
=
reinterpret_cast
<
float
*>
(
private_c_mem
.
mpDeviceBuf
);
// float * private_c = mat_c + tid * m * n;
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
p_a
=
mat_a
;
...
...
@@ -343,7 +346,6 @@ void test_ukernel(ukenrel_t uk,
invoke_uk
(
param
,
private_c
);
memcpy
(
mat_c
+
tid
*
m
*
n
,
private_c
,
m
*
n
*
sizeof
(
float
));
free
(
private_c
);
}
us
=
us
/
max_threads
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment