Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
920a752b
Commit
920a752b
authored
Jun 15, 2023
by
Adam Osewski
Browse files
clang-format
parent
209c1e50
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
96 deletions
+43
-96
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp
...u/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp
+0
-48
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out_roofline.hpp
...idwise_gemm_xdlops_splitk_direct_c_write_out_roofline.hpp
+43
-48
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp
View file @
920a752b
...
...
@@ -717,53 +717,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
#if 0
// preload data into LDS
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (karg.K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// gridwise GEMM pipeline
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
*
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
))
/
...
...
@@ -786,7 +739,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
#endif
// output: register to global memory
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out_roofline.hpp
View file @
920a752b
...
...
@@ -767,10 +767,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
constexpr
auto
N3
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
N4
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_n2_n3
=
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_n2_n3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M0
,
N0
,
I1
,
I1
,
I2
,
I1
,
I1
,
Number
<
8
>
{}));
make_tuple
(
M0
,
N0
,
I1
,
I1
,
I2
,
I1
,
I1
,
Number
<
8
>
{}));
const
auto
M0_grid
=
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
const
auto
N0_grid
=
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
...
...
@@ -781,10 +780,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const
auto
N3_grid
=
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
const
auto
N4_grid
=
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
// if (blockIdx.x == 0 && ThisThreadBlock::GetThreadId() == 0)
// {
// printf("grid: [M0: %d, N0: %d, M1: %d, N1: %d, M2: %d, N2: %d, N3: %d, N4: %d]\n",
// printf("grid: [M0: %d, N0: %d, M1: %d, N1: %d, M2: %d, N2: %d, N3: %d, N4:
// %d]\n",
// M0_grid,
// N0_grid,
// M1_grid,
...
...
@@ -797,28 +796,26 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n234_tmp
=
transform_tensor_descriptor
(
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_pass_through_transform
(
M
0_grid
),
make_pass_through_transform
(
N0
_grid
),
make_pass_through_transform
(
M
1_grid
),
make_pass_through_transform
(
N1
_grid
),
make_pass_through_transform
(
M2_grid
),
make_merge_transform
(
make_tuple
(
N3_grid
,
N2_grid
,
N4_grid
))
// num_groups_per_blk * group_size
),
make_tuple
(
make_pass_through_transform
(
M0_grid
),
make_pass_through_transform
(
N
0_grid
),
make_pass_through_transform
(
M1
_grid
),
make_pass_through_transform
(
N
1_grid
),
make_pass_through_transform
(
M2
_grid
),
make_merge_transform
(
make_tuple
(
N3_grid
,
N2_grid
,
N4_grid
))
// num_groups_per_blk * group_size
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}
),
Sequence
<
5
,
6
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}
));
Sequence
<
5
>
{}));
// if (blockIdx.x == 0 && ThisThreadBlock::GetThreadId() == 0)
// {
...
...
@@ -834,28 +831,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_n2_n3_new
=
transform_tensor_descriptor
(
c_grid_desc_m0_n0_m1_n1_m2_n234_tmp
,
make_tuple
(
make_pass_through_transform
(
M0_grid
),
// M0 - MRepeat / MXdlPerWave
make_pass_through_transform
(
N0_grid
),
// N0 - NRepeat / NXdlPerWave
make_pass_through_transform
(
M1_grid
),
// M1 - MWaves
make_pass_through_transform
(
N1_grid
),
// N1 - NWaves
make_unmerge_transform
(
make_tuple
(
I2
,
Number
<
16
>
{})),
// M2 -> (M2: 2, M3: 16)
make_unmerge_transform
(
make_tuple
(
I4
,
Number
<
8
>
{}))
// N2, N3
make_pass_through_transform
(
M0_grid
),
// M0 - MRepeat / MXdlPerWave
make_pass_through_transform
(
N0_grid
),
// N0 - NRepeat / NXdlPerWave
make_pass_through_transform
(
M1_grid
),
// M1 - MWaves
make_pass_through_transform
(
N1_grid
),
// N1 - NWaves
make_unmerge_transform
(
make_tuple
(
I2
,
Number
<
16
>
{})),
// M2 -> (M2: 2, M3: 16)
make_unmerge_transform
(
make_tuple
(
I4
,
Number
<
8
>
{}))
// N2, N3
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}
),
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
,
7
>
{}
)
);
Sequence
<
6
,
7
>
{}));
// if (blockIdx.x == 0 && ThisThreadBlock::GetThreadId() == 0)
// {
...
...
@@ -872,13 +866,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const
auto
wave_idx
=
blockwise_gemm
.
GetWaveIdx
();
const
auto
lane_id_to_m3_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
16
>
{},
I4
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{})
);
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
16
>
{},
I4
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
lane_data_idx_on_block
=
lane_id_to_m3_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
wave_idx
[
I2
]));
const
auto
lane_data_idx_on_block
=
lane_id_to_m3_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
wave_idx
[
I2
]));
// if (blockIdx.x == 0 && (ThisThreadBlock::GetThreadId() == 0 ||
// ThisThreadBlock::GetThreadId() == 16 ||
...
...
@@ -918,10 +911,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_n2_n3
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_m3_n2_n3_new
),
CElementwiseOperation
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
I2
,
I1
,
I1
,
8
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// CThreadTransferDstAccessOrder,
7
,
// CThreadTransferDstVectorDim,
8
,
// CThreadTransferDstScalarPerVector,
Sequence
<
M0
,
N0
,
I1
,
I1
,
I2
,
I1
,
I1
,
8
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// CThreadTransferDstAccessOrder,
7
,
// CThreadTransferDstVectorDim,
8
,
// CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_n0_m1_n1_m2_m3_n2_n3_new
,
...
...
@@ -936,18 +929,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
c_element_op
};
// if (blockIdx.x == 0 || blockIdx.x == 5)
// { // M1, N1, M2, N2, N3
// { // M1, N1, M2, N2,
// N3
// if (ThisThreadBlock::GetThreadId() == 0 ||
// ThisThreadBlock::GetThreadId() == 3 || // [ 0, 0, 0, 3, 0]
// ThisThreadBlock::GetThreadId() == 16 || // [ 0, 0, 4, 0, 0]
// ThisThreadBlock::GetThreadId() == 33 || // [ 0, 0, 8, 1, 0]
// ThisThreadBlock::GetThreadId() == 64 || // [ 0, 1, 0, 0, 0]
// ThisThreadBlock::GetThreadId() == 96 || // [ 0, 1, 8, 0, 0]
// ThisThreadBlock::GetThreadId() == 130 || // [ 1, 0, 0, 2, 0]
// ThisThreadBlock::GetThreadId() == 224 // [ 1, 1, 8, 0, 0]
// ThisThreadBlock::GetThreadId() == 3 || // [ 0, 0, 0, 3,
// 0] ThisThreadBlock::GetThreadId() == 16 || // [ 0, 0, 4, 0,
// 0] ThisThreadBlock::GetThreadId() == 33 || // [ 0, 0, 8, 1,
// 0] ThisThreadBlock::GetThreadId() == 64 || // [ 0, 1, 0, 0,
// 0] ThisThreadBlock::GetThreadId() == 96 || // [ 0, 1, 8, 0,
// 0] ThisThreadBlock::GetThreadId() == 130 || // [ 1, 0, 0, 2,
// 0] ThisThreadBlock::GetThreadId() == 224 // [ 1, 1, 8, 0,
// 0]
// )
// {
// printf("[B:%d, T:%d] -> dst_slice_origin_idx: [%d, %d, %d, %d, %d, %d, %d]\n",
// printf("[B:%d, T:%d] -> dst_slice_origin_idx: [%d, %d, %d, %d, %d, %d,
// %d]\n",
// get_block_1d_id(),
// ThisThreadBlock::GetThreadId(),
// m_thread_data_on_grid_idx[I0],
...
...
@@ -960,7 +956,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
// }
// }
c_thread_copy
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_n2_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment