Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
54b9c729
Commit
54b9c729
authored
May 08, 2022
by
wangshaojie6
Browse files
add a script to shuffle mfma instructions to get better performance
parent
01136036
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
334 additions
and
3 deletions
+334
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+77
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+3
-1
script/shuffle_v_mfma/shuffle_v_mfma.py
script/shuffle_v_mfma/shuffle_v_mfma.py
+254
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
54b9c729
...
...
@@ -257,6 +257,80 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
#if 1
//static_for<0, KPerThread, KPack>{}([&](auto k) {
// static_for<0, MRepeat, 1>{}([&](auto m0) {
// // read A
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(m0, I0, I0, k),
// a_block_buf,
// a_thread_desc_,
// make_tuple(m0, I0, I0, k),
// a_thread_buf);
// });
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// // read B
// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
// make_tuple(n0, I0, I0, k),
// b_block_buf,
// b_thread_desc_,
// make_tuple(n0, I0, I0, k),
// b_thread_buf);
// });
//});
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
k
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
k
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
k
),
b_thread_buf
);
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
#else
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
...
...
@@ -299,16 +373,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
});
});
});
#endif
}
private:
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPer
Thread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPer
Block
>
{}));
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPer
Thread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPer
Block
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
54b9c729
...
...
@@ -671,13 +671,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
constexpr
(
HasMainKBlockLoop
)
{
index_t
k0_block_data_begin
=
0
;
block_sync_lds
();
do
{
//a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds
();
//
block_sync_lds();
//b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
...
...
@@ -692,6 +693,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
);
k0_block_data_begin
+=
K0PerBlock
;
...
...
script/shuffle_v_mfma/shuffle_v_mfma.py
0 → 100644
View file @
54b9c729
import
re
def
read_example_asm_file
(
asm_file_path
):
with
open
(
asm_file_path
)
as
f
:
lines
=
f
.
readlines
()
return
lines
def
write_example_asm_file_back
(
asm_file_path
,
new_asm_txt
):
with
open
(
asm_file_path
,
"w"
)
as
f
:
f
.
writelines
(
new_asm_txt
)
class
asm_file_analyser
:
def
__init__
(
self
,
asm_txt
):
self
.
asm_txt
=
asm_txt
self
.
core_loop_txt_bb0
=
self
.
gen_core_loop_txt
(
".LBB0_1"
)
self
.
core_loop_txt_bb1
=
self
.
gen_core_loop_txt
(
".LBB1_1"
)
self
.
next_free_vgpr
=
self
.
find_next_free_vgpr
(
asm_txt
)
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0
=
self
.
enlarge_ds_read
(
self
.
core_loop_txt_bb0
)
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb1
=
self
.
enlarge_ds_read
(
self
.
core_loop_txt_bb1
)
self
.
inst_weight_dict_bb0
=
self
.
gen_inst_weight
(
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0
)
self
.
inst_weight_dict_bb1
=
self
.
gen_inst_weight
(
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb1
)
self
.
interleave_vmfma_bb0
,
self
.
interleave_other_bb0
=
self
.
get_interleave_start_line
(
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0
)
self
.
interleave_vmfma_bb1
,
self
.
interleave_other_bb1
=
self
.
get_interleave_start_line
(
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb1
)
self
.
reshuffle_inst_slot_bb0
=
self
.
mfma_shuffle
(
self
.
interleave_vmfma_bb0
,
self
.
interleave_other_bb0
,
self
.
inst_weight_dict_bb0
)
self
.
reshuffle_inst_slot_bb1
=
self
.
mfma_shuffle
(
self
.
interleave_vmfma_bb1
,
self
.
interleave_other_bb1
,
self
.
inst_weight_dict_bb1
)
self
.
new_asm_txt_bb0
=
self
.
gen_new_asm_txt
(
self
.
interleave_vmfma_bb0
,
self
.
interleave_other_bb0
,
self
.
reshuffle_inst_slot_bb0
,
self
.
asm_txt
,
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0
)
for
line
in
self
.
new_asm_txt_bb0
:
print
(
line
)
self
.
new_asm_txt_bb1
=
self
.
gen_new_asm_txt
(
self
.
interleave_vmfma_bb1
,
self
.
interleave_other_bb1
,
self
.
reshuffle_inst_slot_bb1
,
self
.
new_asm_txt_bb0
,
self
.
core_loop_txt_bb0_enlarge_ds_read_vgpr_bb1
)
def
gen_core_loop_txt
(
self
,
branch_name
):
core_loop
=
[]
loop_flag
=
0
for
line
in
self
.
asm_txt
:
if
line
.
find
(
f
"
{
branch_name
}
:"
)
!=
-
1
:
loop_flag
=
1
if
line
.
find
(
f
"s_cbranch_scc1
{
branch_name
}
"
)
!=
-
1
:
loop_flag
=
0
if
loop_flag
==
1
:
core_loop
.
append
(
line
)
return
core_loop
def
find_next_free_vgpr
(
self
,
asm_txt
):
for
line
in
asm_txt
:
numvpgr_str
=
re
.
findall
(
r
'(?<=; NumVgprs: )\d*'
,
line
)
if
len
(
numvpgr_str
)
!=
0
:
next_free_vgpr
=
int
(
numvpgr_str
[
0
])
print
(
next_free_vgpr
)
return
next_free_vgpr
def
enlarge_ds_read
(
self
,
core_loop_txt
):
new_core_loop
=
[]
ds_read_list
=
[]
next_free_vgpr
=
self
.
next_free_vgpr
vgpr_replacement_list
=
[]
for
i
in
range
(
len
(
core_loop_txt
)):
line
=
core_loop_txt
[
i
]
find_ds_read
=
re
.
search
(
r
'ds_read2_b64 v\[\d*:\d*\]'
,
line
)
if
find_ds_read
:
ds_read_line
=
find_ds_read
.
group
()
if
ds_read_line
in
ds_read_list
:
new_core_loop
.
append
(
re
.
sub
(
r
'v\[\d*:\d*\]'
,
f
"v[
{
next_free_vgpr
}
:
{
next_free_vgpr
+
3
}
]"
,
line
))
old_vgpr_str
=
re
.
findall
(
r
'(?<=v\[)\d+'
,
line
)
old_vgpr
=
int
(
old_vgpr_str
[
0
])
vgpr_replacement_mfma
=
{}
vgpr_replacement_mfma
[
f
"v[
{
old_vgpr
}
:
{
old_vgpr
+
1
}
]"
]
=
f
"v[
{
next_free_vgpr
}
:
{
next_free_vgpr
+
1
}
]"
vgpr_replacement_mfma
[
f
"v[
{
old_vgpr
+
2
}
:
{
old_vgpr
+
3
}
]"
]
=
f
"v[
{
next_free_vgpr
+
2
}
:
{
next_free_vgpr
+
3
}
]"
vgpr_replacement_list
.
append
([
i
,
vgpr_replacement_mfma
])
next_free_vgpr
+=
4
else
:
new_core_loop
.
append
(
line
)
ds_read_list
.
append
(
find_ds_read
.
group
())
else
:
new_core_loop
.
append
(
line
)
#print(vgpr_replacement_list)
core_loop_suf_vgpr
=
[]
for
i
in
range
(
len
(
new_core_loop
)):
line
=
new_core_loop
[
i
]
if
line
.
find
(
"v_mfma"
)
!=
-
1
:
v_pair
=
re
.
findall
(
r
'v\[\d*:\d*]'
,
line
)
#print(i, v_pair)
new_line
=
line
for
i_rep
in
vgpr_replacement_list
:
if
i
>
i_rep
[
0
]:
if
v_pair
[
0
]
in
i_rep
[
1
].
keys
():
new_line
=
new_line
.
replace
(
v_pair
[
0
],
i_rep
[
1
][
v_pair
[
0
]])
if
v_pair
[
1
]
in
i_rep
[
1
].
keys
():
new_line
=
new_line
.
replace
(
v_pair
[
1
],
i_rep
[
1
][
v_pair
[
1
]])
#print(new_line)
core_loop_suf_vgpr
.
append
(
new_line
)
else
:
core_loop_suf_vgpr
.
append
(
line
)
#print(vgpr_replacement_list)
#for i in core_loop_suf_vgpr:
# print(i)
return
core_loop_suf_vgpr
def
gen_inst_weight
(
self
,
core_loop_txt
):
#core_loop_txt = self.core_loop_txt_bb0
inst_weight_dict
=
{}
for
line
in
core_loop_txt
:
if
line
.
find
(
"ds_write2_b64"
)
!=
-
1
:
inst_weight_dict
[
line
]
=
30
elif
line
.
find
(
"v_mul_lo_u32"
)
!=
-
1
:
inst_weight_dict
[
line
]
=
8
elif
line
.
find
(
"v_mul_hi_u32"
)
!=
-
1
:
inst_weight_dict
[
line
]
=
8
elif
line
.
find
(
"v_mfma"
)
!=
-
1
:
inst_weight_dict
[
line
]
=
56
elif
line
.
find
(
"buffer_load_dword"
)
!=
-
1
:
inst_weight_dict
[
line
]
=
30
elif
line
.
find
(
"s_barrier"
)
!=
-
1
:
inst_weight_dict
[
line
]
=
52
elif
line
.
find
(
";"
)
!=
-
1
:
inst_weight_dict
[
line
]
=
0
elif
len
(
line
.
strip
())
==
0
:
inst_weight_dict
[
line
]
=
0
else
:
inst_weight_dict
[
line
]
=
4
return
inst_weight_dict
def
get_interleave_start_line
(
self
,
core_loop_txt
):
#core_loop_txt = self.core_loop_txt_bb0
vgpr_for_vmfma
=
[]
vgpr_for_others
=
[]
interleave_vmfma
=
[]
interleave_other
=
[]
i_last_ds_read
=
0
for
i_inst
in
range
(
len
(
core_loop_txt
)):
line
=
core_loop_txt
[
i_inst
]
if
line
.
find
(
"ds_read"
)
!=
-
1
:
i_last_ds_read
=
i_inst
i_last_ds_read
+=
1
for
i_inst
in
range
(
i_last_ds_read
,
len
(
core_loop_txt
)):
line
=
core_loop_txt
[
i_inst
]
if
line
.
find
(
"v_mfma"
)
!=
-
1
:
prev_line
=
core_loop_txt
[
i_inst
-
1
]
if
prev_line
.
find
(
"s_waitcnt lgkmcnt("
)
!=
-
1
:
interleave_other
.
pop
()
interleave_vmfma
.
append
([
prev_line
,
i_inst
-
1
])
v_list_str
=
re
.
findall
(
r
"(?<=v\[)\d*:\d*(?=\])"
,
line
)
interleave_vmfma
.
append
([
line
,
i_inst
])
for
v_str
in
v_list_str
:
a
=
v_str
.
split
(
":"
)
vgpr_for_vmfma
.
append
(
int
(
a
[
0
]))
vgpr_for_vmfma
.
append
(
int
(
a
[
1
]))
else
:
v_list_str
=
re
.
findall
(
r
"(?<=v)\d+"
,
line
)
if
len
(
v_list_str
)
>
0
:
for
x
in
v_list_str
:
if
int
(
x
)
in
vgpr_for_vmfma
:
vgpr_for_vmfma
=
[]
interleave_vmfma
=
[]
interleave_other
=
[]
interleave_other
.
append
([
line
,
i_inst
])
#self.interleave_vmfma = interleave_vmfma
#self.interleave_other = interleave_other
return
interleave_vmfma
,
interleave_other
def
mfma_shuffle
(
self
,
interleave_vmfma
,
interleave_other
,
inst_weight_dict
):
reshuffle_inst_slot
=
[]
#interleave_vmfma = self.interleave_vmfma
#interleave_other = self.interleave_other
#inst_weight_dict = self.inst_weight_dict
#print(inst_weight_dict)
#return
gap
=
56
i_inst_others
=
0
for
vmfma
in
interleave_vmfma
:
tmp_gap
=
0
reshuffle_inst_slot
.
append
(
vmfma
[
0
])
if
vmfma
[
0
].
find
(
"s_waitcnt lgkmcnt("
)
!=
-
1
:
pass
else
:
while
i_inst_others
<
len
(
interleave_other
):
tmp_gap
+=
inst_weight_dict
[
interleave_other
[
i_inst_others
][
0
]]
#print(f"{interleave_other[i_inst_others][0][:-1]}: {inst_weight_dict[interleave_other[i_inst_others][0]]}")
#print(tmp_gap)
if
tmp_gap
>
gap
:
break
else
:
reshuffle_inst_slot
.
append
(
interleave_other
[
i_inst_others
][
0
])
i_inst_others
+=
1
if
i_inst_others
<
len
(
interleave_other
):
for
i
in
range
(
i_inst_others
,
len
(
interleave_other
)):
reshuffle_inst_slot
.
append
(
interleave_other
[
i
][
0
])
#self.reshuffle_inst_slot = reshuffle_inst_slot
return
reshuffle_inst_slot
def
gen_new_asm_txt
(
self
,
interleave_vmfma
,
interleave_other
,
reshuffle_inst_slot
,
asm_txt
,
core_loop_txt
):
new_asm_txt
=
[]
new_core_loop
=
[]
update_for_loop
=
0
reshuffle_start_point
=
min
(
interleave_vmfma
[
0
][
1
],
interleave_other
[
0
][
1
])
for
i
in
range
(
len
(
core_loop_txt
)):
if
i
>=
reshuffle_start_point
and
(
i
-
reshuffle_start_point
)
<
len
(
reshuffle_inst_slot
):
new_core_loop
.
append
(
reshuffle_inst_slot
[
i
-
reshuffle_start_point
])
else
:
new_core_loop
.
append
(
core_loop_txt
[
i
])
for
i
in
range
(
len
(
asm_txt
)):
if
update_for_loop
==
0
:
if
asm_txt
[
i
]
==
new_core_loop
[
0
]:
update_for_loop
=
1
if
len
(
new_core_loop
)
!=
0
and
update_for_loop
:
new_asm_txt
.
append
(
new_core_loop
.
pop
(
0
))
else
:
new_asm_txt
.
append
(
asm_txt
[
i
])
#self.new_asm_txt = new_asm_txt
return
new_asm_txt
if
__name__
==
"__main__"
:
asm_path
=
"example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl-hip-amdgcn-amd-amdhsa-gfx908.s"
asm_txt
=
read_example_asm_file
(
asm_path
)
asm_analyser
=
asm_file_analyser
(
asm_txt
)
#asm_analyser.get_interleave_start_line()
#asm_analyser.mfma_shuffle()
#asm_analyser.gen_new_asm_txt()
#print(asm_analyser.inst_weight_dict)
write_example_asm_file_back
(
asm_path
,
asm_analyser
.
new_asm_txt_bb1
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment