Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ca8b5c79
Commit
ca8b5c79
authored
May 25, 2023
by
carlushuang
Browse files
update reduction for streamk(not ready yet)
parent
b2a49620
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
670 additions
and
59 deletions
+670
-59
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
...n/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
+10
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+37
-1
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+176
-27
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+323
-24
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+32
-7
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+16
-0
include/ck/utility/workgroup_barrier.hpp
include/ck/utility/workgroup_barrier.hpp
+76
-0
No files found.
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
View file @
ca8b5c79
...
@@ -111,6 +111,16 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2
...
@@ -111,6 +111,16 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2
}
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_slice_origin_idx
);
}
private:
private:
static
constexpr
auto
thread_cluster_desc_
=
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
View file @
ca8b5c79
...
@@ -141,7 +141,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -141,7 +141,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
// TODO: remove clear buffer for streamk kernels
// TODO: remove clear buffer for streamk kernels
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
if
constexpr
(
GridwiseGemm
::
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
{
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
}
else
if
constexpr
(
GridwiseGemm
::
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
char
*
workspace_semaphore
=
reinterpret_cast
<
char
*>
(
karg
.
p_workspace_
);
workspace_semaphore
=
workspace_semaphore
+
karg
.
block_mapping
.
get_workspace_size_for_acc
(
sizeof
(
GridwiseGemm
::
FloatAcc
));
hipGetErrorString
(
hipMemset
(
workspace_semaphore
,
0
,
karg
.
block_mapping
.
get_workspace_size_for_semaphore
()));
}
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -151,6 +165,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -151,6 +165,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
karg
.
p_a_grid
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
p_c_grid
,
karg
.
p_workspace_
,
karg
.
M
,
karg
.
M
,
karg
.
N
,
karg
.
N
,
karg
.
K
,
karg
.
K
,
...
@@ -170,6 +185,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -170,6 +185,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
}
}
};
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
p_arg
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
GridwiseGemm
::
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
return
p_arg
->
block_mapping
.
get_workspace_size
(
sizeof
(
GridwiseGemm
::
FloatAcc
));
}
else
{
return
0
;
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
}
static
constexpr
bool
IsValidCompilationParameter
()
static
constexpr
bool
IsValidCompilationParameter
()
{
{
// TODO: properly implement this check
// TODO: properly implement this check
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
ca8b5c79
...
@@ -637,44 +637,53 @@ struct BlockToCTileMap_3DGrid_KSplit
...
@@ -637,44 +637,53 @@ struct BlockToCTileMap_3DGrid_KSplit
}
}
};
};
enum
StreamKReductionStrategy
{
Atomic
=
0
,
// sk block use atomic to do reduction
Reduction
,
// let some workgroup responsible for doing the reduction operation
};
template
<
uint32_t
MPerBlock_
,
template
<
uint32_t
MPerBlock_
,
uint32_t
NPerBlock_
,
uint32_t
NPerBlock_
,
uint32_t
KPerBlock_
,
uint32_t
KPerBlock_
,
uint32_t
TileSwizzleSubM_
=
8
>
StreamKReductionStrategy
ReductionStrategy_
=
StreamKReductionStrategy
::
Atomic
,
uint32_t
TileSwizzleSubM_
=
8
>
struct
BlockToCTileMap_GemmStreamK
struct
BlockToCTileMap_GemmStreamK
{
{
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
2
;
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
2
;
static
constexpr
uint32_t
MPerBlock
=
MPerBlock_
;
static
constexpr
uint32_t
MPerBlock
=
MPerBlock_
;
static
constexpr
uint32_t
NPerBlock
=
NPerBlock_
;
static
constexpr
uint32_t
NPerBlock
=
NPerBlock_
;
static
constexpr
uint32_t
KPerBlock
=
KPerBlock_
;
static
constexpr
uint32_t
KPerBlock
=
KPerBlock_
;
static
constexpr
uint32_t
tile_swizzle_sub_m
=
TileSwizzleSubM_
;
static
constexpr
StreamKReductionStrategy
ReductionStrategy
=
ReductionStrategy_
;
static
constexpr
uint32_t
tile_swizzle_sub_m
=
TileSwizzleSubM_
;
//--------------------------------------
//--------------------------------------
// pass to device
// pass to device
uint32_t
sk_num_blocks
;
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
sk_total_iters
;
//
uint32_t sk_total_iters;
uint32_t
dp_start_block_idx
;
uint32_t
dp_start_block_idx
;
uint32_t
dp_iters_per_block
;
// uint32_t dp_iters_per_block;
uint32_t
dp_num_blocks
;
// uint32_t dp_num_blocks;
uint32_t
reduction_start_block_idx
;
uint32_t
k_iters_per_big_block
;
uint32_t
k_iters_per_big_block
;
// uint32_t tiles_cover_big_blocks; // for reduction
// uint32_t total_acc_buffers; // for reduction
MDiv2
n_tiles
;
MDiv
k_iters_per_tile
;
MDiv
k_iters_per_tile
;
MDiv
n_tiles
;
MDiv
eqav_tiles_big
;
// for reduction
MDiv
eqav_tiles_little
;
// for reduction
// MDiv tile_swizzle_sub_m_rem;
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
//--------------------------------------
static
int
env_get_int
(
const
char
*
var_name
,
int
default_int
)
{
char
*
v
=
getenv
(
var_name
);
int
r
=
default_int
;
if
(
v
)
r
=
atoi
(
v
);
return
r
;
}
// prefer construct on host
// prefer construct on host
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
n
,
...
@@ -727,8 +736,9 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -727,8 +736,9 @@ struct BlockToCTileMap_GemmStreamK
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
}
}
dp_iters_per_block
=
k_iters_per_tile
.
get
();
uint32_t
dp_iters_per_block
=
k_iters_per_tile
.
get
();
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
uint32_t
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
uint32_t
dp_num_blocks
=
0
;
{
{
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
...
@@ -775,7 +785,6 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -775,7 +785,6 @@ struct BlockToCTileMap_GemmStreamK
// give a chance to control num of sk blocks
// give a chance to control num of sk blocks
sk_num_blocks
=
sk_blocks
!=
0xffffffff
?
sk_blocks
:
sk_num_blocks
;
sk_num_blocks
=
sk_blocks
!=
0xffffffff
?
sk_blocks
:
sk_num_blocks
;
sk_num_blocks
=
env_get_int
(
"sk_num_blocks"
,
sk_num_blocks
);
if
(
sk_num_blocks
==
0
)
if
(
sk_num_blocks
==
0
)
{
{
...
@@ -807,7 +816,16 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -807,7 +816,16 @@ struct BlockToCTileMap_GemmStreamK
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
}
}
}
n_tiles
=
MDiv
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
reduction_start_block_idx
=
dp_start_block_idx
+
dp_num_blocks
;
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
uint32_t
upper_big
=
math
::
lcm
(
k_iters_per_big_block
,
k_iters_per_tile
.
get
());
uint32_t
upper_little
=
math
::
lcm
(
k_iters_per_big_block
-
1
,
k_iters_per_tile
.
get
());
eqav_tiles_big
=
MDiv
(
upper_big
/
k_iters_per_tile
.
get
());
eqav_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
}
// tile_swizzle_sub_m_rem =
// tile_swizzle_sub_m_rem =
// MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
// MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
...
@@ -831,9 +849,28 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -831,9 +849,28 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block
);
k_iters_per_big_block
);
}
}
__host__
__device__
uint32_t
get_sk_total_iters
()
const
{
uint32_t
sk_total_iters
=
sk_num_big_blocks
*
k_iters_per_big_block
+
(
sk_num_blocks
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
return
sk_total_iters
;
}
__host__
__device__
uint32_t
get_sk_tiles
()
const
{
// tiles for sk
uint32_t
sk_total_iters
=
get_sk_total_iters
();
return
k_iters_per_tile
.
div
(
sk_total_iters
);
}
__host__
__device__
dim3
get_grid_dims
()
const
__host__
__device__
dim3
get_grid_dims
()
const
{
{
return
dim3
(
dp_start_block_idx
+
dp_num_blocks
,
1
,
1
);
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
return
dim3
(
reduction_start_block_idx
+
get_sk_tiles
(),
1
,
1
);
}
else
return
dim3
(
reduction_start_block_idx
,
1
,
1
);
}
}
__device__
uint32_t
get_block_idx
()
const
__device__
uint32_t
get_block_idx
()
const
...
@@ -858,6 +895,8 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -858,6 +895,8 @@ struct BlockToCTileMap_GemmStreamK
}
}
else
if
(
block_idx
>=
dp_start_block_idx
)
else
if
(
block_idx
>=
dp_start_block_idx
)
{
{
uint32_t
sk_total_iters
=
get_sk_total_iters
();
uint32_t
dp_iters_per_block
=
k_iters_per_tile
.
get
();
iter_start
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
iter_start
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
iter_end
=
iter_start
+
dp_iters_per_block
;
iter_end
=
iter_start
+
dp_iters_per_block
;
}
}
...
@@ -882,10 +921,11 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -882,10 +921,11 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_tile
.
divmod
(
iter
,
tile_idx
,
iter_offset
);
k_iters_per_tile
.
divmod
(
iter
,
tile_idx
,
iter_offset
);
}
}
__device__
auto
tile_to_spatial
(
uint32_t
tile_idx
,
uint32_t
m
,
uint32_t
/*n*/
)
const
__device__
auto
tile_to_spatial
(
uint32_t
tile_idx
,
uint32_t
m
,
uint32_t
n
)
const
{
{
uint32_t
m_tile_idx
,
n_tile_idx
;
uint32_t
m_tile_idx
,
n_tile_idx
;
n_tiles
.
divmod
(
tile_idx
,
m_tile_idx
,
n_tile_idx
);
uint32_t
n_tiles_value
=
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
n_tiles
.
divmod
(
tile_idx
,
n_tiles_value
,
m_tile_idx
,
n_tile_idx
);
// return make_tuple(m_tile_idx, n_tile_idx);
// return make_tuple(m_tile_idx, n_tile_idx);
// swizzle tile
// swizzle tile
...
@@ -901,7 +941,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -901,7 +941,7 @@ struct BlockToCTileMap_GemmStreamK
m_tile_idx_sub0
=
m_tile_idx
/
tile_swizzle_sub_m
;
m_tile_idx_sub0
=
m_tile_idx
/
tile_swizzle_sub_m
;
m_tile_idx_sub1
=
m_tile_idx
%
tile_swizzle_sub_m
;
m_tile_idx_sub1
=
m_tile_idx
%
tile_swizzle_sub_m
;
uint32_t
tile_idx_local
=
n_tile_idx
+
m_tile_idx_sub1
*
n_tiles
.
get
()
;
uint32_t
tile_idx_local
=
n_tile_idx
+
m_tile_idx_sub1
*
n_tiles
_value
;
uint32_t
m_tile_idx_with_adapt
,
n_tile_idx_with_adapt
;
uint32_t
m_tile_idx_with_adapt
,
n_tile_idx_with_adapt
;
...
@@ -911,6 +951,115 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -911,6 +951,115 @@ struct BlockToCTileMap_GemmStreamK
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
,
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
,
n_tile_idx_with_adapt
);
n_tile_idx_with_adapt
);
}
}
// __host__ __device__ uint32_t get_workspace_offset_for_semaphore() const
// {
// // workspace contains 2 part, 1) partial reduction buffer 2) semaphore for cross-wg sync
// // we let 1) start from offset:0, 2) start from the end of 1)
// // NOTE: offset is in unit of byte
// return get_total_acc_buffers() *
// }
__host__
__device__
uint32_t
get_workspace_size_for_acc
(
uint32_t
acc_element_bytes
)
const
{
static
constexpr
uint32_t
alignment
=
128
;
uint32_t
acc_buffer_bytes
=
MPerBlock
*
NPerBlock
*
get_total_acc_buffers
()
*
acc_element_bytes
;
return
(
acc_buffer_bytes
+
alignment
-
1
)
/
alignment
*
alignment
;
}
__host__
__device__
uint32_t
get_workspace_size_for_semaphore
()
const
{
return
get_sk_tiles
()
*
sizeof
(
uint32_t
);
}
__host__
__device__
uint32_t
get_workspace_size
(
uint32_t
acc_element_bytes
)
const
{
return
get_workspace_size_for_acc
(
acc_element_bytes
)
+
get_workspace_size_for_semaphore
();
}
__device__
uint32_t
get_tile_intersections
(
uint32_t
tiles_
,
const
MDiv
&
eqav_tiles_
)
const
{
uint32_t
tile_idx_
=
tiles_
==
0
?
0
:
(
tiles_
-
1
);
uint32_t
max_eqav_tiles_
=
eqav_tiles_
.
get
()
-
1
;
uint32_t
quo_
,
rem_
;
eqav_tiles_
.
divmod
(
tile_idx_
,
quo_
,
rem_
);
return
quo_
*
max_eqav_tiles_
+
rem_
;
// return tile_idx_ / eqav_tiles_ * max_eqav_tiles_ + (tile_idx_ % eqav_tiles_);
}
__host__
__device__
uint32_t
get_tiles_cover_sk_block
(
uint32_t
num_sk_blocks_
,
uint32_t
iters_per_sk_block_
)
const
{
return
k_iters_per_tile
.
div
(
num_sk_blocks_
*
iters_per_sk_block_
+
k_iters_per_tile
.
get
()
-
1
);
}
__host__
__device__
uint32_t
get_total_acc_buffers
()
const
{
uint32_t
tiles_cover_big_blocks
=
get_tiles_cover_sk_block
(
sk_num_big_blocks
,
k_iters_per_big_block
);
uint32_t
tiles_cover_little_blocks
=
get_tiles_cover_sk_block
(
sk_num_blocks
-
sk_num_big_blocks
,
k_iters_per_big_block
-
1
);
uint32_t
total_intersec_big
=
get_tile_intersections
(
tiles_cover_big_blocks
,
eqav_tiles_big
);
uint32_t
total_intersec_little
=
get_tile_intersections
(
tiles_cover_little_blocks
,
eqav_tiles_little
);
return
sk_num_blocks
+
total_intersec_big
+
total_intersec_little
;
}
__device__
uint32_t
get_acc_buffer_offset_from_tile
(
uint32_t
tile_idx_
)
const
{
// TODO: from big to little
uint32_t
tiles_cover_big_blocks
=
get_tiles_cover_sk_block
(
sk_num_big_blocks
,
k_iters_per_big_block
);
if
(
tile_idx_
<
tiles_cover_big_blocks
)
{
uint32_t
touched_sk_blocks
=
(
tile_idx_
*
k_iters_per_tile
.
get
()
+
k_iters_per_big_block
-
1
)
/
k_iters_per_big_block
;
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_
,
eqav_tiles_big
);
return
touched_sk_blocks
+
current_intersec
;
}
else
{
uint32_t
iters_per_little_sk_block
=
k_iters_per_big_block
-
1
;
uint32_t
tile_idx_little_reverse
=
get_sk_tiles
()
-
tile_idx_
;
uint32_t
touched_sk_blocks
=
(
tile_idx_little_reverse
*
k_iters_per_tile
.
get
()
+
iters_per_little_sk_block
-
1
)
/
iters_per_little_sk_block
;
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_little_reverse
,
eqav_tiles_little
);
// printf("reverse tile:%u, %u/%u\n", tile_idx_little_reverse, touched_sk_blocks,
// current_intersec);
return
get_total_acc_buffers
()
-
(
touched_sk_blocks
+
current_intersec
);
}
}
__device__
uint32_t
get_acc_buffer_offset_from_block
(
uint32_t
block_idx_
)
const
{
uint32_t
iters_per_big_sk_block
=
k_iters_per_big_block
;
uint32_t
iters_per_little_sk_block
=
k_iters_per_big_block
-
1
;
if
(
block_idx_
<
sk_num_big_blocks
)
{
// uint32_t touched_tiles = (block_idx_ * iters_per_big_sk_block + iters - 1) / iters;
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_
*
iters_per_big_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eqav_tiles_big
);
return
block_idx_
+
current_intersec
;
}
else
{
uint32_t
block_idx_little_reverse
=
sk_num_blocks
-
block_idx_
;
// uint32_t touched_tiles = (block_idx_little_reverse * iters_per_little_sk_block +
// iters - 1) / iters;
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_little_reverse
*
iters_per_little_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
eqav_tiles_little
);
return
get_total_acc_buffers
()
-
(
block_idx_little_reverse
+
current_intersec
);
}
}
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
ca8b5c79
...
@@ -14,8 +14,9 @@
...
@@ -14,8 +14,9 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -28,6 +29,7 @@ __global__ void
...
@@ -28,6 +29,7 @@ __global__ void
kernel_gemm_xdlops_streamk
(
const
typename
GridwiseGemm
::
FloatAB
*
p_a_grid
,
kernel_gemm_xdlops_streamk
(
const
typename
GridwiseGemm
::
FloatAB
*
p_a_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
p_b_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
p_b_grid
,
typename
GridwiseGemm
::
FloatC
*
p_c_grid
,
typename
GridwiseGemm
::
FloatC
*
p_c_grid
,
void
*
p_workspace
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -45,6 +47,7 @@ __global__ void
...
@@ -45,6 +47,7 @@ __global__ void
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_workspace
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -52,18 +55,26 @@ __global__ void
...
@@ -52,18 +55,26 @@ __global__ void
StrideB
,
StrideB
,
StrideC
,
StrideC
,
block_mapping
,
block_mapping
,
static_cast
<
void
*>
(
p_shared
));
static_cast
<
void
*>
(
p_shared
));
#else
#else
ignore
=
karg
;
ignore
=
p_a_grid
;
ignore
=
b2c_map
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_workspace
;
ignore
=
M
;
ignore
=
N
;
ignore
=
K
;
ignore
=
StrideA
;
ignore
=
StrideB
;
ignore
=
StrideC
;
ignore
=
block_mapping
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
Block2CTileMap_
,
typename
Block2CTileMap_
,
typename
FloatAB_
,
typename
FloatAB_
,
typename
FloatAcc
,
typename
FloatAcc
_
,
typename
FloatC_
,
typename
FloatC_
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -117,6 +128,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -117,6 +128,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static
constexpr
auto
KPerBlock
=
K0PerBlock
*
K1
;
static
constexpr
auto
KPerBlock
=
K0PerBlock
*
K1
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
FloatAcc
=
FloatAcc_
;
using
FloatCShuffle
=
FloatAcc
;
using
FloatCShuffle
=
FloatAcc
;
using
Block2CTileMap
=
Block2CTileMap_
;
using
Block2CTileMap
=
Block2CTileMap_
;
...
@@ -292,7 +304,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -292,7 +304,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
c_block_size
=
constexpr
auto
c_block_size
=
GetCBlockDescriptor_MBlock_MPer
Block
_NBlock_NPer
Block
().
GetElementSpaceSize
();
GetCBlockDescriptor_MBlock_MPer
Shuffle
_NBlock_NPer
Shuffle
().
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
FloatAB
),
...
@@ -372,7 +384,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -372,7 +384,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_MPer
Block
_NBlock_NPer
Block
()
GetCBlockDescriptor_MBlock_MPer
Shuffle
_NBlock_NPer
Shuffle
()
{
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
...
@@ -384,11 +396,54 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -384,11 +396,54 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
/
CShuffleMRepeatPerShuffle
>
{},
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
>
{},
Number
<
NRepeat
/
CShuffleNRepeatPerShuffle
>
{},
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
__host__
__device__
static
constexpr
auto
GetClusterLengthReduction
()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr
auto
NPerBlockPow2
=
math
::
next_power_of_two
<
NPerBlock
>
();
constexpr
auto
NPerBlockReduction
=
NPerBlockPow2
/
CBlockTransferScalarPerVector_NWaveNPerXDL
;
constexpr
auto
MPerBlockReduction
=
(
BlockSize
+
NPerBlockReduction
-
1
)
/
NPerBlockReduction
;
return
Sequence
<
MPerBlockReduction
,
NPerBlockReduction
>
{};
}
__host__
__device__
static
constexpr
auto
GetPartialAccBlockDescriptor
()
{
const
auto
c_partial_acc_block_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MPerBlock
,
NPerBlock
),
make_tuple
(
NPerBlock
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MPerBlock
,
NPerBlock
),
make_tuple
(
I1
,
MPerBlock
));
}
}();
return
c_partial_acc_block_m_n
;
}
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
))
>
;
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
FloatC
*
p_c_grid
,
void
*
p_workspace
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -425,6 +480,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -425,6 +480,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// ignore = p_workspace; // TODO: for reduction
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -468,16 +525,187 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -468,16 +525,187 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
uint32_t
block_idx
=
block_mapping
.
get_block_idx
();
uint32_t
block_idx
=
block_mapping
.
get_block_idx
();
bool
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
bool
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
bool
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
;
bool
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
&&
block_idx
<
block_mapping
.
reduction_start_block_idx
;
bool
is_reduction_block
=
block_idx
>=
block_mapping
.
reduction_start_block_idx
;
bool
is_padding_block
=
block_idx
>=
block_mapping
.
sk_num_blocks
&&
block_idx
<
block_mapping
.
dp_start_block_idx
;
uint32_t
iter_start
,
iter_end
;
uint32_t
iter_start
,
iter_end
;
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
// if(threadIdx.x == 0)
// if(threadIdx.x == 0)
// printf("xxx bid:%d, is_sk_block:%d, is_dp_block:%d\n", static_cast<int>(blockIdx.x),
// printf("xxx bid:%d, is_sk_block:%d, is_dp_block:%d\n", static_cast<int>(blockIdx.x),
// is_sk_block, is_dp_block);
// is_sk_block, is_dp_block);
if
(
!
is_
sk_block
&&
!
is_dp
_block
)
if
(
is_
padding
_block
)
return
;
return
;
uint32_t
*
p_semaphore
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
block_mapping
.
get_workspace_size_for_acc
(
sizeof
(
FloatAcc
)));
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_reduction_block
)
{
// descriptors
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
constexpr
auto
acc_thread_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
constexpr
auto
partial_acc_load_step_n
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
constexpr
auto
partial_acc_load_step_n_reverse
=
make_multi_index
(
0
,
-
1
*
(
MReduceIters
-
1
)
*
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
constexpr
auto
partial_acc_load_step_m
=
make_multi_index
(
cluster_length_reduce
.
At
(
I0
),
0
);
constexpr
auto
partial_acc_store_step_n
=
make_multi_index
(
0
,
0
,
0
,
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
constexpr
auto
partial_acc_store_step_n_reverse
=
make_multi_index
(
0
,
0
,
0
,
-
1
*
(
MReduceIters
-
1
)
*
cluster_length_reduce
.
At
(
I1
)
*
CBlockTransferScalarPerVector_NWaveNPerXDL
);
constexpr
auto
partial_acc_store_step_m
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I0
),
0
,
0
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
true
>
parcial_acc_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
true
>
acc_buf
;
acc_buf
.
Clear
();
// start to compute
auto
reduction_idx
=
blockIdx
.
x
-
block_mapping
.
reduction_start_block_idx
;
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
reduction_idx
,
m
,
n
);
workgroup_barrier
wg_barrier
(
p_semaphore
);
uint32_t
tile_acc_offset_start
=
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
uint32_t
tile_acc_offset_end
=
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
+
1
);
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAcc
,
// SrcData,
FloatAcc
,
// DstData,
decltype
(
c_partial_acc_block_m_n
),
// SrcDesc,
decltype
(
acc_thread_buf_desc
),
// DstDesc,
Sequence
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
,
// SliceLengths,
Sequence
<
I0
>
,
// DimAccessOrder,
2
,
// SrcVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
>
{
c_partial_acc_block_m_n
,
make_multi_index
(
static_cast
<
index_t
>
(
tile_acc_offset_start
),
I0
,
I0
)};
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
// SrcData,
FloatC
,
// DstData,
decltype
(
acc_thread_buf_desc
),
// SrcDesc,
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
CElementwiseOperation
,
// ElementwiseOperation,
Sequence
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
,
// SliceLengths,
Sequence
<
I0
>
,
// DimAccessOrder,
2
,
// DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// DstScalarPerVector,
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
1
,
// DstScalarStrideInVector,
false
// DstResetCoordinateAfterRun,
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
spatial_idx
[
I0
],
I0
,
spatial_idx
[
I1
],
I0
),
CElementwiseOperation
{}};
// block synchronization
wg_barrier
.
wait_eq
(
reduction_idx
,
tile_acc_offset_end
-
tile_acc_offset_start
);
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
FloatAcc
>
;
static_for
<
0
,
MReduceIters
,
1
>
{}([
&
](
auto
i_m_reduce
)
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
{
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
static_cast
<
FloatAcc
*>
(
p_workspace
)
+
i
,
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
acc_load
.
Run
(
c_partial_acc_block_m_n
,
c_partial_acc_buf
,
acc_thread_buf_desc
,
make_multi_index
(
I0
),
parcial_acc_buf
);
static_for
<
0
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
1
>
{}(
[
&
](
auto
i_vec
)
{
constexpr
auto
offset
=
acc_thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_vec
));
Accumulation
::
Calculate
(
acc_buf
(
Number
<
offset
>
{}),
parcial_acc_buf
[
Number
<
offset
>
{}]);
});
}
acc_store
.
Run
(
acc_thread_buf_desc
,
make_multi_index
(
I0
),
acc_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
i_n_reduce
!=
(
NReduceIters
-
1
))
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_load_step_n
);
}
else
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n_reverse
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n_reverse
);
}
});
if
constexpr
(
i_m_reduce
!=
MReduceIters
-
1
)
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_m
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_m
);
}
});
return
;
}
}
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_mapping
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
while
(
true
)
while
(
true
)
{
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
...
@@ -602,15 +830,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -602,15 +830,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_block_desc_mblock_mpershuffle_nblock_npershuffle
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle
();
constexpr
auto
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
=
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared_block
),
static_cast
<
FloatCShuffle
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
c_block_desc_mblock_mpershuffle_nblock_npershuffle
.
GetElementSpaceSize
());
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
static_cast
<
FloatAcc
*>
(
p_workspace
)
+
block_acc_offset
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mper
block
_nblock_nper
block
,
c_block_desc_mblock_mper
shuffle
_nblock_nper
shuffle
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
CShuffleMRepeatPerShuffle
,
make_tuple
(
CShuffleMRepeatPerShuffle
,
...
@@ -701,14 +936,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -701,14 +936,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mper
block
_nblock_nper
block
),
decltype
(
c_block_desc_mblock_mper
shuffle
_nblock_nper
shuffle
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mper
block
_nblock_nper
block
,
{
c_block_desc_mblock_mper
shuffle
_nblock_nper
shuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
...
@@ -717,6 +952,32 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -717,6 +952,32 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
0
),
0
),
c_element_op
};
c_element_op
};
// LDS to global partial acc
auto
c_block_copy_lds_to_partial_acc
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_element_op
};
constexpr
auto
mxdlperwave_forward_step
=
constexpr
auto
mxdlperwave_forward_step
=
make_multi_index
(
0
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
0
,
0
);
make_multi_index
(
0
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
0
,
0
);
constexpr
auto
nxdlperwave_forward_step
=
constexpr
auto
nxdlperwave_forward_step
=
...
@@ -757,19 +1018,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -757,19 +1018,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
c_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_block_buf
),
c_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_block_buf
),
decltype
(
c_grid_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
Set
>(
InMemoryDataOperationEnum
::
Set
>(
c_block_desc_mblock_mper
block
_nblock_nper
block
,
c_block_desc_mblock_mper
shuffle
_nblock_nper
shuffle
,
c_block_buf
,
c_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
c_grid_buf
);
else
if
(
is_sk_block
)
else
if
(
is_sk_block
)
c_block_copy_lds_to_global
{
.
template
Run
<
decltype
(
c_block_buf
),
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
decltype
(
c_grid_buf
),
StreamKReductionStrategy
::
Reduction
)
InMemoryDataOperationEnum
::
AtomicAdd
>(
{
c_block_desc_mblock_mperblock_nblock_nperblock
,
// constexpr offset
c_block_buf
,
c_block_copy_lds_to_partial_acc
.
SetDstSliceOrigin
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
c_grid_buf
);
make_tuple
(
mxdlperwave
,
I0
,
nxdlperwave
,
I0
));
c_block_copy_lds_to_partial_acc
.
template
Run
<
decltype
(
c_block_buf
),
decltype
(
c_block_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_block_buf
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
c_partial_acc_buf
);
}
else
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
{
c_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
}
// move on nxdlperwave dimension
// move on nxdlperwave dimension
if
constexpr
(
nxdlperwave_forward_sweep
&&
if
constexpr
(
nxdlperwave_forward_sweep
&&
...
@@ -795,6 +1079,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -795,6 +1079,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
mxdlperwave_forward_step
);
mxdlperwave_forward_step
);
}
}
});
});
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
tile_idx
);
}
}
}
}
// exit condition
// exit condition
...
@@ -802,6 +1097,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -802,6 +1097,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
if
(
iter_end
<=
iter_start
)
if
(
iter_end
<=
iter_start
)
break
;
break
;
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
block_acc_offset
-=
MPerBlock
*
NPerBlock
;
}
// make sure next loop LDS is ready for use
// make sure next loop LDS is ready for use
block_sync_lds
();
block_sync_lds
();
}
}
...
...
include/ck/utility/magic_division.hpp
View file @
ca8b5c79
...
@@ -178,21 +178,46 @@ struct MDiv
...
@@ -178,21 +178,46 @@ struct MDiv
ck
::
tie
(
multiplier
,
shift
)
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
ck
::
tie
(
multiplier
,
shift
)
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
}
}
__host__
__device__
uint32_t
div
(
uint32_t
dividend
)
const
__host__
__device__
uint32_t
div
(
uint32_t
dividend
_
)
const
{
{
return
MagicDivision
::
DoMagicDivision
(
dividend
,
multiplier
,
shift
);
return
MagicDivision
::
DoMagicDivision
(
dividend
_
,
multiplier
,
shift
);
}
}
__host__
__device__
void
__host__
__device__
void
divmod
(
uint32_t
dividend
,
uint32_t
&
quotient
,
uint32_t
&
remainder
)
const
divmod
(
uint32_t
dividend
_
,
uint32_t
&
quotient
_
,
uint32_t
&
remainder
_
)
const
{
{
quotient
=
div
(
dividend
);
quotient
_
=
div
(
dividend
_
);
remainder
=
dividend
-
(
quotient
*
divisor
);
remainder
_
=
dividend
_
-
(
quotient
_
*
divisor
);
}
}
__host__
__device__
uint32_t
operator
/
(
uint32_t
dividend
)
const
{
return
div
(
dividend
);
}
__host__
__device__
uint32_t
get
()
const
{
return
divisor
;
}
__host__
__device__
uint32_t
get
()
const
{
return
divisor
;
}
};
};
struct
MDiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t
multiplier
;
uint32_t
shift
;
// TODO: 8 bit is enough
// prefer construct on host
__host__
__device__
MDiv2
(
uint32_t
divisor_
)
{
ck
::
tie
(
multiplier
,
shift
)
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
}
__host__
__device__
MDiv2
()
:
multiplier
(
0
),
shift
(
0
)
{}
__host__
__device__
uint32_t
div
(
uint32_t
dividend_
)
const
{
return
MagicDivision
::
DoMagicDivision
(
dividend_
,
multiplier
,
shift
);
}
__host__
__device__
void
divmod
(
uint32_t
dividend_
,
uint32_t
divisor_
,
uint32_t
&
quotient_
,
uint32_t
&
remainder_
)
const
{
quotient_
=
div
(
dividend_
);
remainder_
=
dividend_
-
(
quotient_
*
divisor_
);
}
};
}
// namespace ck
}
// namespace ck
include/ck/utility/math.hpp
View file @
ca8b5c79
...
@@ -240,5 +240,21 @@ struct less
...
@@ -240,5 +240,21 @@ struct less
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
<
y
;
}
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
<
y
;
}
};
};
template
<
index_t
X
>
__host__
__device__
constexpr
auto
next_power_of_two
()
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr
index_t
Y
=
1
<<
(
32
-
__builtin_clz
(
X
-
1
));
return
Y
;
}
template
<
index_t
X
>
__host__
__device__
constexpr
auto
next_power_of_two
(
Number
<
X
>
x
)
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr
index_t
Y
=
1
<<
(
32
-
__builtin_clz
(
x
.
value
-
1
));
return
Number
<
Y
>
{};
}
}
// namespace math
}
// namespace math
}
// namespace ck
}
// namespace ck
include/ck/utility/workgroup_barrier.hpp
0 → 100644
View file @
ca8b5c79
#pragma once
#include <hip/hip_runtime.h>
#include <stdint.h>
namespace
ck
{
struct
workgroup_barrier
{
__device__
workgroup_barrier
(
uint32_t
*
ptr
)
:
base_ptr
(
ptr
)
{}
__device__
uint32_t
ld
(
uint32_t
offset
)
{
#if 0
float d = llvm_amdgcn_raw_buffer_load_fp32(
amdgcn_make_buffer_resource(base_ptr),
0,
offset,
AMDGCN_BUFFER_GLC);
union cvt {
float f32;
uint32_t u32;
};
cvt x;
x.f32 = d;
return x.u32;
#endif
return
__atomic_load_n
(
base_ptr
+
offset
,
__ATOMIC_RELAXED
);
}
__device__
void
wait_eq
(
uint32_t
offset
,
uint32_t
value
)
{
if
(
threadIdx
.
x
==
0
){
while
(
ld
(
offset
)
!=
value
){}
}
__syncthreads
();
}
__device__
void
wait_lt
(
uint32_t
offset
,
uint32_t
value
)
{
if
(
threadIdx
.
x
==
0
){
while
(
ld
(
offset
)
<
value
){}
}
__syncthreads
();
}
__device__
void
wait_set
(
uint32_t
offset
,
uint32_t
compare
,
uint32_t
value
)
{
if
(
threadIdx
.
x
==
0
){
while
(
atomicCAS
(
base_ptr
+
offset
,
compare
,
value
)
!=
compare
){}
}
__syncthreads
();
}
// enter critical zoon, assume buffer is zero when launch kernel
__device__
void
aquire
(
uint32_t
offset
)
{
wait_set
(
offset
,
0
,
1
);
}
// exit critical zoon, assume buffer is zero when launch kernel
__device__
void
release
(
uint32_t
offset
)
{
wait_set
(
offset
,
1
,
0
);
}
__device__
void
inc
(
uint32_t
offset
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
){
atomicAdd
(
base_ptr
+
offset
,
1
);
}
}
uint32_t
*
base_ptr
;
};
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment