Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
099c470e
"...composable_kernel_rocm.git" did not exist on "e5376be4acc6fb8554c5ff5430b8f2750bc939c9"
Commit
099c470e
authored
May 13, 2022
by
wangshaojie6
Browse files
do some tests
parent
731febb6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
232 additions
and
56 deletions
+232
-56
conv2d_wrw_outline.s
conv2d_wrw_outline.s
+106
-0
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
+5
-5
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
+8
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+44
-26
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+19
-7
script/shuffle_v_mfma/shuffle_v_mfma.py
script/shuffle_v_mfma/shuffle_v_mfma.py
+50
-10
No files found.
conv2d_wrw_outline.s
0 → 100644
View file @
099c470e
;origin loop
.
origin_loop_start
:
ds_read2_b64
v_lda
[
0
:
3
]
ds_read2_b64
v_ldb
[
0
:
3
]
ds_read2_b64
v_lda
[
4
:
7
]
ds_read2_b64
v_ldb
[
4
:
7
]
v_mfma
v_lda
[
0
:
1
],
v_ldb
[
0
:
1
]
v_mfma
v_lda
[
2
:
3
],
v_ldb
[
2
:
3
]
v_mfma
v_lda
[
0
:
1
],
v_ldb
[
4
:
5
]
v_mfma
v_lda
[
2
:
3
],
v_ldb
[
6
:
7
]
v_mfma
v_lda
[
4
:
5
],
v_ldb
[
0
:
1
]
v_mfma
v_lda
[
6
:
7
],
v_ldb
[
2
:
3
]
v_mfma
v_lda
[
4
:
5
],
v_ldb
[
4
:
5
]
v_mfma
v_lda
[
6
:
7
],
v_ldb
[
6
:
7
]
ds_read2_b64
v_lda
[
0
:
3
]
offset
:
next
k
ds_read2_b64
v_lda
[
4
:
7
]
offset
:
next
k
ds_read2_b64
v_ldb
[
0
:
3
]
offset
:
next
k
ds_read2_b64
v_ldb
[
4
:
7
]
offset
:
next
k
s_barrier
v_mfma
v_lda
[
0
:
1
],
v_ldb
[
0
:
1
]
v_mfma
v_lda
[
2
:
3
],
v_ldb
[
2
:
3
]
v_mfma
v_lda
[
0
:
1
],
v_ldb
[
4
:
5
]
v_mfma
v_lda
[
2
:
3
],
v_ldb
[
6
:
7
]
v_pack
v_lda
[
0
],
v_gla
[
0
],
v_gla
[
1
],
lo
v_pack
v_lda
[
1
],
v_gla
[
0
],
v_gla
[
1
],
hi
v_pack
v_lda
[
2
],
v_gla
[
2
],
v_gla
[
3
],
lo
v_pack
v_lda
[
3
],
v_gla
[
2
],
v_gla
[
3
],
hi
ds_write2_b64
v_lda
[
0
:
1
],
v_lda
[
2
:
3
]
v_pack
v_pkb
[
0
],
v_glb
[
0
],
v_glb
[
1
],
lo
v_pack
v_pkb
[
1
],
v_glb
[
0
],
v_glb
[
1
],
hi
v_pack
v_pkb
[
2
],
v_glb
[
2
],
v_glb
[
3
],
lo
v_pack
v_pkb
[
3
],
v_glb
[
2
],
v_glb
[
3
],
hi
ds_write2_b64
v_pkb
[
0
:
1
],
v_pkb
[
2
:
3
]
s_barrier
v_move_slice_window
0
v_move_slice_window
1
; ... ~60 valus
buffer_load_dwordx4
v_gla
[
0
:
3
]
buffer_load_dwordx4
v_glb
[
0
:
3
]
v_mfma
v_lda
[
4
:
5
],
v_ldb
[
0
:
1
]
v_mfma
v_lda
[
6
:
7
],
v_ldb
[
2
:
3
]
v_mfma
v_lda
[
4
:
5
],
v_ldb
[
4
:
5
]
v_mfma
v_lda
[
6
:
7
],
v_ldb
[
6
:
7
]
s_branch
origin_loop_start
;optimized loop
.
optimized_loop_start
:
ds_read2_b64
v_lda
[
0
:
3
]
ds_read2_b64
v_ldb
[
0
:
3
]
ds_read2_b64
v_lda
[
4
:
7
]
ds_read2_b64
v_ldb
[
4
:
7
]
v_mfma
v_lda
[
0
:
1
],
v_ldb
[
0
:
1
]
v_mfma
v_lda
[
2
:
3
],
v_ldb
[
2
:
3
]
v_mfma
v_lda
[
0
:
1
],
v_ldb
[
4
:
5
]
v_mfma
v_lda
[
2
:
3
],
v_ldb
[
6
:
7
]
v_mfma
v_lda
[
4
:
5
],
v_ldb
[
0
:
1
]
v_mfma
v_lda
[
6
:
7
],
v_ldb
[
2
:
3
]
v_mfma
v_lda
[
4
:
5
],
v_ldb
[
4
:
5
]
v_mfma
v_lda
[
6
:
7
],
v_ldb
[
6
:
7
]
ds_read2_b64
v_lda
[
8
:
11
]
offset
:
next
k
ds_read2_b64
v_lda
[
12
:
15
]
offset
:
next
k
ds_read2_b64
v_ldb
[
8
:
11
]
offset
:
next
k
ds_read2_b64
v_ldb
[
12
:
15
]
offset
:
next
k
v_mfma
v_lda
[
8
:
9
],
v_ldb
[
8
:
9
]
s_barrier
v_mfma
v_lda
[
10
:
11
],
v_ldb
[
10
:
11
]
v_pack
v_lda
[
0
],
v_gla
[
0
],
v_gla
[
1
],
lo
v_pack
v_lda
[
1
],
v_gla
[
0
],
v_gla
[
1
],
hi
v_pack
v_lda
[
2
],
v_gla
[
2
],
v_gla
[
3
],
lo
v_pack
v_lda
[
3
],
v_gla
[
2
],
v_gla
[
3
],
hi
ds_write2_b64
v_lda
[
0
:
1
],
v_lda
[
2
:
3
]
v_mfma
v_lda
[
8
:
9
],
v_ldb
[
12
:
13
]
v_pack
v_pkb
[
0
],
v_glb
[
0
],
v_glb
[
1
],
lo
v_pack
v_pkb
[
1
],
v_glb
[
0
],
v_glb
[
1
],
hi
v_pack
v_pkb
[
2
],
v_glb
[
2
],
v_glb
[
3
],
lo
v_pack
v_pkb
[
3
],
v_glb
[
2
],
v_glb
[
3
],
hi
ds_write2_b64
v_pkb
[
0
:
1
],
v_pkb
[
2
:
3
]
v_mfma
v_lda
[
10
:
11
],
v_ldb
[
14
:
15
]
s_barrier
v_mfma
v_lda
[
12
:
13
],
v_ldb
[
8
:
9
]
v_move_slice_window
0
v_mfma
v_lda
[
12
:
13
],
v_ldb
[
10
:
11
]
v_move_slice_window
1
buffer_load_dwordx4
v_gla
[
0
:
3
]
v_mfma
v_lda
[
12
:
13
],
v_ldb
[
12
:
13
]
buffer_load_dwordx4
v_glb
[
0
:
3
]
v_mfma
v_lda
[
14
:
15
],
v_ldb
[
14
:
15
]
s_branch
optimized_loop_start
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
View file @
099c470e
...
...
@@ -42,14 +42,14 @@ using DeviceConvBwdDataInstance = ck::tensor_operation::device::
OutElementOp
,
// OutElementwiseOperation
ConvBwdDefault
,
// ConvolutionBackwardDataSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
256
,
// MPerBlock
256
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// NXdlPerWave
4
,
// MXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
...
...
@@ -61,7 +61,7 @@ using DeviceConvBwdDataInstance = ck::tensor_operation::device::
S
<
2
,
0
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcScalarPerVector
4
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
7
,
...
...
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
View file @
099c470e
...
...
@@ -44,22 +44,22 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
256
,
// BlockSize
25
6
,
// MPerBlock
256
,
// NPerBlock
4
,
// K0PerBlock
6
4
,
// MPerBlock
128
,
// NPerBlock
8
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
4
,
// MXdlPerWave
4
,
// NXdlPerWave
S
<
1
,
4
,
32
,
2
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
1
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
8
,
8
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
4
,
// ABlockTransferDstScalarPerVector_K1
2
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
32
,
2
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
8
,
16
,
2
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
099c470e
...
...
@@ -279,6 +279,24 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// });
//});
//static_for<0, KPerThread, KPack>{}([&](auto k) {
// static_for<0, MRepeat, 1>{}([&](auto m0) {
// //read from lds for A
// a_thread_copy_.Run();
// });
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// //read from lds for B
// b_thread_copy_.Run();
// });
//
// static_for<0, MRepeat, 1>{}([&](auto m0) {
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// // do mfma within k
// xdlops_gemm.template Run();
// });
// });
//});
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
...
...
@@ -299,34 +317,34 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf
);
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
099c470e
...
...
@@ -673,15 +673,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
k0_block_data_begin
=
0
;
block_sync_lds
();
//do
//{
// blockwise_gemm.Run();
//
// block_sync_lds();
//
// a_blockwise_copy.MoveSrcSliceWindow();
// b_blockwise_copy.MoveSrcSliceWindow();
//
// a_blockwise_copy.RunWrite();
// b_blockwise_copy.RunWrite();
//
// a_blockwise_copy.RunRead();
// block_sync_lds();
// b_blockwise_copy.RunRead();
//
// k0 += K0PerBlock;
//} while(k0 < (K0 - K0PerBlock));
do
{
//a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
//block_sync_lds();
//b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
...
...
script/shuffle_v_mfma/shuffle_v_mfma.py
View file @
099c470e
...
...
@@ -19,7 +19,13 @@ class asm_file_analyser:
self
.
core_loop_txt_bb0
=
self
.
gen_core_loop_txt
(
".LBB0_1"
)
self
.
core_loop_txt_bb1
=
self
.
gen_core_loop_txt
(
".LBB1_1"
)
self
.
next_free_vgpr
=
self
.
find_next_free_vgpr
(
asm_txt
)
self
.
next_free_vgpr
=
self
.
find_next_free_vgpr
(
self
.
asm_txt
)
self
.
vgpr_limit_number
=
self
.
find_vgpr_limit
(
self
.
asm_txt
)
self
.
asm_txt_max_vgpr
=
self
.
set_vgpr_to_max
(
self
.
asm_txt
)
print
(
self
.
vgpr_limit_number
)
#assert False
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0
=
self
.
enlarge_ds_read
(
self
.
core_loop_txt_bb0
)
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb1
=
self
.
enlarge_ds_read
(
self
.
core_loop_txt_bb1
)
...
...
@@ -33,10 +39,10 @@ class asm_file_analyser:
self
.
reshuffle_inst_slot_bb0
=
self
.
mfma_shuffle
(
self
.
interleave_vmfma_bb0
,
self
.
interleave_other_bb0
,
self
.
inst_weight_dict_bb0
)
self
.
reshuffle_inst_slot_bb1
=
self
.
mfma_shuffle
(
self
.
interleave_vmfma_bb1
,
self
.
interleave_other_bb1
,
self
.
inst_weight_dict_bb1
)
self
.
new_asm_txt_bb0
=
self
.
gen_new_asm_txt
(
self
.
interleave_vmfma_bb0
,
self
.
interleave_other_bb0
,
self
.
reshuffle_inst_slot_bb0
,
self
.
asm_txt
,
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0
)
self
.
new_asm_txt_bb0
=
self
.
gen_new_asm_txt
(
self
.
interleave_vmfma_bb0
,
self
.
interleave_other_bb0
,
self
.
reshuffle_inst_slot_bb0
,
self
.
asm_txt
_max_vgpr
,
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0
)
for
line
in
self
.
new_asm_txt_bb0
:
print
(
line
)
#
for line in self.new_asm_txt_bb0:
#
print(line)
self
.
new_asm_txt_bb1
=
self
.
gen_new_asm_txt
(
self
.
interleave_vmfma_bb1
,
self
.
interleave_other_bb1
,
self
.
reshuffle_inst_slot_bb1
,
self
.
new_asm_txt_bb0
,
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb1
)
...
...
@@ -58,9 +64,38 @@ class asm_file_analyser:
numvpgr_str
=
re
.
findall
(
r
'(?<=; NumVgprs: )\d*'
,
line
)
if
len
(
numvpgr_str
)
!=
0
:
next_free_vgpr
=
int
(
numvpgr_str
[
0
])
print
(
next_free_vgpr
)
#
print(next_free_vgpr)
return
next_free_vgpr
def
find_vgpr_limit
(
self
,
asm_txt
):
for
line
in
asm_txt
:
lds_size_re
=
re
.
search
(
r
'(?<=; LDSByteSize: )\d*'
,
line
)
if
lds_size_re
:
lds_size_str
=
lds_size_re
.
group
()
lds_size
=
int
(
lds_size_str
)
agpr_size_re
=
re
.
search
(
r
'(?<=; NumAgprs: )\d*'
,
line
)
if
agpr_size_re
:
agpr_size_str
=
agpr_size_re
.
group
()
agpr_size
=
int
(
agpr_size_str
)
vgpr_limit_number
=
256
//
(
min
(
64
*
1024
//
lds_size
,
256
//
agpr_size
))
return
vgpr_limit_number
def
set_vgpr_to_max
(
self
,
asm_txt
):
asm_max_vgpr
=
[]
for
line
in
asm_txt
:
if
line
.
find
(
".vgpr_count:"
)
!=
-
1
:
col
=
line
.
find
(
".vgpr_count:"
)
print
(
col
)
asm_max_vgpr
.
append
(
f
"
{
line
[:
col
]
}
.vgpr_count:
{
self
.
vgpr_limit_number
}
\n
"
)
elif
line
.
find
(
".amdhsa_next_free_vgpr"
)
!=
-
1
:
col_hsa
=
line
.
find
(
".amdhsa_next_free_vgpr"
)
asm_max_vgpr
.
append
(
f
"
{
line
[:
col_hsa
]
}
.amdhsa_next_free_vgpr
{
self
.
vgpr_limit_number
}
\n
"
)
else
:
asm_max_vgpr
.
append
(
line
)
return
asm_max_vgpr
def
enlarge_ds_read
(
self
,
core_loop_txt
):
new_core_loop
=
[]
ds_read_list
=
[]
...
...
@@ -97,20 +132,25 @@ class asm_file_analyser:
v_pair
=
re
.
findall
(
r
'v\[\d*:\d*]'
,
line
)
#print(i, v_pair)
new_line
=
line
replace_dict
=
{}
for
i_rep
in
vgpr_replacement_list
:
if
i
>
i_rep
[
0
]:
if
v_pair
[
0
]
in
i_rep
[
1
].
keys
():
new_line
=
new_line
.
replace
(
v_pair
[
0
],
i_rep
[
1
][
v_pair
[
0
]])
#new_line = new_line.replace(v_pair[0], i_rep[1][v_pair[0]])
replace_dict
[
v_pair
[
0
]]
=
i_rep
[
1
][
v_pair
[
0
]]
if
v_pair
[
1
]
in
i_rep
[
1
].
keys
():
new_line
=
new_line
.
replace
(
v_pair
[
1
],
i_rep
[
1
][
v_pair
[
1
]])
#new_line = new_line.replace(v_pair[1], i_rep[1][v_pair[1]])
replace_dict
[
v_pair
[
1
]]
=
i_rep
[
1
][
v_pair
[
1
]]
#print(replace_dict)
for
v_rep
in
replace_dict
:
new_line
=
new_line
.
replace
(
v_rep
,
replace_dict
[
v_rep
])
#print(new_line)
core_loop_suf_vgpr
.
append
(
new_line
)
else
:
core_loop_suf_vgpr
.
append
(
line
)
#print(vgpr_replacement_list)
#for i in core_loop_suf_vgpr:
# print(i)
return
core_loop_suf_vgpr
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment