Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5b1e2442
Commit
5b1e2442
authored
Mar 22, 2024
by
Harisankar Sadasivan
Browse files
2-tile sk+ DP with atomics for FP16
parent
2ae16e90
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
697 additions
and
722 deletions
+697
-722
example/01_gemm/README.md
example/01_gemm/README.md
+22
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+13
-19
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+62
-123
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+600
-580
No files found.
example/01_gemm/README.md
View file @
5b1e2442
...
...
@@ -21,3 +21,25 @@ Warm up
Start running 5 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
```
# Instructions for ```example_gemm_xdl_streamk```
## Run ```example_gemm_xdl_streamk```
```
bash
# arg1: verification (0=no, 1=yes)
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
# arg3: time kernel (0=no, 1=yes)
# arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
# arg10: NumSKBlocks(optional, defaults to DP GEMM)
bin/example_gemm_xdl_streamk 1 2 1 3840 4096 4096 4096 4096 4096 312
```
Result (MI250 @ 1700Mhz, 181TFlops peak FP16 on 1 dye)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Recommended grid size :312
Perf: 1.21689 ms, 105.884 TFlops, 79.2748 GB/s, GemmXdlStreamK_RRR_B256_Vec8x2x8_128x128x4x8
```
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
View file @
5b1e2442
...
...
@@ -137,6 +137,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
"setting"
);
}
// stream-k: calculate the number of blocks to be launched based on #CUs and #occupancy
// dim3 grid_dims = karg.block_mapping.get_grid_dims(karg.num_cu, karg.occupancy);
dim3
grid_dims
=
karg
.
block_mapping
.
get_grid_dims
();
float
ave_time
=
0
;
...
...
@@ -268,22 +270,19 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
uint32_t
NumSKBlocks
=
0
xffffffff
)
uint32_t
NumSKBlocks
=
0
)
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
printf
(
"Assuming full GPU availability, recommended stream-k grid size for tuning :%0d
\n
"
,
num_cu
*
occupancy
);
return
Argument
{
p_a
,
p_b
,
...
...
@@ -318,17 +317,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
const
ADataType
*>
(
p_a
),
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
5b1e2442
...
...
@@ -1010,142 +1010,69 @@ struct BlockToCTileMap_GemmStreamK
MDiv
eqav_tiles_big
;
// for reduction
MDiv
eqav_tiles_little
;
// for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
// prefer construct on host
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
num_cu
,
uint32_t
occupancy
,
uint32_t
sk_blocks
=
0
xffffffff
)
uint32_t
sk_blocks
=
0
)
{
// total output tiles
uint32_t
num_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t
full_dispatches
=
num_tiles
/
num_cu
;
uint32_t
full_dispatch_tiles
=
full_dispatches
*
num_cu
;
uint32_t
partial_dispatche_tiles
=
num_tiles
-
full_dispatch_tiles
;
uint32_t
sk_occupancy
=
occupancy
;
uint32_t
dp_tiles
=
full_dispatch_tiles
;
uint32_t
sk_tiles
=
partial_dispatche_tiles
;
if
(
full_dispatches
<
occupancy
)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy
=
1
;
// TODO: single occ seems better
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
if
((
occupancy
>
1
)
&&
(
full_dispatches
%
occupancy
==
occupancy
-
1
))
uint32_t
dp_tiles
,
dp_num_blocks
,
sk_total_iters
;
// default to regular DP GEMM if sk blocks == 0
sk_num_blocks
=
sk_blocks
;
if
(
sk_num_blocks
==
0
||
sk_num_blocks
==
0xFFFFFFFF
)
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy
=
1
;
// left 1 slot for sk occupancy
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
sk_num_blocks
=
0
;
dp_tiles
=
num_tiles
;
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
// 2-tile sk + DP GEMM
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy
=
occupancy
-
((
full_dispatches
-
1
)
%
occupancy
);
dp_tiles
=
full_dispatch_tiles
-
num_cu
;
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
// grid size
uint32_t
grid_size
=
occupancy
*
num_cu
;
// check if there's enough work for DP+ stream-k
bool
bigEnough
=
num_tiles
>
grid_size
;
// max of 2 sk tiles per block
uint32_t
sk_tiles
=
bigEnough
?
grid_size
+
num_tiles
%
grid_size
:
num_tiles
;
// remaining tiles are DP tiles
dp_tiles
=
bigEnough
?
(
num_tiles
-
sk_tiles
)
:
0
;
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
uint32_t
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
uint32_t
dp_num_blocks
=
0
;
{
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
uint32_t
max_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
*
sk_occupancy
:
math
::
min
(
num_cu
,
sk_total_iters
/
min_k_iters_per_sk_block
);
// if use dp for sk-block, how many iters do we need
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
std
::
numeric_limits
<
int
>::
max
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
uint32_t
tentative_sk_iters_per_block
=
(
sk_total_iters
+
tentative_sk_blocks
-
1
)
/
tentative_sk_blocks
;
uint32_t
tentative_sk_iters
=
tentative_sk_iters_per_block
;
uint32_t
sk_blocks_per_tile
=
(
tentative_sk_blocks
+
sk_tiles
-
1
)
/
sk_tiles
;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t
cross_sk_blocks_overhead
=
sk_blocks_per_tile
;
if
(
tentative_sk_blocks
%
sk_tiles
!=
0
)
{
// penalty for uneven divide
cross_sk_blocks_overhead
+=
sk_blocks_per_tile
*
tentative_sk_iters_per_block
/
50
;
}
uint32_t
tentative_sk_score
=
tentative_sk_iters
+
cross_sk_blocks_overhead
;
if
(
tentative_sk_score
<
best_sk_score
)
{
best_sk_score
=
tentative_sk_score
;
sk_num_blocks
=
tentative_sk_blocks
;
}
}
if
(
best_sk_score
>=
dp_for_sk_iters
)
{
sk_num_blocks
=
0
;
}
// give a chance to control num of sk blocks
sk_num_blocks
=
sk_blocks
!=
0xffffffff
?
sk_blocks
:
sk_num_blocks
;
if
(
sk_num_blocks
==
0
)
{
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
else
{
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
}
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
// using multiple blocks for parallel reduction
reduction_start_block_idx
=
dp_start_block_idx
+
dp_num_blocks
;
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
...
...
@@ -1157,13 +1084,14 @@ struct BlockToCTileMap_GemmStreamK
}
#if 0
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
printf("cu:%d, occupancy:%d, grids
ize
:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d,
dp_iters_per_block:%d,
dp_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n",
num_cu,
occupancy,
// get_grid_dims(num_cu, occupancy).x,
get_grid_dims().x,
num_tiles,
dp_tiles,
...
...
@@ -1171,7 +1099,7 @@ struct BlockToCTileMap_GemmStreamK
sk_num_blocks,
sk_total_iters,
dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks,
k_iters_per_tile.get(),
k_iters_per_big_block,
...
...
@@ -1195,7 +1123,8 @@ struct BlockToCTileMap_GemmStreamK
return
k_iters_per_tile
.
div
(
sk_total_iters
);
}
__host__
__device__
dim3
get_grid_dims
()
const
// __host__ __device__ constexpr dim3 get_grid_dims(int num_cu, int occupancy) const
__host__
__device__
constexpr
dim3
get_grid_dims
()
const
{
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
...
...
@@ -1203,6 +1132,16 @@ struct BlockToCTileMap_GemmStreamK
}
else
return
dim3
(
reduction_start_block_idx
,
1
,
1
);
// return dim3(num_cu * occupancy, 1, 1); // HS
}
__host__
__device__
uint32_t
total_blocks_allocated
()
const
{
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
return
__builtin_amdgcn_readfirstlane
(
reduction_start_block_idx
+
get_sk_tiles
());
}
else
return
__builtin_amdgcn_readfirstlane
(
reduction_start_block_idx
);
}
__device__
uint32_t
get_block_idx
()
const
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
5b1e2442
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment