Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5dbaf3c2
Commit
5dbaf3c2
authored
Aug 18, 2021
by
Jing Zhang
Browse files
refactor xdlops, hide c desc
parent
370c9245
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
155 additions
and
171 deletions
+155
-171
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+8
-8
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+147
-163
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
5dbaf3c2
...
@@ -118,13 +118,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -118,13 +118,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
()
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
()
{
{
///\to-do: hide xdl clayout into xdlops-gemm
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
CXdlopsLayout
=
xdlops_gemm
.
GetCXdlopsLayout
();
constexpr
auto
M0
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
I
1
,
M2
,
I1
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
M
1
,
M2
,
N
));
}
}
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2BlockDescriptor
()
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2BlockDescriptor
()
...
@@ -195,7 +196,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -195,7 +196,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
/
xdlops_gemm
.
KPerThread
>
{}([
&
](
auto
k0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
make_tuple
(
k0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
k0
,
I0
,
I0
,
I0
,
I0
),
...
@@ -212,8 +213,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -212,8 +213,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
KPerThread
>::
type
;
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
mfma_type
.
k_per_blk
>::
type
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
5dbaf3c2
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment