Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5490b99c
Commit
5490b99c
authored
Apr 10, 2024
by
Harisankar Sadasivan
Browse files
2 tile streamk withreduction
parent
5b1e2442
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
117 deletions
+68
-117
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+15
-38
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+22
-33
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+31
-46
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
View file @
5490b99c
...
...
@@ -78,7 +78,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
BlockToCTileMap_GemmStreamK
<
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
,
StreamKReductionStrategy
::
Atomic
>
,
// StreamKReductionStrategy::Atomic>,
StreamKReductionStrategy
::
Reduction
>
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
...
...
@@ -139,11 +140,20 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
// stream-k: calculate the number of blocks to be launched based on #CUs and #occupancy
// dim3 grid_dims = karg.block_mapping.get_grid_dims(karg.num_cu, karg.occupancy);
dim3
grid_dims
=
karg
.
block_mapping
.
get_grid_dims
();
float
ave_time
=
0
;
int
occupancy
,
num_cu
;
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
dim3
grid_dims
=
karg
.
block_mapping
.
sk_num_blocks
;
printf
(
"Recommended #stream-k blocks (assuming full GPU availability): %0d
\n
"
,
num_cu
*
occupancy
);
float
ave_time
=
0
;
// TODO: remove clear buffer for streamk kernels
if
constexpr
(
GridwiseGemm
::
Block2CTileMap
::
ReductionStrategy
==
...
...
@@ -272,30 +282,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation
,
uint32_t
NumSKBlocks
=
0
)
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
printf
(
"Assuming full GPU availability, recommended stream-k grid size for tuning :%0d
\n
"
,
num_cu
*
occupancy
);
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
),
NumSKBlocks
};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
NumSKBlocks
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -315,15 +303,6 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation
,
index_t
NumSKBlocks
=
0
)
override
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
const
ADataType
*>
(
p_a
),
reinterpret_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -334,8 +313,6 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
),
static_cast
<
uint32_t
>
(
NumSKBlocks
));
}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
5490b99c
...
...
@@ -1007,16 +1007,11 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
k_iters_per_big_block
;
MDiv2
n_tiles
;
MDiv
k_iters_per_tile
;
MDiv
eq
a
v_tiles_big
;
// for reduction
MDiv
eq
a
v_tiles_little
;
// for reduction
MDiv
eq
ui
v_tiles_big
;
// for reduction
MDiv
eq
ui
v_tiles_little
;
// for reduction
// prefer construct on host
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
num_cu
,
uint32_t
occupancy
,
uint32_t
sk_blocks
=
0
)
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
sk_blocks
=
0
)
{
// total output tiles
uint32_t
num_tiles
=
...
...
@@ -1027,6 +1022,7 @@ struct BlockToCTileMap_GemmStreamK
// default to regular DP GEMM if sk blocks == 0
sk_num_blocks
=
sk_blocks
;
if
(
sk_num_blocks
==
0
||
sk_num_blocks
==
0xFFFFFFFF
)
{
sk_num_blocks
=
0
;
...
...
@@ -1042,7 +1038,7 @@ struct BlockToCTileMap_GemmStreamK
else
{
// grid size
uint32_t
grid_size
=
occupancy
*
num_cu
;
uint32_t
grid_size
=
sk_num_blocks
;
// check if there's enough work for DP+ stream-k
bool
bigEnough
=
num_tiles
>
grid_size
;
// max of 2 sk tiles per block
...
...
@@ -1068,7 +1064,7 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
dp_start_block_idx
=
sk_num_blocks
;
}
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
...
...
@@ -1079,8 +1075,8 @@ struct BlockToCTileMap_GemmStreamK
{
uint32_t
upper_big
=
math
::
lcm
(
k_iters_per_big_block
,
k_iters_per_tile
.
get
());
uint32_t
upper_little
=
math
::
lcm
(
k_iters_per_big_block
-
1
,
k_iters_per_tile
.
get
());
eq
a
v_tiles_big
=
MDiv
(
upper_big
/
k_iters_per_tile
.
get
());
eq
a
v_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
eq
ui
v_tiles_big
=
MDiv
(
upper_big
/
k_iters_per_tile
.
get
());
eq
ui
v_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
}
#if 0
...
...
@@ -1091,8 +1087,7 @@ struct BlockToCTileMap_GemmStreamK
"sk_tiles:%u, workspace(acc float):%u\n",
num_cu,
occupancy,
// get_grid_dims(num_cu, occupancy).x,
get_grid_dims().x,
get_grid_dims(num_cu, occupancy).x,
num_tiles,
dp_tiles,
sk_num_big_blocks,
...
...
@@ -1124,15 +1119,9 @@ struct BlockToCTileMap_GemmStreamK
}
// __host__ __device__ constexpr dim3 get_grid_dims(int num_cu, int occupancy) const
__host__
__device__
constexpr
dim3
get_grid_dims
()
const
__host__
__device__
constexpr
dim3
get_grid_dims
(
uint32_t
num_cu
,
uint32_t
occupancy
)
const
{
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
return
dim3
(
reduction_start_block_idx
+
get_sk_tiles
(),
1
,
1
);
}
else
return
dim3
(
reduction_start_block_idx
,
1
,
1
);
// return dim3(num_cu * occupancy, 1, 1); // HS
return
dim3
(
num_cu
*
occupancy
,
1
,
1
);
}
__host__
__device__
uint32_t
total_blocks_allocated
()
const
{
...
...
@@ -1240,13 +1229,13 @@ struct BlockToCTileMap_GemmStreamK
}
__host__
__device__
uint32_t
get_tile_intersections
(
uint32_t
tiles_
,
const
MDiv
&
eq
a
v_tiles_
)
const
const
MDiv
&
eq
ui
v_tiles_
)
const
{
uint32_t
tile_idx_
=
tiles_
==
0
?
0
:
(
tiles_
-
1
);
uint32_t
max_eq
a
v_tiles_
=
eq
a
v_tiles_
.
get
()
-
1
;
uint32_t
tile_idx_
=
tiles_
==
0
?
0
:
(
tiles_
-
1
);
uint32_t
max_eq
ui
v_tiles_
=
eq
ui
v_tiles_
.
get
()
-
1
;
uint32_t
quo_
,
rem_
;
eq
a
v_tiles_
.
divmod
(
tile_idx_
,
quo_
,
rem_
);
return
quo_
*
max_eq
a
v_tiles_
+
rem_
;
eq
ui
v_tiles_
.
divmod
(
tile_idx_
,
quo_
,
rem_
);
return
quo_
*
max_eq
ui
v_tiles_
+
rem_
;
}
__host__
__device__
uint32_t
get_tiles_cover_sk_block
(
uint32_t
num_sk_blocks_
,
...
...
@@ -1264,9 +1253,9 @@ struct BlockToCTileMap_GemmStreamK
get_tiles_cover_sk_block
(
sk_num_blocks
-
sk_num_big_blocks
,
k_iters_per_big_block
-
1
);
uint32_t
total_intersec_big
=
get_tile_intersections
(
tiles_cover_big_blocks
,
eq
a
v_tiles_big
);
get_tile_intersections
(
tiles_cover_big_blocks
,
eq
ui
v_tiles_big
);
uint32_t
total_intersec_little
=
get_tile_intersections
(
tiles_cover_little_blocks
,
eq
a
v_tiles_little
);
get_tile_intersections
(
tiles_cover_little_blocks
,
eq
ui
v_tiles_little
);
return
sk_num_blocks
+
total_intersec_big
+
total_intersec_little
;
}
...
...
@@ -1281,7 +1270,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
touched_sk_blocks
=
(
tile_idx_
*
k_iters_per_tile
.
get
()
+
k_iters_per_big_block
-
1
)
/
k_iters_per_big_block
;
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_
,
eq
a
v_tiles_big
);
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_
,
eq
ui
v_tiles_big
);
return
touched_sk_blocks
+
current_intersec
;
}
else
...
...
@@ -1292,7 +1281,7 @@ struct BlockToCTileMap_GemmStreamK
(
tile_idx_little_reverse
*
k_iters_per_tile
.
get
()
+
iters_per_little_sk_block
-
1
)
/
iters_per_little_sk_block
;
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_little_reverse
,
eq
a
v_tiles_little
);
get_tile_intersections
(
tile_idx_little_reverse
,
eq
ui
v_tiles_little
);
return
get_total_acc_buffers
()
-
(
touched_sk_blocks
+
current_intersec
);
}
}
...
...
@@ -1305,7 +1294,7 @@ struct BlockToCTileMap_GemmStreamK
{
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_
*
iters_per_big_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eq
a
v_tiles_big
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eq
ui
v_tiles_big
);
return
block_idx_
+
current_intersec
;
}
else
...
...
@@ -1313,7 +1302,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
block_idx_little_reverse
=
sk_num_blocks
-
block_idx_
;
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_little_reverse
*
iters_per_little_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eq
a
v_tiles_little
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eq
ui
v_tiles_little
);
return
get_total_acc_buffers
()
-
(
block_idx_little_reverse
+
current_intersec
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
5490b99c
...
...
@@ -145,7 +145,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
num_cu
,
occupancy
;
// stream-k arguments
Block2CTileMap
block_mapping
;
Argument
(
const
FloatAB
*
p_a_grid_
,
...
...
@@ -157,8 +156,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
uint32_t
num_cu_
,
uint32_t
occupancy_
,
uint32_t
num_sk_blocks_
)
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
...
...
@@ -169,9 +166,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA
(
StrideA_
),
StrideB
(
StrideB_
),
StrideC
(
StrideC_
),
num_cu
(
num_cu_
),
occupancy
(
occupancy_
),
block_mapping
(
M
,
N
,
K
,
num_cu_
,
occupancy_
,
num_sk_blocks_
)
block_mapping
(
M
,
N
,
K
,
num_sk_blocks_
)
{
}
...
...
@@ -523,9 +518,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v3
();
uint32_t
*
p_semaphore
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
block_mapping
.
get_workspace_size_for_acc
(
sizeof
(
FloatAcc
)));
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
...
...
@@ -536,6 +528,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
uint32_t
total_iter_length
;
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
// get nperblock, mperblock for reduction
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
// calculate total Mreduce iterations for block
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_length_reduce
.
At
(
I1
)
*
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{});
// calculate
// total Nreduce
// iterations for
// block
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
// thread
// buf LOAD
// descriptor
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
// thread
// buf STORE
// descriptor
#pragma unroll
// stream-k: for new work for all the persistent blocks.
for
(;
block_idx
<
block_mapping
.
total_blocks_allocated
();
block_idx
+=
gridDim
.
x
)
...
...
@@ -558,28 +577,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
is_reduction_block
=
block_idx
>=
block_mapping
.
reduction_start_block_idx
;
if
(
is_reduction_block
)
{
// descriptors
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_length_reduce
.
At
(
I1
)
*
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{});
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
...
...
@@ -622,8 +622,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto
reduction_idx
=
block_idx
-
block_mapping
.
reduction_start_block_idx
;
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
reduction_idx
,
m
,
n
);
workgroup_barrier
wg_barrier
(
p_semaphore
);
uint32_t
tile_acc_offset_start
=
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
uint32_t
tile_acc_offset_end
=
...
...
@@ -669,9 +667,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CBlockTransferScalarPerVector_NWaveNPerXDL
),
CElementwiseOperation
{}};
// block synchronization
wg_barrier
.
wait_eq
(
reduction_idx
,
tile_acc_offset_end
-
tile_acc_offset_start
);
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(block_idx),
...
...
@@ -750,9 +745,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
partial_acc_store_step_m
);
}
}
return
;
continue
;
}
}
while
(
true
)
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
...
...
@@ -1131,17 +1127,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
mxdlperwave_forward_step
);
}
});
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
tile_idx
);
}
}
}
// exit condition
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment