Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a8b336a0
Commit
a8b336a0
authored
May 31, 2023
by
carlushuang
Browse files
clean code and use atomic streamk by default
parent
15e80eb1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
56 deletions
+5
-56
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+3
-31
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+1
-24
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
View file @
a8b336a0
...
@@ -78,7 +78,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -78,7 +78,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
BlockToCTileMap_GemmStreamK
<
MPerBlock
,
BlockToCTileMap_GemmStreamK
<
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
,
K0PerBlock
*
K1
,
StreamKReductionStrategy
::
Reduction
>
,
StreamKReductionStrategy
::
Atomic
>
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
a8b336a0
...
@@ -693,23 +693,11 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -693,23 +693,11 @@ struct BlockToCTileMap_GemmStreamK
// pass to device
// pass to device
uint32_t
sk_num_blocks
;
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
sk_num_big_blocks
;
// uint32_t sk_total_iters;
uint32_t
dp_start_block_idx
;
uint32_t
dp_start_block_idx
;
// uint32_t dp_iters_per_block;
// uint32_t dp_num_blocks;
uint32_t
reduction_start_block_idx
;
uint32_t
reduction_start_block_idx
;
uint32_t
k_iters_per_big_block
;
uint32_t
k_iters_per_big_block
;
// uint32_t tiles_cover_big_blocks; // for reduction
// uint32_t total_acc_buffers; // for reduction
MDiv2
n_tiles
;
MDiv2
n_tiles
;
MDiv
k_iters_per_tile
;
MDiv
k_iters_per_tile
;
MDiv
eqav_tiles_big
;
// for reduction
MDiv
eqav_tiles_big
;
// for reduction
MDiv
eqav_tiles_little
;
// for reduction
MDiv
eqav_tiles_little
;
// for reduction
...
@@ -859,9 +847,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -859,9 +847,7 @@ struct BlockToCTileMap_GemmStreamK
eqav_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
eqav_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
}
}
// tile_swizzle_sub_m_rem =
#if 0
// MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
...
@@ -883,6 +869,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -883,6 +869,7 @@ struct BlockToCTileMap_GemmStreamK
reduction_start_block_idx,
reduction_start_block_idx,
get_sk_tiles(),
get_sk_tiles(),
get_workspace_size(sizeof(float)));
get_workspace_size(sizeof(float)));
#endif
}
}
__host__
__device__
uint32_t
get_sk_total_iters
()
const
__host__
__device__
uint32_t
get_sk_total_iters
()
const
...
@@ -962,7 +949,6 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -962,7 +949,6 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
m_tile_idx
,
n_tile_idx
;
uint32_t
m_tile_idx
,
n_tile_idx
;
uint32_t
n_tiles_value
=
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
uint32_t
n_tiles_value
=
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
n_tiles
.
divmod
(
tile_idx
,
n_tiles_value
,
m_tile_idx
,
n_tile_idx
);
n_tiles
.
divmod
(
tile_idx
,
n_tiles_value
,
m_tile_idx
,
n_tile_idx
);
// return make_tuple(m_tile_idx, n_tile_idx);
// swizzle tile
// swizzle tile
uint32_t
m_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
);
uint32_t
m_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
);
...
@@ -983,18 +969,10 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -983,18 +969,10 @@ struct BlockToCTileMap_GemmStreamK
n_tile_idx_with_adapt
=
tile_idx_local
/
sub_m_adapt
;
n_tile_idx_with_adapt
=
tile_idx_local
/
sub_m_adapt
;
m_tile_idx_with_adapt
=
tile_idx_local
%
sub_m_adapt
;
m_tile_idx_with_adapt
=
tile_idx_local
%
sub_m_adapt
;
// sub_m_adapt.divmod(tile_idx_local, n_tile_idx_with_adapt, m_tile_idx_with_adapt);
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
,
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
,
n_tile_idx_with_adapt
);
n_tile_idx_with_adapt
);
}
}
// __host__ __device__ uint32_t get_workspace_offset_for_semaphore() const
// {
// // workspace contains 2 part, 1) partial reduction buffer 2) semaphore for cross-wg sync
// // we let 1) start from offset:0, 2) start from the end of 1)
// // NOTE: offset is in unit of byte
// return get_total_acc_buffers() *
// }
__host__
__device__
uint32_t
get_workspace_size_for_acc
(
uint32_t
acc_element_bytes
)
const
__host__
__device__
uint32_t
get_workspace_size_for_acc
(
uint32_t
acc_element_bytes
)
const
{
{
static
constexpr
uint32_t
alignment
=
128
;
static
constexpr
uint32_t
alignment
=
128
;
...
@@ -1021,7 +999,6 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1021,7 +999,6 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
quo_
,
rem_
;
uint32_t
quo_
,
rem_
;
eqav_tiles_
.
divmod
(
tile_idx_
,
quo_
,
rem_
);
eqav_tiles_
.
divmod
(
tile_idx_
,
quo_
,
rem_
);
return
quo_
*
max_eqav_tiles_
+
rem_
;
return
quo_
*
max_eqav_tiles_
+
rem_
;
// return tile_idx_ / eqav_tiles_ * max_eqav_tiles_ + (tile_idx_ % eqav_tiles_);
}
}
__host__
__device__
uint32_t
get_tiles_cover_sk_block
(
uint32_t
num_sk_blocks_
,
__host__
__device__
uint32_t
get_tiles_cover_sk_block
(
uint32_t
num_sk_blocks_
,
...
@@ -1068,8 +1045,6 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1068,8 +1045,6 @@ struct BlockToCTileMap_GemmStreamK
iters_per_little_sk_block
;
iters_per_little_sk_block
;
uint32_t
current_intersec
=
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_little_reverse
,
eqav_tiles_little
);
get_tile_intersections
(
tile_idx_little_reverse
,
eqav_tiles_little
);
// printf("reverse tile:%u, %u/%u\n", tile_idx_little_reverse, touched_sk_blocks,
// current_intersec);
return
get_total_acc_buffers
()
-
(
touched_sk_blocks
+
current_intersec
);
return
get_total_acc_buffers
()
-
(
touched_sk_blocks
+
current_intersec
);
}
}
}
}
...
@@ -1080,7 +1055,6 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1080,7 +1055,6 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
iters_per_little_sk_block
=
k_iters_per_big_block
-
1
;
uint32_t
iters_per_little_sk_block
=
k_iters_per_big_block
-
1
;
if
(
block_idx_
<
sk_num_big_blocks
)
if
(
block_idx_
<
sk_num_big_blocks
)
{
{
// uint32_t touched_tiles = (block_idx_ * iters_per_big_sk_block + iters - 1) / iters;
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_
*
iters_per_big_sk_block
+
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_
*
iters_per_big_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eqav_tiles_big
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eqav_tiles_big
);
...
@@ -1089,9 +1063,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1089,9 +1063,7 @@ struct BlockToCTileMap_GemmStreamK
else
else
{
{
uint32_t
block_idx_little_reverse
=
sk_num_blocks
-
block_idx_
;
uint32_t
block_idx_little_reverse
=
sk_num_blocks
-
block_idx_
;
// uint32_t touched_tiles = (block_idx_little_reverse * iters_per_little_sk_block +
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
// iters - 1) / iters;
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_little_reverse
*
iters_per_little_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
block_idx_little_reverse
*
iters_per_little_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eqav_tiles_little
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eqav_tiles_little
);
return
get_total_acc_buffers
()
-
(
block_idx_little_reverse
+
current_intersec
);
return
get_total_acc_buffers
()
-
(
block_idx_little_reverse
+
current_intersec
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
a8b336a0
...
@@ -478,8 +478,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -478,8 +478,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// ignore = p_workspace; // TODO: for reduction
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -531,9 +529,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -531,9 +529,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
uint32_t
iter_start
,
iter_end
;
uint32_t
iter_start
,
iter_end
;
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
// if(threadIdx.x == 0)
// printf("xxx bid:%d, is_sk_block:%d, is_dp_block:%d\n", static_cast<int>(blockIdx.x),
// is_sk_block, is_dp_block);
if
(
is_padding_block
)
if
(
is_padding_block
)
return
;
return
;
...
@@ -654,14 +650,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -654,14 +650,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
#if 0
#if 0
if(threadIdx.x == 0) {
if(threadIdx.x == 0) {
// if(reduction_idx == 0){
// printf("(cluster red:%d,%d)bid:%d, rid:%d, os:%d-%d(%d), spatial:%d-%d, tid:%d, %d, %d\n",
// cluster_length_reduce.At(I0).value,
// cluster_length_reduce.At(I1).value, static_cast<int>(blockIdx.x),
// reduction_idx, tile_acc_offset_start, tile_acc_offset_end,
// tile_acc_offset_end - tile_acc_offset_start, spatial_idx[I0],
// spatial_idx[I1], static_cast<int>(threadIdx.x), thread_m_cluster_id,
// thread_n_cluster_id);
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
...
@@ -672,7 +660,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -672,7 +660,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
using
Accumulation
=
ck
::
detail
::
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
FloatAcc
>
;
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
FloatAcc
>
;
// static_for<0, MReduceIters, 1>{}([&](auto i_m_reduce) {
for
(
int
i_m
=
0
;
i_m
<
MReduceIters
;
i_m
++
)
for
(
int
i_m
=
0
;
i_m
<
MReduceIters
;
i_m
++
)
{
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
...
@@ -731,7 +718,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -731,7 +718,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
}
}
}
}
});
});
// if constexpr(i_m_reduce != MReduceIters - 1)
{
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_m
);
partial_acc_load_step_m
);
...
@@ -739,7 +725,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -739,7 +725,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
partial_acc_store_step_m
);
partial_acc_store_step_m
);
}
}
}
}
//});
return
;
return
;
}
}
}
}
...
@@ -767,14 +752,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -767,14 +752,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
const
index_t
k0_block_data_idx_on_grid
=
const
index_t
k0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
iter_offset
*
K0PerBlock
);
__builtin_amdgcn_readfirstlane
(
iter_offset
*
K0PerBlock
);
// if(threadIdx.x == 0)
// printf("[%s], bid:%d, block_idx:%d, tile_idx:%d(%d, %d, %d), iter_start:%d(%d |
// %d), iter_end:%d, len:%d\n",
// is_sk_block ? "sk_block" : (is_dp_block ? "dp_block" : "other "),
// static_cast<int>(blockIdx.x), block_idx, tile_idx, m_block_data_idx_on_grid,
// n_block_data_idx_on_grid, k0_block_data_idx_on_grid, iter_end -
// current_iter_length, iter_offset, iter_start, iter_end, current_iter_length);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment