Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
35f95fe9
Commit
35f95fe9
authored
Mar 27, 2022
by
carlushuang
Browse files
movaps->movups, and support loop over L1
parent
e72c0c43
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
911 additions
and
839 deletions
+911
-839
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+305
-305
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
+26
-26
include/ck/utility/cpuid.hpp
include/ck/utility/cpuid.hpp
+193
-193
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+387
-315
No files found.
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
35f95fe9
...
@@ -82,11 +82,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -82,11 +82,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vmov
a
ps_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".macro vmov
u
ps_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
".if
\\
i_scale != 0
\n
"
"vmov
a
ps
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
ymm
\n
"
"vmov
u
ps
\\
i_offset(
\\
r_base,
\\
r_stride,
\\
i_scale),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmov
a
ps
\\
i_offset(
\\
r_base),
\\
ymm
\n
"
"vmov
u
ps
\\
i_offset(
\\
r_base),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -105,15 +105,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -105,15 +105,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4
\n
"
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmov
a
ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*4*8,
\\
ymm
\n
"
"vmov
u
ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*4*8,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmov
a
ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*4,
\\
ymm
\n
"
"vmov
u
ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*4,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmov
a
ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*4*8,
\\
ymm
\n
"
"vmov
u
ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*4*8,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmov
a
ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*4,
\\
ymm
\n
"
"vmov
u
ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*4,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -265,18 +265,18 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -265,18 +265,18 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
" vmov
a
ps %%ymm0, (%%rax)
\n
"
" vmov
u
ps %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmov
a
ps %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Nr > 8)
\n
vmov
u
ps %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vmov
a
ps %%ymm2, (%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vmov
u
ps %%ymm2, (%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vmov
a
ps %%ymm3, 32(%%rbx)
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vmov
u
ps %%ymm3, 32(%%rbx)
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vmov
a
ps %%ymm4, (%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vmov
u
ps %%ymm4, (%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vmov
a
ps %%ymm5, 32(%%rcx)
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vmov
u
ps %%ymm5, 32(%%rcx)
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vmov
a
ps %%ymm6, (%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vmov
u
ps %%ymm6, (%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmov
a
ps %%ymm7, 32(%%rdx)
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vmov
u
ps %%ymm7, 32(%%rdx)
\n
.endif
\n
"
".if (m_Mr > 4)
\n
vmov
a
ps %%ymm8, (%%r8)
\n
.endif
\n
"
".if (m_Mr > 4)
\n
vmov
u
ps %%ymm8, (%%r8)
\n
.endif
\n
"
".if (m_Mr > 4) && (m_Nr > 8)
\n
vmov
a
ps %%ymm9, 32(%%r8)
\n
.endif
\n
"
".if (m_Mr > 4) && (m_Nr > 8)
\n
vmov
u
ps %%ymm9, 32(%%r8)
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vmov
a
ps %%ymm10, (%%r9)
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vmov
u
ps %%ymm10, (%%r9)
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vmov
a
ps %%ymm11, 32(%%r9)
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vmov
u
ps %%ymm11, 32(%%r9)
\n
.endif
\n
"
"L_GemmAvx2_MxN_6x16_Exit%=:
\n
"
"L_GemmAvx2_MxN_6x16_Exit%=:
\n
"
:
:
:
:
...
...
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
View file @
35f95fe9
include/ck/utility/cpuid.hpp
View file @
35f95fe9
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
35f95fe9
...
@@ -218,6 +218,74 @@ void test_ukernel(ukenrel_t uk,
...
@@ -218,6 +218,74 @@ void test_ukernel(ukenrel_t uk,
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
alpha
=
alpha
;
param
.
alpha
=
alpha
;
auto
invoke_uk
=
[
&
]()
{
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Row
,
BLayout
>::
value
)
{
assert
(
m
%
uk
.
Mr_
==
0
&&
n
==
uk
.
Nr_
);
data_type
*
p_a
=
mat_a
;
float
*
p_c
=
mat_c
;
param
.
p_a
=
p_a
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
Mr_
)
{
uk
.
Run
(
&
param
);
p_a
+=
uk
.
Mr_
*
k
;
p_c
+=
uk
.
Mr_
*
n
;
param
.
p_a
=
p_a
;
param
.
p_c
=
p_c
;
}
}
else
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Col
,
BLayout
>::
value
)
{
assert
(
m
%
uk
.
Mr_
==
0
&&
n
%
uk
.
Nr_
==
0
);
data_type
*
p_a
=
mat_a
;
// data_type* p_b = mat_b;
float
*
p_c
=
mat_c
;
param
.
p_a
=
p_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
Mr_
)
{
float
*
p_c_n
=
p_c
;
float
*
p_b_n
=
mat_b
;
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
Nr_
)
{
uk
.
Run
(
&
param
);
p_b_n
+=
uk
.
Nr_
*
k
;
// Nr_/8*k*8
p_c_n
+=
uk
.
Nr_
;
param
.
p_b
=
p_b_n
;
param
.
p_c
=
p_c_n
;
}
p_a
+=
uk
.
Mr_
*
k
;
p_c
+=
uk
.
Mr_
*
n
;
param
.
p_a
=
p_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
p_c
;
}
}
else
if
constexpr
(
std
::
is_same
<
Col
,
ALayout
>::
value
&&
std
::
is_same
<
Row
,
BLayout
>::
value
)
{
assert
(
m
==
uk
.
Mr_
&&
n
==
uk
.
Nr_
);
uk
.
Run
(
&
param
);
}
else
{
assert
(
m
%
uk
.
Mr_
==
0
&&
n
%
uk
.
Nr_
==
0
);
data_type
*
p_b
=
mat_b
;
float
*
p_c
=
mat_c
;
param
.
p_b
=
p_b
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
Nr_
)
{
uk
.
Run
(
&
param
);
p_b
+=
uk
.
Nr_
*
k
;
// Nr_/8*k*8
p_c
+=
uk
.
Nr_
;
param
.
p_b
=
p_b
;
param
.
p_c
=
p_c
;
}
}
};
printf
(
"gemm_uk_%dx%d_%c%c: "
,
uk
.
Mr_
,
uk
.
Nr_
,
ALayout
::
name
[
0
],
BLayout
::
name
[
0
]);
printf
(
"gemm_uk_%dx%d_%c%c: "
,
uk
.
Mr_
,
uk
.
Nr_
,
ALayout
::
name
[
0
],
BLayout
::
name
[
0
]);
fflush
(
stdout
);
fflush
(
stdout
);
// printf("%s: ", typeid(uk).name());fflush(stdout);
// printf("%s: ", typeid(uk).name());fflush(stdout);
...
@@ -227,13 +295,13 @@ void test_ukernel(ukenrel_t uk,
...
@@ -227,13 +295,13 @@ void test_ukernel(ukenrel_t uk,
for
(
int
i
=
0
;
i
<
(
repeat
/
5
);
i
++
)
for
(
int
i
=
0
;
i
<
(
repeat
/
5
);
i
++
)
{
{
uk
.
Run
(
&
param
);
invoke_uk
(
);
}
}
auto
t0
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
t0
=
std
::
chrono
::
high_resolution_clock
::
now
();
for
(
int
i
=
0
;
i
<
repeat
;
i
++
)
for
(
int
i
=
0
;
i
<
repeat
;
i
++
)
{
{
uk
.
Run
(
&
param
);
invoke_uk
(
);
}
}
auto
t1
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
t1
=
std
::
chrono
::
high_resolution_clock
::
now
();
...
@@ -243,7 +311,7 @@ void test_ukernel(ukenrel_t uk,
...
@@ -243,7 +311,7 @@ void test_ukernel(ukenrel_t uk,
double
gflops
=
static_cast
<
double
>
(
2
*
m
*
n
*
k
)
*
1e-3
/
us
;
double
gflops
=
static_cast
<
double
>
(
2
*
m
*
n
*
k
)
*
1e-3
/
us
;
memset
(
mat_c
,
0
,
m
*
n
*
sizeof
(
float
));
memset
(
mat_c
,
0
,
m
*
n
*
sizeof
(
float
));
uk
.
Run
(
&
param
);
invoke_uk
(
);
printf
(
"m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, "
,
m
,
n
,
k
,
alpha
,
us
,
gflops
);
printf
(
"m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, "
,
m
,
n
,
k
,
alpha
,
us
,
gflops
);
fflush
(
stdout
);
fflush
(
stdout
);
...
@@ -274,7 +342,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
...
@@ -274,7 +342,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{
{
return
;
return
;
}
}
if
(
uk_type
::
Mr_
!=
m
||
uk_type
::
Nr_
!=
n
)
if
(
m
%
uk_type
::
Mr_
!=
0
||
n
%
uk_type
::
Nr_
!=
0
)
return
;
if
((
m
!=
uk_type
::
Mr_
&&
std
::
is_same
<
typename
uk_type
::
ALayout_
,
Col
>::
value
)
||
(
n
!=
uk_type
::
Nr_
&&
std
::
is_same
<
typename
uk_type
::
BLayout_
,
Row
>::
value
))
// only k is the fast changing dim of A/B can we do muldiplt m, n
return
;
return
;
test_ukernel
<
data_type
,
ALayout
,
BLayout
>
(
uk_type
{},
mat_a
,
mat_b
,
mat_c
,
alpha
,
m
,
n
,
k
);
test_ukernel
<
data_type
,
ALayout
,
BLayout
>
(
uk_type
{},
mat_a
,
mat_b
,
mat_c
,
alpha
,
m
,
n
,
k
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment