Commit 099c470e authored by wangshaojie6's avatar wangshaojie6
Browse files

do some tests

parent 731febb6
;origin loop
.origin_loop_start:
ds_read2_b64 v_lda[0:3]
ds_read2_b64 v_ldb[0:3]
ds_read2_b64 v_lda[4:7]
ds_read2_b64 v_ldb[4:7]
v_mfma v_lda[0:1], v_ldb[0:1]
v_mfma v_lda[2:3], v_ldb[2:3]
v_mfma v_lda[0:1], v_ldb[4:5]
v_mfma v_lda[2:3], v_ldb[6:7]
v_mfma v_lda[4:5], v_ldb[0:1]
v_mfma v_lda[6:7], v_ldb[2:3]
v_mfma v_lda[4:5], v_ldb[4:5]
v_mfma v_lda[6:7], v_ldb[6:7]
ds_read2_b64 v_lda[0:3] offset: next k
ds_read2_b64 v_lda[4:7] offset: next k
ds_read2_b64 v_ldb[0:3] offset: next k
ds_read2_b64 v_ldb[4:7] offset: next k
s_barrier
v_mfma v_lda[0:1], v_ldb[0:1]
v_mfma v_lda[2:3], v_ldb[2:3]
v_mfma v_lda[0:1], v_ldb[4:5]
v_mfma v_lda[2:3], v_ldb[6:7]
v_pack v_lda[0], v_gla[0], v_gla[1], lo
v_pack v_lda[1], v_gla[0], v_gla[1], hi
v_pack v_lda[2], v_gla[2], v_gla[3], lo
v_pack v_lda[3], v_gla[2], v_gla[3], hi
ds_write2_b64 v_lda[0:1], v_lda[2:3]
v_pack v_pkb[0], v_glb[0], v_glb[1], lo
v_pack v_pkb[1], v_glb[0], v_glb[1], hi
v_pack v_pkb[2], v_glb[2], v_glb[3], lo
v_pack v_pkb[3], v_glb[2], v_glb[3], hi
ds_write2_b64 v_pkb[0:1], v_pkb[2:3]
s_barrier
v_move_slice_window 0
v_move_slice_window 1
; ... ~60 valus
buffer_load_dwordx4 v_gla[0:3]
buffer_load_dwordx4 v_glb[0:3]
v_mfma v_lda[4:5], v_ldb[0:1]
v_mfma v_lda[6:7], v_ldb[2:3]
v_mfma v_lda[4:5], v_ldb[4:5]
v_mfma v_lda[6:7], v_ldb[6:7]
s_branch origin_loop_start
;optimized loop
.optimized_loop_start:
ds_read2_b64 v_lda[0:3]
ds_read2_b64 v_ldb[0:3]
ds_read2_b64 v_lda[4:7]
ds_read2_b64 v_ldb[4:7]
v_mfma v_lda[0:1], v_ldb[0:1]
v_mfma v_lda[2:3], v_ldb[2:3]
v_mfma v_lda[0:1], v_ldb[4:5]
v_mfma v_lda[2:3], v_ldb[6:7]
v_mfma v_lda[4:5], v_ldb[0:1]
v_mfma v_lda[6:7], v_ldb[2:3]
v_mfma v_lda[4:5], v_ldb[4:5]
v_mfma v_lda[6:7], v_ldb[6:7]
ds_read2_b64 v_lda[8:11] offset: next k
ds_read2_b64 v_lda[12:15] offset: next k
ds_read2_b64 v_ldb[8:11] offset: next k
ds_read2_b64 v_ldb[12:15] offset: next k
v_mfma v_lda[8:9], v_ldb[8:9]
s_barrier
v_mfma v_lda[10:11], v_ldb[10:11]
v_pack v_lda[0], v_gla[0], v_gla[1], lo
v_pack v_lda[1], v_gla[0], v_gla[1], hi
v_pack v_lda[2], v_gla[2], v_gla[3], lo
v_pack v_lda[3], v_gla[2], v_gla[3], hi
ds_write2_b64 v_lda[0:1], v_lda[2:3]
v_mfma v_lda[8:9], v_ldb[12:13]
v_pack v_pkb[0], v_glb[0], v_glb[1], lo
v_pack v_pkb[1], v_glb[0], v_glb[1], hi
v_pack v_pkb[2], v_glb[2], v_glb[3], lo
v_pack v_pkb[3], v_glb[2], v_glb[3], hi
ds_write2_b64 v_pkb[0:1], v_pkb[2:3]
v_mfma v_lda[10:11], v_ldb[14:15]
s_barrier
v_mfma v_lda[12:13], v_ldb[8:9]
v_move_slice_window 0
v_mfma v_lda[12:13], v_ldb[10:11]
v_move_slice_window 1
buffer_load_dwordx4 v_gla[0:3]
v_mfma v_lda[12:13], v_ldb[12:13]
buffer_load_dwordx4 v_glb[0:3]
v_mfma v_lda[14:15], v_ldb[14:15]
s_branch optimized_loop_start
......@@ -42,14 +42,14 @@ using DeviceConvBwdDataInstance = ck::tensor_operation::device::
OutElementOp, // OutElementwiseOperation
ConvBwdDefault, // ConvolutionBackwardDataSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
256, // MPerBlock
256, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
4, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......@@ -61,7 +61,7 @@ using DeviceConvBwdDataInstance = ck::tensor_operation::device::
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
2, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
7,
......
......@@ -44,22 +44,22 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
256, // BlockSize
256, // MPerBlock
256, // NPerBlock
4, // K0PerBlock
64, // MPerBlock
128, // NPerBlock
8, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
4, // NXdlPerWave
S<1, 4, 32, 2>, // ABlockTransferThreadClusterLengths_K0_M_K1
1, // MXdlPerWave
2, // NXdlPerWave
S<1, 8, 8, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 32, 2>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 8, 16, 2>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
......
......@@ -279,6 +279,24 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// });
//});
//static_for<0, KPerThread, KPack>{}([&](auto k) {
// static_for<0, MRepeat, 1>{}([&](auto m0) {
// //read from lds for A
// a_thread_copy_.Run();
// });
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// //read from lds for B
// b_thread_copy_.Run();
// });
//
// static_for<0, MRepeat, 1>{}([&](auto m0) {
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// // do mfma within k
// xdlops_gemm.template Run();
// });
// });
//});
static_for<0, KPerThread, KPack>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
......
......@@ -673,15 +673,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t k0_block_data_begin = 0;
block_sync_lds();
//do
//{
// blockwise_gemm.Run();
//
// block_sync_lds();
//
// a_blockwise_copy.MoveSrcSliceWindow();
// b_blockwise_copy.MoveSrcSliceWindow();
//
// a_blockwise_copy.RunWrite();
// b_blockwise_copy.RunWrite();
//
// a_blockwise_copy.RunRead();
// block_sync_lds();
// b_blockwise_copy.RunRead();
//
// k0 += K0PerBlock;
//} while(k0 < (K0 - K0PerBlock));
do
{
//a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
//block_sync_lds();
//b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
......
......@@ -19,7 +19,13 @@ class asm_file_analyser:
self.core_loop_txt_bb0 = self.gen_core_loop_txt(".LBB0_1")
self.core_loop_txt_bb1 = self.gen_core_loop_txt(".LBB1_1")
self.next_free_vgpr = self.find_next_free_vgpr(asm_txt)
self.next_free_vgpr = self.find_next_free_vgpr(self.asm_txt)
self.vgpr_limit_number = self.find_vgpr_limit(self.asm_txt)
self.asm_txt_max_vgpr = self.set_vgpr_to_max(self.asm_txt)
print(self.vgpr_limit_number)
#assert False
self.core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0 = self.enlarge_ds_read(self.core_loop_txt_bb0)
self.core_loop_txt_bb0_enlarge_ds_read_vgpr_bb1 = self.enlarge_ds_read(self.core_loop_txt_bb1)
......@@ -33,10 +39,10 @@ class asm_file_analyser:
self.reshuffle_inst_slot_bb0 = self.mfma_shuffle(self.interleave_vmfma_bb0, self.interleave_other_bb0, self.inst_weight_dict_bb0)
self.reshuffle_inst_slot_bb1 = self.mfma_shuffle(self.interleave_vmfma_bb1, self.interleave_other_bb1, self.inst_weight_dict_bb1)
self.new_asm_txt_bb0 = self.gen_new_asm_txt(self.interleave_vmfma_bb0, self.interleave_other_bb0, self.reshuffle_inst_slot_bb0, self.asm_txt, self.core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0)
self.new_asm_txt_bb0 = self.gen_new_asm_txt(self.interleave_vmfma_bb0, self.interleave_other_bb0, self.reshuffle_inst_slot_bb0, self.asm_txt_max_vgpr, self.core_loop_txt_bb0_enlarge_ds_read_vgpr_bb0)
for line in self.new_asm_txt_bb0:
print(line)
#for line in self.new_asm_txt_bb0:
# print(line)
self.new_asm_txt_bb1 = self.gen_new_asm_txt(self.interleave_vmfma_bb1, self.interleave_other_bb1, self.reshuffle_inst_slot_bb1, self.new_asm_txt_bb0, self.core_loop_txt_bb0_enlarge_ds_read_vgpr_bb1)
......@@ -58,9 +64,38 @@ class asm_file_analyser:
numvpgr_str = re.findall(r'(?<=; NumVgprs: )\d*', line)
if len(numvpgr_str) != 0:
next_free_vgpr = int(numvpgr_str[0])
print(next_free_vgpr)
#print(next_free_vgpr)
return next_free_vgpr
def find_vgpr_limit(self, asm_txt):
for line in asm_txt:
lds_size_re = re.search(r'(?<=; LDSByteSize: )\d*', line)
if lds_size_re:
lds_size_str = lds_size_re.group()
lds_size = int(lds_size_str)
agpr_size_re = re.search(r'(?<=; NumAgprs: )\d*', line)
if agpr_size_re:
agpr_size_str = agpr_size_re.group()
agpr_size = int(agpr_size_str)
vgpr_limit_number = 256//(min(64*1024 // lds_size, 256 // agpr_size))
return vgpr_limit_number
def set_vgpr_to_max(self, asm_txt):
asm_max_vgpr = []
for line in asm_txt:
if line.find(".vgpr_count:") != -1:
col = line.find(".vgpr_count:")
print(col)
asm_max_vgpr.append(f"{line[:col]}.vgpr_count: {self.vgpr_limit_number}\n")
elif line.find(".amdhsa_next_free_vgpr") != -1:
col_hsa = line.find(".amdhsa_next_free_vgpr")
asm_max_vgpr.append(f"{line[:col_hsa]}.amdhsa_next_free_vgpr {self.vgpr_limit_number}\n")
else:
asm_max_vgpr.append(line)
return asm_max_vgpr
def enlarge_ds_read(self, core_loop_txt):
new_core_loop = []
ds_read_list = []
......@@ -97,20 +132,25 @@ class asm_file_analyser:
v_pair = re.findall(r'v\[\d*:\d*]', line)
#print(i, v_pair)
new_line = line
replace_dict = {}
for i_rep in vgpr_replacement_list:
if i > i_rep[0]:
if v_pair[0] in i_rep[1].keys():
new_line = new_line.replace(v_pair[0], i_rep[1][v_pair[0]])
#new_line = new_line.replace(v_pair[0], i_rep[1][v_pair[0]])
replace_dict[v_pair[0]] = i_rep[1][v_pair[0]]
if v_pair[1] in i_rep[1].keys():
new_line = new_line.replace(v_pair[1], i_rep[1][v_pair[1]])
#new_line = new_line.replace(v_pair[1], i_rep[1][v_pair[1]])
replace_dict[v_pair[1]] = i_rep[1][v_pair[1]]
#print(new_line)
#print(replace_dict)
for v_rep in replace_dict:
new_line = new_line.replace(v_rep, replace_dict[v_rep])
core_loop_suf_vgpr.append(new_line)
else:
core_loop_suf_vgpr.append(line)
#print(vgpr_replacement_list)
#for i in core_loop_suf_vgpr:
# print(i)
return core_loop_suf_vgpr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment