Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5b1e2442
"...composable_kernel_rocm.git" did not exist on "238d58c2f5947246a3e62f72db2b175b2e948554"
Commit
5b1e2442
authored
Mar 22, 2024
by
Harisankar Sadasivan
Browse files
2-tile sk+ DP with atomics for FP16
parent
2ae16e90
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
697 additions
and
722 deletions
+697
-722
example/01_gemm/README.md
example/01_gemm/README.md
+22
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+13
-19
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+62
-123
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+600
-580
No files found.
example/01_gemm/README.md
View file @
5b1e2442
...
@@ -21,3 +21,25 @@ Warm up
...
@@ -21,3 +21,25 @@ Warm up
Start running 5 times...
Start running 5 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
```
```
# Instructions for ```example_gemm_xdl_streamk```
## Run ```example_gemm_xdl_streamk```
```
bash
# arg1: verification (0=no, 1=yes)
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
# arg3: time kernel (0=no, 1=yes)
# arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
# arg10: NumSKBlocks(optional, defaults to DP GEMM)
bin/example_gemm_xdl_streamk 1 2 1 3840 4096 4096 4096 4096 4096 312
```
Result (MI250 @ 1700Mhz, 181TFlops peak FP16 on 1 dye)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Recommended grid size :312
Perf: 1.21689 ms, 105.884 TFlops, 79.2748 GB/s, GemmXdlStreamK_RRR_B256_Vec8x2x8_128x128x4x8
```
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
View file @
5b1e2442
...
@@ -137,6 +137,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -137,6 +137,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
"setting"
);
"setting"
);
}
}
// stream-k: calculate the number of blocks to be launched based on #CUs and #occupancy
// dim3 grid_dims = karg.block_mapping.get_grid_dims(karg.num_cu, karg.occupancy);
dim3
grid_dims
=
karg
.
block_mapping
.
get_grid_dims
();
dim3
grid_dims
=
karg
.
block_mapping
.
get_grid_dims
();
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -268,22 +270,19 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -268,22 +270,19 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
uint32_t
NumSKBlocks
=
0
xffffffff
)
uint32_t
NumSKBlocks
=
0
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
hip_check_error
(
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
hipDeviceProp_t
dev_prop
;
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
rtn
);
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
num_cu
=
dev_prop
.
multiProcessorCount
;
printf
(
"Assuming full GPU availability, recommended stream-k grid size for tuning :%0d
\n
"
,
num_cu
*
occupancy
);
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -318,17 +317,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -318,17 +317,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
{
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
hip_check_error
(
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
hipDeviceProp_t
dev_prop
;
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
rtn
);
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
num_cu
=
dev_prop
.
multiProcessorCount
;
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
const
ADataType
*>
(
p_a
),
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
5b1e2442
...
@@ -1010,142 +1010,69 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1010,142 +1010,69 @@ struct BlockToCTileMap_GemmStreamK
MDiv
eqav_tiles_big
;
// for reduction
MDiv
eqav_tiles_big
;
// for reduction
MDiv
eqav_tiles_little
;
// for reduction
MDiv
eqav_tiles_little
;
// for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
// prefer construct on host
// prefer construct on host
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
n
,
uint32_t
k
,
uint32_t
k
,
uint32_t
num_cu
,
uint32_t
num_cu
,
uint32_t
occupancy
,
uint32_t
occupancy
,
uint32_t
sk_blocks
=
0
xffffffff
)
uint32_t
sk_blocks
=
0
)
{
{
// total output tiles
uint32_t
num_tiles
=
uint32_t
num_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
// one cu can hold one wg at one time, from the whole chip's point of view
uint32_t
dp_tiles
,
dp_num_blocks
,
sk_total_iters
;
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// default to regular DP GEMM if sk blocks == 0
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
sk_num_blocks
=
sk_blocks
;
// dispatch)
if
(
sk_num_blocks
==
0
||
sk_num_blocks
==
0xFFFFFFFF
)
//
uint32_t
full_dispatches
=
num_tiles
/
num_cu
;
uint32_t
full_dispatch_tiles
=
full_dispatches
*
num_cu
;
uint32_t
partial_dispatche_tiles
=
num_tiles
-
full_dispatch_tiles
;
uint32_t
sk_occupancy
=
occupancy
;
uint32_t
dp_tiles
=
full_dispatch_tiles
;
uint32_t
sk_tiles
=
partial_dispatche_tiles
;
if
(
full_dispatches
<
occupancy
)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy
=
1
;
// TODO: single occ seems better
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
if
((
occupancy
>
1
)
&&
(
full_dispatches
%
occupancy
==
occupancy
-
1
))
{
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
sk_num_blocks
=
0
;
// occupancy = 3, full_dispatches = 5, 8, 11 ...
dp_tiles
=
num_tiles
;
// occupancy = 4, full_dispatches = 7, 11 ...
sk_num_big_blocks
=
0
;
sk_occupancy
=
1
;
// left 1 slot for sk occupancy
k_iters_per_big_block
=
0
;
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
}
// 2-tile sk + DP GEMM
else
else
{
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// grid size
// to construct sk dispatch
uint32_t
grid_size
=
occupancy
*
num_cu
;
sk_occupancy
=
occupancy
-
((
full_dispatches
-
1
)
%
occupancy
);
// check if there's enough work for DP+ stream-k
dp_tiles
=
full_dispatch_tiles
-
num_cu
;
bool
bigEnough
=
num_tiles
>
grid_size
;
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
// max of 2 sk tiles per block
uint32_t
sk_tiles
=
bigEnough
?
grid_size
+
num_tiles
%
grid_size
:
num_tiles
;
// remaining tiles are DP tiles
dp_tiles
=
bigEnough
?
(
num_tiles
-
sk_tiles
)
:
0
;
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
}
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
uint32_t
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
// using multiple blocks for parallel reduction
uint32_t
dp_num_blocks
=
0
;
{
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
uint32_t
max_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
*
sk_occupancy
:
math
::
min
(
num_cu
,
sk_total_iters
/
min_k_iters_per_sk_block
);
// if use dp for sk-block, how many iters do we need
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
std
::
numeric_limits
<
int
>::
max
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
uint32_t
tentative_sk_iters_per_block
=
(
sk_total_iters
+
tentative_sk_blocks
-
1
)
/
tentative_sk_blocks
;
uint32_t
tentative_sk_iters
=
tentative_sk_iters_per_block
;
uint32_t
sk_blocks_per_tile
=
(
tentative_sk_blocks
+
sk_tiles
-
1
)
/
sk_tiles
;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t
cross_sk_blocks_overhead
=
sk_blocks_per_tile
;
if
(
tentative_sk_blocks
%
sk_tiles
!=
0
)
{
// penalty for uneven divide
cross_sk_blocks_overhead
+=
sk_blocks_per_tile
*
tentative_sk_iters_per_block
/
50
;
}
uint32_t
tentative_sk_score
=
tentative_sk_iters
+
cross_sk_blocks_overhead
;
if
(
tentative_sk_score
<
best_sk_score
)
{
best_sk_score
=
tentative_sk_score
;
sk_num_blocks
=
tentative_sk_blocks
;
}
}
if
(
best_sk_score
>=
dp_for_sk_iters
)
{
sk_num_blocks
=
0
;
}
// give a chance to control num of sk blocks
sk_num_blocks
=
sk_blocks
!=
0xffffffff
?
sk_blocks
:
sk_num_blocks
;
if
(
sk_num_blocks
==
0
)
{
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
else
{
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
}
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
reduction_start_block_idx
=
dp_start_block_idx
+
dp_num_blocks
;
reduction_start_block_idx
=
dp_start_block_idx
+
dp_num_blocks
;
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
...
@@ -1157,13 +1084,14 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1157,13 +1084,14 @@ struct BlockToCTileMap_GemmStreamK
}
}
#if 0
#if 0
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
printf("cu:%d, occupancy:%d, grids
ize
:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d,
dp_iters_per_block:%d,
dp_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n",
"sk_tiles:%u, workspace(acc float):%u\n",
num_cu,
num_cu,
occupancy,
occupancy,
// get_grid_dims(num_cu, occupancy).x,
get_grid_dims().x,
get_grid_dims().x,
num_tiles,
num_tiles,
dp_tiles,
dp_tiles,
...
@@ -1171,7 +1099,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1171,7 +1099,7 @@ struct BlockToCTileMap_GemmStreamK
sk_num_blocks,
sk_num_blocks,
sk_total_iters,
sk_total_iters,
dp_start_block_idx,
dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks,
dp_num_blocks,
k_iters_per_tile.get(),
k_iters_per_tile.get(),
k_iters_per_big_block,
k_iters_per_big_block,
...
@@ -1195,7 +1123,8 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1195,7 +1123,8 @@ struct BlockToCTileMap_GemmStreamK
return
k_iters_per_tile
.
div
(
sk_total_iters
);
return
k_iters_per_tile
.
div
(
sk_total_iters
);
}
}
__host__
__device__
dim3
get_grid_dims
()
const
// __host__ __device__ constexpr dim3 get_grid_dims(int num_cu, int occupancy) const
__host__
__device__
constexpr
dim3
get_grid_dims
()
const
{
{
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
{
...
@@ -1203,6 +1132,16 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1203,6 +1132,16 @@ struct BlockToCTileMap_GemmStreamK
}
}
else
else
return
dim3
(
reduction_start_block_idx
,
1
,
1
);
return
dim3
(
reduction_start_block_idx
,
1
,
1
);
// return dim3(num_cu * occupancy, 1, 1); // HS
}
__host__
__device__
uint32_t
total_blocks_allocated
()
const
{
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
return
__builtin_amdgcn_readfirstlane
(
reduction_start_block_idx
+
get_sk_tiles
());
}
else
return
__builtin_amdgcn_readfirstlane
(
reduction_start_block_idx
);
}
}
__device__
uint32_t
get_block_idx
()
const
__device__
uint32_t
get_block_idx
()
const
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
5b1e2442
...
@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t
StrideA
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideB
;
index_t
StrideC
;
index_t
StrideC
;
index_t
num_cu
,
occupancy
;
// stream-k arguments
Block2CTileMap
block_mapping
;
Block2CTileMap
block_mapping
;
Argument
(
const
FloatAB
*
p_a_grid_
,
Argument
(
const
FloatAB
*
p_a_grid_
,
...
@@ -156,8 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -156,8 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t
StrideA_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
StrideC_
,
uint32_t
num_cu
,
uint32_t
num_cu
_
,
uint32_t
occupancy
,
uint32_t
occupancy
_
,
uint32_t
num_sk_blocks_
)
uint32_t
num_sk_blocks_
)
:
p_a_grid
(
p_a_grid_
),
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
p_b_grid
(
p_b_grid_
),
...
@@ -168,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -168,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA
(
StrideA_
),
StrideA
(
StrideA_
),
StrideB
(
StrideB_
),
StrideB
(
StrideB_
),
StrideC
(
StrideC_
),
StrideC
(
StrideC_
),
block_mapping
(
M
,
N
,
K
,
num_cu
,
occupancy
,
num_sk_blocks_
)
num_cu
(
num_cu_
),
occupancy
(
occupancy_
),
block_mapping
(
M
,
N
,
K
,
num_cu_
,
occupancy_
,
num_sk_blocks_
)
{
{
}
}
...
@@ -452,16 +455,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -452,16 +455,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Block2CTileMap
block_mapping
,
Block2CTileMap
block_mapping
,
void
*
__restrict__
p_shared_block
)
void
*
__restrict__
p_shared_block
)
{
{
uint32_t
m
=
M
;
uint32_t
m
=
M
;
uint32_t
n
=
N
;
uint32_t
n
=
N
;
uint32_t
k
=
K
;
uint32_t
k
=
K
;
uint32_t
pad_m
=
(
m
+
MPerBlock
-
1
)
/
MPerBlock
*
MPerBlock
;
uint32_t
pad_m
=
(
m
+
MPerBlock
-
1
)
/
MPerBlock
*
MPerBlock
;
uint32_t
pad_n
=
(
n
+
NPerBlock
-
1
)
/
NPerBlock
*
NPerBlock
;
uint32_t
pad_n
=
(
n
+
NPerBlock
-
1
)
/
NPerBlock
*
NPerBlock
;
uint32_t
pad_k
=
(
k
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
uint32_t
pad_k
=
(
k
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
uint32_t
stride_a
=
StrideA
;
uint32_t
stride_a
=
StrideA
;
uint32_t
stride_b
=
StrideB
;
uint32_t
stride_b
=
StrideB
;
uint32_t
stride_c
=
StrideC
;
uint32_t
stride_c
=
StrideC
;
uint32_t
block_idx
=
block_mapping
.
get_block_idx
();
const
auto
a_k0_m_k1_grid_desc
=
MakeAGridDescriptor_K0_M_K1
(
m
,
pad_m
,
k
,
pad_k
,
stride_a
);
const
auto
a_k0_m_k1_grid_desc
=
MakeAGridDescriptor_K0_M_K1
(
m
,
pad_m
,
k
,
pad_k
,
stride_a
);
const
auto
b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_K0_N_K1
(
k
,
pad_k
,
n
,
pad_n
,
stride_b
);
const
auto
b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_K0_N_K1
(
k
,
pad_k
,
n
,
pad_n
,
stride_b
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
m
,
pad_m
,
n
,
pad_n
,
stride_c
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
m
,
pad_m
,
n
,
pad_n
,
stride_c
);
...
@@ -520,623 +523,640 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -520,623 +523,640 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v3
();
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v3
();
uint32_t
block_idx
=
block_mapping
.
get_block_idx
();
bool
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
bool
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
&&
block_idx
<
block_mapping
.
reduction_start_block_idx
;
bool
is_reduction_block
=
block_idx
>=
block_mapping
.
reduction_start_block_idx
;
bool
is_padding_block
=
block_idx
>=
block_mapping
.
sk_num_blocks
&&
block_idx
<
block_mapping
.
dp_start_block_idx
;
uint32_t
iter_start
,
iter_end
;
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
if
(
is_padding_block
)
return
;
uint32_t
*
p_semaphore
=
uint32_t
*
p_semaphore
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
block_mapping
.
get_workspace_size_for_acc
(
sizeof
(
FloatAcc
)));
block_mapping
.
get_workspace_size_for_acc
(
sizeof
(
FloatAcc
)));
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_mapping
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
,
is_padding_block
,
is_reduction_block
;
uint32_t
total_iter_length
;
#pragma unroll
// stream-k: for new work for all the persistent blocks.
for
(;
block_idx
<
block_mapping
.
total_blocks_allocated
();
block_idx
+=
gridDim
.
x
)
{
{
if
(
is_reduction_block
)
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
&&
block_idx
<
block_mapping
.
reduction_start_block_idx
;
is_padding_block
=
block_idx
>=
block_mapping
.
sk_num_blocks
&&
block_idx
<
block_mapping
.
dp_start_block_idx
;
if
(
is_padding_block
)
{
continue
;
}
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
total_iter_length
=
iter_end
-
iter_start
;
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
{
// descriptors
is_reduction_block
=
block_idx
>=
block_mapping
.
reduction_start_block_idx
;
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
if
(
is_reduction_block
)
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
{
const
auto
reduce_thread_cluster_idx
=
// descriptors
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
const
auto
reduce_thread_cluster_idx
=
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
constexpr
auto
MReduceIters
=
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
cluster_length_reduce
.
At
(
I1
)
*
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{});
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
cluster_length_reduce
.
At
(
I1
)
*
make_tuple
(
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{});
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
partial_acc_load_step_n
=
make_multi_index
(
I1
,
I1
,
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
0
,
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
constexpr
auto
partial_acc_load_step_n_reverse
=
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
make_multi_index
(
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
constexpr
auto
partial_acc_load_step_n
=
make_multi_index
(
CBlockTransferScalarPerVector_NWaveNPerXDL
);
0
,
constexpr
auto
partial_acc_load_step_m
=
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
make_multi_index
(
cluster_length_reduce
.
At
(
I0
),
0
);
constexpr
auto
partial_acc_load_step_n_reverse
=
make_multi_index
(
0
,
constexpr
auto
partial_acc_store_step_n
=
make_multi_index
(
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
0
,
CBlockTransferScalarPerVector_NWaveNPerXDL
);
0
,
constexpr
auto
partial_acc_load_step_m
=
0
,
make_multi_index
(
cluster_length_reduce
.
At
(
I0
),
0
);
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
constexpr
auto
partial_acc_store_step_n_reverse
=
constexpr
auto
partial_acc_store_step_n
=
make_multi_index
(
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
CBlockTransferScalarPerVector_NWaveNPerXDL
);
constexpr
auto
partial_acc_store_step_n_reverse
=
make_multi_index
(
constexpr
auto
partial_acc_store_step_m
=
0
,
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I0
),
0
,
0
);
0
,
0
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
FloatAcc
,
CBlockTransferScalarPerVector_NWaveNPerXDL
);
CBlockTransferScalarPerVector_NWaveNPerXDL
,
constexpr
auto
partial_acc_store_step_m
=
true
>
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I0
),
0
,
0
);
parcial_acc_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
FloatAcc
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
true
>
true
>
acc_buf
;
parcial_acc_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
// start to compute
FloatAcc
,
auto
reduction_idx
=
blockIdx
.
x
-
block_mapping
.
reduction_start_block_idx
;
CBlockTransferScalarPerVector_NWaveNPerXDL
,
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
reduction_idx
,
m
,
n
);
true
>
acc_buf
;
workgroup_barrier
wg_barrier
(
p_semaphore
);
// start to compute
uint32_t
tile_acc_offset_start
=
auto
reduction_idx
=
block_idx
-
block_mapping
.
reduction_start_block_idx
;
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
reduction_idx
,
m
,
n
);
uint32_t
tile_acc_offset_end
=
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
+
1
);
workgroup_barrier
wg_barrier
(
p_semaphore
);
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
uint32_t
tile_acc_offset_start
=
FloatAcc
,
// SrcData,
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
FloatAcc
,
// DstData,
uint32_t
tile_acc_offset_end
=
decltype
(
c_partial_acc_block_m_n
),
// SrcDesc,
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
+
1
);
decltype
(
acc_thread_buf_load_desc
),
// DstDesc,
Sequence
<
1
,
CBlockTransferScalarPerVector_NWaveNPerXDL
>
,
// SliceLengths,
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
Sequence
<
0
,
1
>
,
// DimAccessOrder,
FloatAcc
,
// SrcData,
1
,
// SrcVectorDim,
FloatAcc
,
// DstData,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// SrcScalarPerVector,
decltype
(
c_partial_acc_block_m_n
),
// SrcDesc,
1
,
// SrcScalarStrideInVector,
decltype
(
acc_thread_buf_load_desc
),
// DstDesc,
false
// SrcResetCoordinateAfterRun,
Sequence
<
1
,
CBlockTransferScalarPerVector_NWaveNPerXDL
>
,
// SliceLengths,
>
{
c_partial_acc_block_m_n
,
Sequence
<
0
,
1
>
,
// DimAccessOrder,
make_multi_index
(
thread_m_cluster_id
,
1
,
// SrcVectorDim,
thread_n_cluster_id
*
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// SrcScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXDL
)};
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
>
{
c_partial_acc_block_m_n
,
FloatAcc
,
// SrcData,
make_multi_index
(
thread_m_cluster_id
,
FloatC
,
// DstData,
thread_n_cluster_id
*
decltype
(
acc_thread_buf_store_desc
),
// SrcDesc,
CBlockTransferScalarPerVector_NWaveNPerXDL
)};
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
CElementwiseOperation
,
// ElementwiseOperation,
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
Sequence
<
1
,
1
,
1
,
CBlockTransferScalarPerVector_NWaveNPerXDL
>
,
// SliceLengths,
FloatAcc
,
// SrcData,
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder,
FloatC
,
// DstData,
3
,
// DstVectorDim,
decltype
(
acc_thread_buf_store_desc
),
// SrcDesc,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// DstScalarPerVector,
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
CElementwiseOperation
,
// ElementwiseOperation,
1
,
// DstScalarStrideInVector,
Sequence
<
1
,
false
// DstResetCoordinateAfterRun,
1
,
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
1
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
CBlockTransferScalarPerVector_NWaveNPerXDL
>
,
// SliceLengths,
thread_m_cluster_id
,
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
3
,
// DstVectorDim,
thread_n_cluster_id
*
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// DstScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXDL
),
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
CElementwiseOperation
{}};
1
,
// DstScalarStrideInVector,
false
// DstResetCoordinateAfterRun,
// block synchronization
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
wg_barrier
.
wait_eq
(
reduction_idx
,
tile_acc_offset_end
-
tile_acc_offset_start
);
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
thread_m_cluster_id
,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
thread_n_cluster_id
*
CBlockTransferScalarPerVector_NWaveNPerXDL
),
CElementwiseOperation
{}};
// block synchronization
wg_barrier
.
wait_eq
(
reduction_idx
,
tile_acc_offset_end
-
tile_acc_offset_start
);
#if 0
#if 0
if(threadIdx.x == 0) {
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(block
Idx.
x),
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(block
_id
x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
}
#endif
#endif
using
Accumulation
=
ck
::
detail
::
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
FloatAcc
>
;
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
FloatAcc
>
;
for
(
int
i_m
=
0
;
i_m
<
MReduceIters
;
i_m
++
)
for
(
int
i_m
=
0
;
i_m
<
MReduceIters
;
i_m
++
)
{
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
acc_buf
.
Clear
();
acc_buf
.
Clear
();
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
{
{
auto
c_partial_acc_buf
=
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
FloatAcc
*>
(
p_workspace
)
+
reinterpret_cast
<
FloatAcc
*>
(
p_workspace
)
+
i
*
c_partial_acc_block_m_n
.
GetElementSpaceSize
(),
i
*
c_partial_acc_block_m_n
.
GetElementSpaceSize
(),
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
acc_load
.
Run
(
c_partial_acc_block_m_n
,
acc_load
.
Run
(
c_partial_acc_block_m_n
,
c_partial_acc_buf
,
c_partial_acc_buf
,
acc_thread_buf_load_desc
,
acc_thread_buf_load_desc
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
parcial_acc_buf
);
parcial_acc_buf
);
static_for
<
0
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
1
>
{}(
static_for
<
0
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
1
>
{}(
[
&
](
auto
i_vec
)
{
[
&
](
auto
i_vec
)
{
constexpr
auto
offset
=
constexpr
auto
offset
=
acc_thread_buf_load_desc
.
CalculateOffset
(
acc_thread_buf_load_desc
.
CalculateOffset
(
make_tuple
(
0
,
i_vec
));
make_tuple
(
0
,
i_vec
));
Accumulation
::
Calculate
(
acc_buf
(
Number
<
offset
>
{}),
Accumulation
::
Calculate
(
acc_buf
(
Number
<
offset
>
{}),
parcial_acc_buf
[
Number
<
offset
>
{}]);
parcial_acc_buf
[
Number
<
offset
>
{}]);
});
});
}
}
if
(
thread_n_cluster_id
*
CBlockTransferScalarPerVector_NWaveNPerXDL
<
if
(
thread_n_cluster_id
*
CBlockTransferScalarPerVector_NWaveNPerXDL
<
NPerBlock
)
NPerBlock
)
{
acc_store
.
Run
(
acc_thread_buf_store_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
acc_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
if
constexpr
(
NReduceIters
!=
1
)
{
if
constexpr
(
i_n_reduce
!=
(
NReduceIters
-
1
))
{
{
acc_
load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
acc_
store
.
Run
(
acc_thread_buf_store_desc
,
partial_acc_load_step_n
);
make_tuple
(
I0
,
I0
,
I0
,
I0
),
acc_store
.
MoveDstSliceWindow
(
acc_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n
);
c_grid_buf
);
}
}
else
if
constexpr
(
NReduceIters
!=
1
)
{
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
if
constexpr
(
i_n_reduce
!=
(
NReduceIters
-
1
))
partial_acc_load_step_n_reverse
);
{
acc_store
.
MoveDstSliceWindow
(
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_load_step_n
);
partial_acc_store_step_n_reverse
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n
);
}
else
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n_reverse
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n_reverse
);
}
}
}
});
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_m
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_m
);
}
}
});
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_m
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_m
);
}
}
return
;
}
}
return
;
}
}
}
while
(
true
)
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_mapping
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
while
(
true
)
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
block_mapping
.
get_current_iter_length
(
iter_start
,
iter_end
,
total_iter_length
));
uint32_t
tile_idx
,
iter_offset
;
block_mapping
.
get_tile_idx_with_offset
(
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
tile_idx
,
m
,
n
);
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]
*
NPerBlock
);
const
index_t
k0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
iter_offset
*
K0PerBlock
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k0_m_k1_grid_desc
,
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k0_n_k1_grid_desc
,
make_multi_index
(
k0_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
const
index_t
num_k_block_main_loop
=
current_iter_length
;
gridwise_gemm_pipeline
.
Run
(
a_k0_m_k1_grid_desc
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_k0_n_k1_grid_desc
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// output: register to global memory
{
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
block_mapping
.
get_current_iter_length
(
iter_start
,
iter_end
,
total_iter_length
));
uint32_t
tile_idx
,
iter_offset
;
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
block_mapping
.
get_tile_idx_with_offset
(
iter_end
-
1
,
tile_idx
,
iter_offset
);
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
tile_idx
,
m
,
n
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]
*
MPerBlock
);
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
const
index_t
n_block_data_idx_on_grid
=
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]
*
NPerBlock
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
const
index_t
k0_block_data_idx_on_grid
=
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
__builtin_amdgcn_readfirstlane
(
iter_offset
*
K0PerBlock
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
constexpr
auto
c_block_desc_mblock_mpershuffle_nblock_npershuffle
=
ThisThreadBlock
,
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle
();
AElementwiseOperation
,
constexpr
auto
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
=
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
FloatCShuffle
*>
(
p_shared_block
),
c_block_desc_mblock_mpershuffle_nblock_npershuffle
.
GetElementSpaceSize
());
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
FloatAcc
*>
(
p_workspace
)
+
block_acc_offset
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_unmerge_transform
(
make_tuple
(
CShuffleMRepeatPerShuffle
,
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform
(
I0
),
// freeze nblock
make_unmerge_transform
(
make_tuple
(
CShuffleNRepeatPerShuffle
,
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatCShuffle
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
InMemoryDataOperationEnum
::
Set
,
CShuffleNRepeatPerShuffle
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
I1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
I1
,
ABlockTransferThreadClusterArrangeOrder
,
M2
,
FloatAB
,
I1
,
FloatAB
,
M4
,
decltype
(
a_k0_m_k1_grid_desc
),
I1
>
,
decltype
(
a_block_desc_k0_m_k1
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
ABlockTransferSrcAccessOrder
,
7
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k0_m_k1_grid_desc
,
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
BThreadTransferSrcResetCoordinateAfterRun
,
make_multi_index
(
0
,
true
>
(
b_k0_n_k1_grid_desc
,
0
,
make_multi_index
(
k0_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
0
),
m_thread_data_on_block_idx
[
I1
],
b_element_op
,
n_thread_data_on_block_idx
[
I1
],
b_block_desc_k0_n_k1
,
m_thread_data_on_block_idx
[
I2
],
make_multi_index
(
0
,
0
,
0
),
m_thread_data_on_block_idx
[
I3
],
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
const
index_t
num_k_block_main_loop
=
current_iter_length
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
gridwise_gemm_pipeline
.
Run
(
a_k0_m_k1_grid_desc
,
// LDS to global
a_block_desc_k0_m_k1
,
auto
c_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
a_blockwise_copy
,
ThisThreadBlock
,
// index_t BlockSize,
a_grid_buf
,
CElementwiseOperation
,
// ElementwiseOperation,
a_block_buf
,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
a_block_slice_copy_step
,
Sequence
<
1
,
b_k0_n_k1_grid_desc
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
b_block_desc_k0_n_k1
,
1
,
b_blockwise_copy
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
b_grid_buf
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
b_block_buf
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
b_block_slice_copy_step
,
FloatCShuffle
,
// typename SrcData,
blockwise_gemm
,
FloatC
,
// typename DstData,
c_thread_buf
,
decltype
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
),
num_k_block_main_loop
);
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
// output: register to global memory
3
,
// index_t VectorDim,
{
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
make_multi_index
(
0
,
0
,
0
,
0
),
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
0
,
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
0
),
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
c_element_op
};
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
// LDS to global partial acc
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
auto
c_block_copy_lds_to_partial_acc
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
ThisThreadBlock
,
// index_t BlockSize,
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
CElementwiseOperation
,
// ElementwiseOperation,
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
// InMemoryDataOperationEnum::Set, // DstInMemOp,
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
constexpr
auto
c_block_desc_mblock_mpershuffle_nblock_npershuffle
=
1
,
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle
();
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
constexpr
auto
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
=
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle
();
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename DstData,
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
decltype
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
),
reinterpret_cast
<
FloatCShuffle
*>
(
p_shared_block
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
c_block_desc_mblock_mpershuffle_nblock_npershuffle
.
GetElementSpaceSize
());
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
auto
c_partial_acc_buf
=
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false,
reinterpret_cast
<
FloatAcc
*>
(
p_workspace
)
+
block_acc_offset
,
// othre wise has scratch
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
false
>
// bool ThreadTransferDstResetCoordinateAfterRun, => need to be false,
.
GetElementSpaceSize
());
// othre wise has scratch
{
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
make_multi_index
(
0
,
0
,
0
,
0
),
transform_tensor_descriptor
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_element_op
};
constexpr
auto
mxdlperwave_forward_step
=
make_multi_index
(
0
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
0
,
0
);
constexpr
auto
nxdlperwave_forward_step
=
make_multi_index
(
0
,
0
,
0
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
constexpr
auto
nxdlperwave_backward_step
=
make_multi_index
(
0
,
0
,
0
,
-
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
static_for
<
0
,
MRepeat
,
CShuffleMRepeatPerShuffle
>
{}([
&
](
auto
mxdlperwave_iter
)
{
constexpr
auto
mxdlperwave
=
mxdlperwave_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}([
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
constexpr
index_t
nxdlperwave_value
=
nxdlperwave_forward_sweep
?
nxdlperwave_iter
:
(
NRepeat
-
nxdlperwave_iter
-
CShuffleNRepeatPerShuffle
);
constexpr
auto
nxdlperwave
=
Number
<
nxdlperwave_value
>
{};
// make sure it's safe to do ds_write
block_sync_lds
();
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
mxdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
// make sure it's safe to do ds_read
block_sync_lds
();
c_block_copy_lds_to_global
.
SetSrcSliceOrigin
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
make_tuple
(
0
,
0
,
0
,
0
));
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_unmerge_transform
(
// LDS to global
make_tuple
(
CShuffleMRepeatPerShuffle
,
if
(
is_dp_block
)
M1
,
c_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_block_buf
),
M2
,
decltype
(
c_grid_buf
),
M3
,
InMemoryDataOperationEnum
::
Set
>(
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
make_freeze_transform
(
I0
),
// freeze nblock
c_block_buf
,
make_unmerge_transform
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
CShuffleNRepeatPerShuffle
,
c_grid_buf
);
N1
,
else
if
(
is_sk_block
)
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
{
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
make_tuple
(
Sequence
<>
{},
StreamKReductionStrategy
::
Reduction
)
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
{
Sequence
<>
{},
// constexpr offset
Sequence
<
1
,
3
,
7
>
{}));
c_block_copy_lds_to_partial_acc
.
SetSrcSliceOrigin
(
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatCShuffle
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
auto
c_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
0
,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
0
),
c_element_op
};
// LDS to global partial acc
auto
c_block_copy_lds_to_partial_acc
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename DstData,
decltype
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false
>
// bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_element_op
};
constexpr
auto
mxdlperwave_forward_step
=
make_multi_index
(
0
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
0
,
0
);
constexpr
auto
nxdlperwave_forward_step
=
make_multi_index
(
0
,
0
,
0
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
constexpr
auto
nxdlperwave_backward_step
=
make_multi_index
(
0
,
0
,
0
,
-
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
static_for
<
0
,
MRepeat
,
CShuffleMRepeatPerShuffle
>
{}([
&
](
auto
mxdlperwave_iter
)
{
constexpr
auto
mxdlperwave
=
mxdlperwave_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}(
[
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
constexpr
index_t
nxdlperwave_value
=
nxdlperwave_forward_sweep
?
nxdlperwave_iter
:
(
NRepeat
-
nxdlperwave_iter
-
CShuffleNRepeatPerShuffle
);
constexpr
auto
nxdlperwave
=
Number
<
nxdlperwave_value
>
{};
// make sure it's safe to do ds_write
block_sync_lds
();
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
mxdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
// make sure it's safe to do ds_read
block_sync_lds
();
c_block_copy_lds_to_global
.
SetSrcSliceOrigin
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
make_tuple
(
0
,
0
,
0
,
0
));
make_tuple
(
0
,
0
,
0
,
0
));
c_block_copy_lds_to_partial_acc
.
SetDstSliceOrigin
(
// LDS to global
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
if
(
is_dp_block
)
make_tuple
(
mxdlperwave
.
value
,
0
,
nxdlperwave
.
value
,
0
));
c_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_block_buf
),
c_block_copy_lds_to_partial_acc
decltype
(
c_grid_buf
),
.
template
Run
<
decltype
(
c_block_buf
),
InMemoryDataOperationEnum
::
Set
>(
decltype
(
c_partial_acc_buf
),
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
InMemoryDataOperationEnum
::
Set
>(
c_block_buf
,
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_block_buf
,
c_grid_buf
);
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
else
if
(
is_sk_block
)
c_partial_acc_buf
);
{
}
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
else
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
StreamKReductionStrategy
::
Atomic
)
{
{
// constexpr offset
c_block_copy_lds_to_global
c_block_copy_lds_to_partial_acc
.
SetSrcSliceOrigin
(
.
template
Run
<
decltype
(
c_block_buf
),
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
decltype
(
c_grid_buf
),
make_tuple
(
0
,
0
,
0
,
0
));
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_block_copy_lds_to_partial_acc
.
SetDstSliceOrigin
(
c_block_buf
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_tuple
(
mxdlperwave
.
value
,
0
,
nxdlperwave
.
value
,
0
));
c_block_copy_lds_to_partial_acc
.
template
Run
<
decltype
(
c_block_buf
),
decltype
(
c_partial_acc_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_block_buf
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
c_partial_acc_buf
);
}
else
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
{
c_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
}
// move on nxdlperwave dimension
if
constexpr
(
nxdlperwave_forward_sweep
&&
(
nxdlperwave
<
NRepeat
-
CShuffleNRepeatPerShuffle
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
nxdlperwave_forward_step
);
}
}
}
else
if
constexpr
((
!
nxdlperwave_forward_sweep
)
&&
(
nxdlperwave
>
0
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_backward_step
);
}
});
// move on nxdlperwave dimension
// move on mxdlperwave dimension
if
constexpr
(
nxdlperwave_forward_sweep
&&
if
constexpr
(
mxdlperwave
<
MRepeat
-
CShuffleMRepeatPerShuffle
)
(
nxdlperwave
<
NRepeat
-
CShuffleNRepeatPerShuffle
))
{
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_forward_step
);
mxdlperwave_forward_step
);
}
else
if
constexpr
((
!
nxdlperwave_forward_sweep
)
&&
(
nxdlperwave
>
0
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_backward_step
);
}
}
});
});
// move on mxdlperwave dimension
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
if
constexpr
(
mxdlperwave
<
MRepeat
-
CShuffleMRepeatPerShuffle
)
StreamKReductionStrategy
::
Reduction
)
{
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
if
(
is_sk_block
)
c_grid_desc_mblock_mperblock_nblock_nperblock
,
{
mxdlperwave_forward_step
);
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
tile_idx
);
}
}
}
});
}
// exit condition
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
StreamKReductionStrategy
::
Reduction
)
{
{
if
(
is_sk_block
)
block_acc_offset
-=
MPerBlock
*
NPerBlock
;
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
tile_idx
);
}
}
}
// make sure next loop LDS is ready for use
block_sync_lds
();
}
}
// exit condition
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
block_acc_offset
-=
MPerBlock
*
NPerBlock
;
}
// make sure next loop LDS is ready for use
block_sync_lds
();
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment