Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f10bfbf9
Commit
f10bfbf9
authored
Mar 28, 2022
by
carlushuang
Browse files
add avx2 intrinsic
parent
3cc7ac0a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
391 additions
and
2 deletions
+391
-2
include/ck/config.hpp
include/ck/config.hpp
+4
-0
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+387
-2
No files found.
include/ck/config.hpp
View file @
f10bfbf9
...
@@ -163,6 +163,10 @@
...
@@ -163,6 +163,10 @@
#define CK_WORKAROUND_GITHUB_135 1
#define CK_WORKAROUND_GITHUB_135 1
#endif
#endif
#ifndef CK_USE_X86_INLINE_ASM
#define CK_USE_X86_INLINE_ASM 1
#endif
namespace
ck
{
namespace
ck
{
enum
struct
InMemoryDataOperationEnum_t
enum
struct
InMemoryDataOperationEnum_t
...
...
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
f10bfbf9
#ifndef CK_THREADWISE_GEMM_AVX2_HPP
#ifndef CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP
#if CK_USE_X86_INLINE_ASM == 0
#include <immintrin.h>
#endif
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "math.hpp"
#include "math.hpp"
...
@@ -51,7 +54,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -51,7 +54,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
* lda/ldb/ldc all in unit of byte
* lda/ldb/ldc all in unit of byte
*
*
*/
*/
#if CK_USE_X86_INLINE_ASM
// clang-format off
// clang-format off
__asm__
__volatile__
(
__asm__
__volatile__
(
"L_GemmAvx2_MxN_6x16_Entry%=:
\n
"
"L_GemmAvx2_MxN_6x16_Entry%=:
\n
"
...
@@ -326,6 +329,197 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -326,6 +329,197 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"ymm14"
,
"ymm15"
"ymm14"
,
"ymm15"
);
);
// clang-format on
// clang-format on
#else
__m256
ymm0
,
ymm1
,
ymm2
,
ymm3
,
ymm4
,
ymm5
,
ymm6
,
ymm7
,
ymm8
,
ymm9
,
ymm10
,
ymm11
,
ymm12
,
ymm13
,
ymm14
,
ymm15
;
const
FloatA
*
p_a
=
reinterpret_cast
<
const
FloatA
*>
(
param
->
p_a
);
const
FloatB
*
p_b
=
reinterpret_cast
<
const
FloatB
*>
(
param
->
p_b
);
float
*
p_c
=
reinterpret_cast
<
float
*>
(
param
->
p_c
);
uint64_t
Kr
=
param
->
Kr
;
uint64_t
lda
=
param
->
lda
/
sizeof
(
FloatA
);
uint64_t
ldb
=
param
->
ldb
/
sizeof
(
FloatB
);
uint64_t
ldc
=
param
->
ldc
/
sizeof
(
float
);
// float alpha = param->alpha;
auto
broadcast_a
=
[
&
](
const
int
i_k
,
const
int
i_m
,
__m256
&
ymm
)
{
if
constexpr
(
std
::
is_same
<
FloatA
,
float
>::
value
)
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_m
*
lda
+
i_k
);
}
else
{
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
Mr
+
i_m
);
}
}
else
{
// static_assert();
// not supported for now. Maybe for intrinsic never use fp16 input and cvt and
// broadcast to ymm (don't have enough register)
// below code seems result in computation fail...
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_m
*
lda
+
i_k
)));
}
else
{
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
Mr
+
i_m
)));
}
}
};
auto
load_b
=
[
&
](
const
int
i_k
,
const
int
i_n
,
__m256
&
ymm
)
{
if
constexpr
(
std
::
is_same
<
FloatB
,
float
>::
value
)
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
Nr
+
i_n
*
8
);
}
else
{
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
8
+
i_n
*
ldb
);
}
}
else
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
Nr
+
i_n
*
8
)));
}
else
{
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
8
+
i_n
*
ldb
)));
}
}
};
// clang-format off
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
1
*
8
);
while
(
Kr
>
4
){
#pragma unroll
for
(
int
i_k
=
0
;
i_k
<
4
;
i_k
++
){
load_b
(
i_k
,
0
,
ymm12
);
if
constexpr
(
Nr
>
8
)
load_b
(
i_k
,
1
,
ymm13
);
broadcast_a
(
i_k
,
0
,
ymm14
);
if
constexpr
(
Mr
>
1
)
broadcast_a
(
i_k
,
1
,
ymm15
);
ymm0
=
_mm256_fmadd_ps
(
ymm12
,
ymm14
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_fmadd_ps
(
ymm13
,
ymm14
,
ymm1
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm3
);
if
constexpr
(
Mr
>
2
)
broadcast_a
(
i_k
,
2
,
ymm14
);
if
constexpr
(
Mr
>
3
)
broadcast_a
(
i_k
,
3
,
ymm15
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_fmadd_ps
(
ymm12
,
ymm14
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_fmadd_ps
(
ymm13
,
ymm14
,
ymm5
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm7
);
if
constexpr
(
Mr
>
4
)
broadcast_a
(
i_k
,
4
,
ymm14
);
if
constexpr
(
Mr
>
5
)
broadcast_a
(
i_k
,
5
,
ymm15
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_fmadd_ps
(
ymm12
,
ymm14
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_fmadd_ps
(
ymm13
,
ymm14
,
ymm9
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm11
);
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
4
;
}
else
{
p_a
+=
Mr
*
4
;
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
4
;
}
else
{
p_b
+=
4
*
8
;
}
Kr
-=
4
;
}
while
(
Kr
!=
0
){
load_b
(
0
,
0
,
ymm12
);
if
constexpr
(
Nr
>
8
)
load_b
(
0
,
1
,
ymm13
);
broadcast_a
(
0
,
0
,
ymm14
);
if
constexpr
(
Mr
>
1
)
broadcast_a
(
0
,
1
,
ymm15
);
ymm0
=
_mm256_fmadd_ps
(
ymm12
,
ymm14
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_fmadd_ps
(
ymm13
,
ymm14
,
ymm1
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm3
);
if
constexpr
(
Mr
>
2
)
broadcast_a
(
0
,
2
,
ymm14
);
if
constexpr
(
Mr
>
3
)
broadcast_a
(
0
,
3
,
ymm15
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_fmadd_ps
(
ymm12
,
ymm14
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_fmadd_ps
(
ymm13
,
ymm14
,
ymm5
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm7
);
if
constexpr
(
Mr
>
4
)
broadcast_a
(
0
,
4
,
ymm14
);
if
constexpr
(
Mr
>
5
)
broadcast_a
(
0
,
5
,
ymm15
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_fmadd_ps
(
ymm12
,
ymm14
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_fmadd_ps
(
ymm13
,
ymm14
,
ymm9
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm11
);
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
1
;
}
else
{
p_a
+=
Mr
*
1
;
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
1
;
}
else
{
p_b
+=
1
*
8
;
}
Kr
--
;
}
if
(
param
->
alpha
!=
1.0
f
){
ymm12
=
_mm256_broadcast_ss
(
reinterpret_cast
<
float
const
*>
(
&
param
->
alpha
));
ymm0
=
_mm256_mul_ps
(
ymm12
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_mul_ps
(
ymm12
,
ymm1
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_mul_ps
(
ymm12
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_mul_ps
(
ymm12
,
ymm3
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_mul_ps
(
ymm12
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_mul_ps
(
ymm12
,
ymm5
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_mul_ps
(
ymm12
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_mul_ps
(
ymm12
,
ymm7
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_mul_ps
(
ymm12
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_mul_ps
(
ymm12
,
ymm9
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_mul_ps
(
ymm12
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_mul_ps
(
ymm12
,
ymm11
);
}
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm0
);
if
constexpr
(
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Mr
>
1
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
if
constexpr
(
Mr
>
2
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
0
*
8
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
1
*
8
,
ymm5
);
if
constexpr
(
Mr
>
3
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
0
*
8
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
1
*
8
,
ymm7
);
if
constexpr
(
Mr
>
4
)
_mm256_storeu_ps
(
p_c
+
4
*
ldc
+
0
*
8
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
4
*
ldc
+
1
*
8
,
ymm9
);
if
constexpr
(
Mr
>
5
)
_mm256_storeu_ps
(
p_c
+
5
*
ldc
+
0
*
8
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
5
*
ldc
+
1
*
8
,
ymm11
);
// clang-format on
#endif
}
}
};
};
...
@@ -370,7 +564,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -370,7 +564,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
* lda/ldb/ldc all in unit of byte
* lda/ldb/ldc all in unit of byte
*
*
*/
*/
#if CK_USE_X86_INLINE_ASM
// clang-format off
// clang-format off
__asm__
__volatile__
(
__asm__
__volatile__
(
"L_GemmAvx2_MxN_4x24_Entry%=:
\n
"
"L_GemmAvx2_MxN_4x24_Entry%=:
\n
"
...
@@ -641,6 +835,197 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -641,6 +835,197 @@ struct ThreadwiseGemmAvx2_MxN_4x24
"ymm14"
,
"ymm15"
"ymm14"
,
"ymm15"
);
);
// clang-format on
// clang-format on
#else
__m256
ymm0
,
ymm1
,
ymm2
,
ymm3
,
ymm4
,
ymm5
,
ymm6
,
ymm7
,
ymm8
,
ymm9
,
ymm10
,
ymm11
,
ymm12
,
ymm13
,
ymm14
,
ymm15
;
const
FloatA
*
p_a
=
reinterpret_cast
<
const
FloatA
*>
(
param
->
p_a
);
const
FloatB
*
p_b
=
reinterpret_cast
<
const
FloatB
*>
(
param
->
p_b
);
float
*
p_c
=
reinterpret_cast
<
float
*>
(
param
->
p_c
);
uint64_t
Kr
=
param
->
Kr
;
uint64_t
lda
=
param
->
lda
/
sizeof
(
FloatA
);
uint64_t
ldb
=
param
->
ldb
/
sizeof
(
FloatB
);
uint64_t
ldc
=
param
->
ldc
/
sizeof
(
float
);
// float alpha = param->alpha;
auto
broadcast_a
=
[
&
](
const
int
i_k
,
const
int
i_m
,
__m256
&
ymm
)
{
if
constexpr
(
std
::
is_same
<
FloatA
,
float
>::
value
)
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_m
*
lda
+
i_k
);
}
else
{
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
Mr
+
i_m
);
}
}
else
{
// static_assert();
// not supported for now. Maybe for intrinsic never use fp16 input and cvt and
// broadcast to ymm (don't have enough register)
// below code seems result in computation fail...
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_m
*
lda
+
i_k
)));
}
else
{
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
Mr
+
i_m
)));
}
}
};
auto
load_b
=
[
&
](
const
int
i_k
,
const
int
i_n
,
__m256
&
ymm
)
{
if
constexpr
(
std
::
is_same
<
FloatB
,
float
>::
value
)
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
Nr
+
i_n
*
8
);
}
else
{
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
8
+
i_n
*
ldb
);
}
}
else
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
Nr
+
i_n
*
8
)));
}
else
{
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
8
+
i_n
*
ldb
)));
}
}
};
// clang-format off
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
2
*
8
);
while
(
Kr
>
4
){
#pragma unroll
for
(
int
i_k
=
0
;
i_k
<
4
;
i_k
++
){
load_b
(
i_k
,
0
,
ymm12
);
if
constexpr
(
Nr
>
8
)
load_b
(
i_k
,
1
,
ymm13
);
if
constexpr
(
Nr
>
16
)
load_b
(
i_k
,
2
,
ymm14
);
broadcast_a
(
i_k
,
0
,
ymm15
);
ymm0
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm1
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_fmadd_ps
(
ymm14
,
ymm15
,
ymm2
);
if
constexpr
(
Mr
>
1
)
broadcast_a
(
i_k
,
1
,
ymm15
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_fmadd_ps
(
ymm14
,
ymm15
,
ymm5
);
if
constexpr
(
Mr
>
2
)
broadcast_a
(
i_k
,
2
,
ymm15
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_fmadd_ps
(
ymm14
,
ymm15
,
ymm8
);
if
constexpr
(
Mr
>
3
)
broadcast_a
(
i_k
,
3
,
ymm15
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_fmadd_ps
(
ymm14
,
ymm15
,
ymm11
);
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
4
;
}
else
{
p_a
+=
Mr
*
4
;
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
4
;
}
else
{
p_b
+=
4
*
8
;
}
Kr
-=
4
;
}
while
(
Kr
!=
0
){
load_b
(
0
,
0
,
ymm12
);
if
constexpr
(
Nr
>
8
)
load_b
(
0
,
1
,
ymm13
);
if
constexpr
(
Nr
>
16
)
load_b
(
0
,
2
,
ymm14
);
broadcast_a
(
0
,
0
,
ymm15
);
ymm0
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm1
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_fmadd_ps
(
ymm14
,
ymm15
,
ymm2
);
if
constexpr
(
Mr
>
1
)
broadcast_a
(
0
,
1
,
ymm15
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_fmadd_ps
(
ymm14
,
ymm15
,
ymm5
);
if
constexpr
(
Mr
>
2
)
broadcast_a
(
0
,
2
,
ymm15
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_fmadd_ps
(
ymm14
,
ymm15
,
ymm8
);
if
constexpr
(
Mr
>
3
)
broadcast_a
(
0
,
3
,
ymm15
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_fmadd_ps
(
ymm12
,
ymm15
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_fmadd_ps
(
ymm13
,
ymm15
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_fmadd_ps
(
ymm14
,
ymm15
,
ymm11
);
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
1
;
}
else
{
p_a
+=
Mr
*
1
;
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
1
;
}
else
{
p_b
+=
1
*
8
;
}
Kr
--
;
}
if
(
param
->
alpha
!=
1.0
f
){
ymm12
=
_mm256_broadcast_ss
(
reinterpret_cast
<
float
const
*>
(
&
param
->
alpha
));
ymm0
=
_mm256_mul_ps
(
ymm12
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_mul_ps
(
ymm12
,
ymm1
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_mul_ps
(
ymm12
,
ymm2
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_mul_ps
(
ymm12
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_mul_ps
(
ymm12
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_mul_ps
(
ymm12
,
ymm5
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_mul_ps
(
ymm12
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_mul_ps
(
ymm12
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_mul_ps
(
ymm12
,
ymm8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_mul_ps
(
ymm12
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_mul_ps
(
ymm12
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_mul_ps
(
ymm12
,
ymm11
);
}
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm0
);
if
constexpr
(
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
0
*
ldc
+
2
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
1
*
ldc
+
2
*
8
,
ymm5
);
if
constexpr
(
Mr
>
2
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
0
*
8
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
1
*
8
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
2
*
ldc
+
2
*
8
,
ymm8
);
if
constexpr
(
Mr
>
3
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
0
*
8
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
1
*
8
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
_mm256_storeu_ps
(
p_c
+
3
*
ldc
+
2
*
8
,
ymm11
);
// clang-format on
#endif
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment