Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
06701e70
Unverified
Commit
06701e70
authored
Jul 09, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Jul 09, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
5800d24e
da42a889
Changes
156
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3666 additions
and
257 deletions
+3666
-257
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
...vice/impl/device_grouped_query_attention_forward_wmma.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
...device/impl/device_multi_query_attention_forward_wmma.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+322
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+14
-6
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+15
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+26
-18
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+27
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
+2010
-0
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+3
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+108
-1
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
+409
-0
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+146
-1
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+97
-0
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+82
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1
-1
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+17
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+333
-185
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+3
-5
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+13
-1
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+34
-11
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
View file @
06701e70
...
...
@@ -61,7 +61,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -166,6 +166,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
View file @
06701e70
...
...
@@ -60,7 +60,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -165,6 +165,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
06701e70
...
...
@@ -1404,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK
}
};
template
<
uint32_t
MPerBlock_
,
uint32_t
NPerBlock_
,
uint32_t
KPerBlock_
,
StreamKReductionStrategy
ReductionStrategy_
=
StreamKReductionStrategy
::
Atomic
,
uint32_t
TileSwizzleSubM_
=
8
,
index_t
GroupNum
=
8
,
index_t
M01_
=
4
>
struct
BlockToCTileMap_GemmStreamK_v2
{
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
2
;
static
constexpr
uint32_t
MPerBlock
=
MPerBlock_
;
static
constexpr
uint32_t
NPerBlock
=
NPerBlock_
;
static
constexpr
uint32_t
KPerBlock
=
KPerBlock_
;
static
constexpr
StreamKReductionStrategy
ReductionStrategy
=
ReductionStrategy_
;
static
constexpr
uint32_t
tile_swizzle_sub_m
=
TileSwizzleSubM_
;
//--------------------------------------
// pass to device
mutable
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
dp_start_block_idx
;
uint32_t
reduction_start_block_idx
;
uint32_t
k_iters_per_big_block
;
MDiv2
n_tiles
;
MDiv
k_iters_per_tile
;
MDiv
equiv_tiles_big
;
// for reduction
MDiv
equiv_tiles_little
;
// for reduction
// prefer construct on host
__host__
__device__
BlockToCTileMap_GemmStreamK_v2
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
grid_size
=
1
,
uint32_t
streamk_sel
=
1
)
{
// total output tiles
uint32_t
num_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
uint32_t
dp_tiles
,
dp_num_blocks
,
sk_total_iters
;
// default to regular DP GEMM if sk blocks == 0
if
(
streamk_sel
==
0
)
{
sk_num_blocks
=
0
;
dp_tiles
=
num_tiles
;
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
// 2-tile sk + DP GEMM
else
{
// check if there's enough work for DP+ stream-k
bool
bigEnough
=
num_tiles
>
grid_size
;
// select between stream-k strategies
uint32_t
sk_tiles
=
0
;
if
(
streamk_sel
==
1
)
// 1 tile stream-k
{
sk_tiles
=
bigEnough
?
(
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
2
)
// 2-tile stream-k
{
sk_tiles
=
bigEnough
?
(
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
3
)
// 3-tile stream-k
{
sk_tiles
=
(
num_tiles
>
(
2
*
grid_size
))
?
(
2
*
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
4
)
// 4-tile stream-k
{
sk_tiles
=
(
num_tiles
>
(
3
*
grid_size
))
?
(
3
*
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
sk_num_blocks
=
sk_tiles
;
// remaining tiles are DP tiles
dp_tiles
=
bigEnough
?
(
num_tiles
-
sk_tiles
)
:
0
;
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
sk_num_blocks
;
}
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
// using multiple blocks for parallel reduction
reduction_start_block_idx
=
dp_start_block_idx
+
dp_num_blocks
;
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
uint32_t
upper_big
=
math
::
lcm
(
k_iters_per_big_block
,
k_iters_per_tile
.
get
());
uint32_t
upper_little
=
math
::
lcm
(
k_iters_per_big_block
-
1
,
k_iters_per_tile
.
get
());
equiv_tiles_big
=
MDiv
(
upper_big
/
k_iters_per_tile
.
get
());
equiv_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
__host__
__device__
uint32_t
get_sk_total_iters
()
const
{
uint32_t
sk_total_iters
=
sk_num_big_blocks
*
k_iters_per_big_block
+
(
sk_num_blocks
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
return
sk_total_iters
;
}
__host__
__device__
uint32_t
get_sk_tiles
()
const
{
// tiles for sk
uint32_t
sk_total_iters
=
get_sk_total_iters
();
return
k_iters_per_tile
.
div
(
sk_total_iters
);
}
__host__
__device__
index_t
get_grid_dims
()
const
{
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
return
reduction_start_block_idx
+
get_sk_tiles
();
}
else
return
reduction_start_block_idx
;
}
__device__
uint32_t
get_block_idx
()
const
{
// TODO: swizzle block index for better locality
return
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
}
__device__
void
get_block_itr
(
uint32_t
block_idx
,
uint32_t
&
iter_start
,
uint32_t
&
iter_end
)
const
{
if
(
block_idx
<
sk_num_big_blocks
)
{
iter_start
=
block_idx
*
k_iters_per_big_block
;
iter_end
=
iter_start
+
k_iters_per_big_block
;
}
else
if
(
block_idx
<
sk_num_blocks
)
{
iter_start
=
(
sk_num_big_blocks
*
k_iters_per_big_block
)
+
(
block_idx
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
iter_end
=
iter_start
+
(
k_iters_per_big_block
-
1
);
}
else
if
(
block_idx
>=
dp_start_block_idx
)
{
uint32_t
sk_total_iters
=
get_sk_total_iters
();
uint32_t
dp_iters_per_block
=
k_iters_per_tile
.
get
();
iter_start
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
iter_end
=
iter_start
+
dp_iters_per_block
;
}
}
__device__
uint32_t
get_current_iter_length
(
uint32_t
iter_start
,
uint32_t
iter_end
,
uint32_t
total_iter_length
)
const
{
uint32_t
iter_length_mod
,
iter_length_quo
/*unused*/
;
k_iters_per_tile
.
divmod
(
iter_end
,
iter_length_quo
,
iter_length_mod
);
uint32_t
current_iter_length
=
math
::
min
(
iter_length_mod
==
0
?
(
iter_end
-
iter_start
)
:
iter_length_mod
,
total_iter_length
);
return
current_iter_length
;
}
__device__
uint32_t
get_tile_idx
(
uint32_t
iter
)
const
{
return
k_iters_per_tile
.
div
(
iter
);
}
__device__
void
get_tile_idx_with_offset
(
uint32_t
iter
,
uint32_t
&
tile_idx
,
uint32_t
&
iter_offset
)
const
{
k_iters_per_tile
.
divmod
(
iter
,
tile_idx
,
iter_offset
);
}
__device__
auto
tile_to_spatial
(
uint32_t
tile_idx
,
uint32_t
m
,
uint32_t
n
)
const
{
uint32_t
m_tile_idx
,
n_tile_idx
;
uint32_t
n_tiles_value
=
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
n_tiles
.
divmod
(
tile_idx
,
n_tiles_value
,
m_tile_idx
,
n_tile_idx
);
// // swizzle tile
uint32_t
m_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
);
uint32_t
tile_swizzle_sub_m_rem
=
m_tiles
%
tile_swizzle_sub_m
;
const
auto
sub_m_adapt
=
(
m_tile_idx
<
(
m_tiles
-
tile_swizzle_sub_m_rem
))
?
tile_swizzle_sub_m
:
tile_swizzle_sub_m_rem
;
uint32_t
m_tile_idx_sub0
,
m_tile_idx_sub1
;
m_tile_idx_sub0
=
m_tile_idx
/
tile_swizzle_sub_m
;
m_tile_idx_sub1
=
m_tile_idx
%
tile_swizzle_sub_m
;
uint32_t
tile_idx_local
=
n_tile_idx
+
m_tile_idx_sub1
*
n_tiles_value
;
uint32_t
m_tile_idx_with_adapt
,
n_tile_idx_with_adapt
;
n_tile_idx_with_adapt
=
tile_idx_local
/
sub_m_adapt
;
m_tile_idx_with_adapt
=
tile_idx_local
%
sub_m_adapt
;
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
,
n_tile_idx_with_adapt
);
}
__host__
__device__
uint32_t
get_workspace_size_for_acc
(
uint32_t
acc_element_bytes
)
const
{
static
constexpr
uint32_t
alignment
=
128
;
uint32_t
acc_buffer_bytes
=
MPerBlock
*
NPerBlock
*
get_total_acc_buffers
()
*
acc_element_bytes
;
return
(
acc_buffer_bytes
+
alignment
-
1
)
/
alignment
*
alignment
;
}
__host__
__device__
uint32_t
get_workspace_size_for_semaphore
()
const
{
return
get_sk_tiles
()
*
sizeof
(
uint32_t
);
}
__host__
__device__
uint32_t
get_workspace_size
(
uint32_t
acc_element_bytes
)
const
{
return
get_workspace_size_for_acc
(
acc_element_bytes
)
+
get_workspace_size_for_semaphore
();
}
__host__
__device__
uint32_t
get_tile_intersections
(
uint32_t
tiles_
,
const
MDiv
&
equiv_tiles_
)
const
{
uint32_t
tile_idx_
=
tiles_
==
0
?
0
:
(
tiles_
-
1
);
uint32_t
max_equiv_tiles_
=
equiv_tiles_
.
get
()
-
1
;
uint32_t
quo_
,
rem_
;
equiv_tiles_
.
divmod
(
tile_idx_
,
quo_
,
rem_
);
return
quo_
*
max_equiv_tiles_
+
rem_
;
}
__host__
__device__
uint32_t
get_tiles_cover_sk_block
(
uint32_t
num_sk_blocks_
,
uint32_t
iters_per_sk_block_
)
const
{
return
k_iters_per_tile
.
div
(
num_sk_blocks_
*
iters_per_sk_block_
+
k_iters_per_tile
.
get
()
-
1
);
}
__host__
__device__
uint32_t
get_total_acc_buffers
()
const
{
uint32_t
tiles_cover_big_blocks
=
get_tiles_cover_sk_block
(
sk_num_big_blocks
,
k_iters_per_big_block
);
uint32_t
tiles_cover_little_blocks
=
get_tiles_cover_sk_block
(
sk_num_blocks
-
sk_num_big_blocks
,
k_iters_per_big_block
-
1
);
uint32_t
total_intersec_big
=
get_tile_intersections
(
tiles_cover_big_blocks
,
equiv_tiles_big
);
uint32_t
total_intersec_little
=
get_tile_intersections
(
tiles_cover_little_blocks
,
equiv_tiles_little
);
return
sk_num_blocks
+
total_intersec_big
+
total_intersec_little
;
}
__device__
uint32_t
get_acc_buffer_offset_from_tile
(
uint32_t
tile_idx_
)
const
{
// TODO: from big to little
uint32_t
tiles_cover_big_blocks
=
get_tiles_cover_sk_block
(
sk_num_big_blocks
,
k_iters_per_big_block
);
if
(
tile_idx_
<
tiles_cover_big_blocks
)
{
uint32_t
touched_sk_blocks
=
(
tile_idx_
*
k_iters_per_tile
.
get
()
+
k_iters_per_big_block
-
1
)
/
k_iters_per_big_block
;
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_
,
equiv_tiles_big
);
return
touched_sk_blocks
+
current_intersec
;
}
else
{
uint32_t
iters_per_little_sk_block
=
k_iters_per_big_block
-
1
;
uint32_t
tile_idx_little_reverse
=
get_sk_tiles
()
-
tile_idx_
;
uint32_t
touched_sk_blocks
=
(
tile_idx_little_reverse
*
k_iters_per_tile
.
get
()
+
iters_per_little_sk_block
-
1
)
/
iters_per_little_sk_block
;
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_little_reverse
,
equiv_tiles_little
);
return
get_total_acc_buffers
()
-
(
touched_sk_blocks
+
current_intersec
);
}
}
__device__
uint32_t
get_acc_buffer_offset_from_block
(
uint32_t
block_idx_
)
const
{
uint32_t
iters_per_big_sk_block
=
k_iters_per_big_block
;
uint32_t
iters_per_little_sk_block
=
k_iters_per_big_block
-
1
;
if
(
block_idx_
<
sk_num_big_blocks
)
{
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_
*
iters_per_big_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
equiv_tiles_big
);
return
block_idx_
+
current_intersec
;
}
else
{
uint32_t
block_idx_little_reverse
=
sk_num_blocks
-
block_idx_
;
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_little_reverse
*
iters_per_little_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
equiv_tiles_little
);
return
get_total_acc_buffers
()
-
(
block_idx_little_reverse
+
current_intersec
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
06701e70
...
...
@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if
constexpr
(
B0EnableLds
)
{
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
B0BlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if
constexpr
(
B1EnableLds
)
{
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_LRow
=
I2
;
#else
constexpr
auto
B_LRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
B1BlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_L0
>
{},
B_LRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_L0
/
B_LRow
>
{},
B_LRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
06701e70
...
...
@@ -50,7 +50,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
if
constexpr
(
AEnableLds
)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
...
@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
06701e70
...
...
@@ -54,7 +54,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -147,7 +147,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// printf("entry kernel launch");
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
...
...
@@ -237,7 +237,7 @@ __global__ void
const
CDEElementwiseOperation
cde_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
}
else
{
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
A_KRow
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
}
else
{
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
B_KRow
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
if
constexpr
(
AEnableLds
)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
...
@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
>
{},
Number
<
CShuffleMRepeatPerShuffle
*
MWave
s
*
MPerWmma
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
s
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
...
...
@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
06701e70
...
...
@@ -45,7 +45,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
A_KRow
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
B_KRow
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
if
constexpr
(
AEnableLds
)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
...
@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
...
...
@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
b_block_space_size_aligned
*
sizeof
(
BDataType
));
};
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
0 → 100644
View file @
06701e70
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
06701e70
...
...
@@ -35,8 +35,9 @@ __global__ void
const
Block2ETileMap
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
p_in_global
,
out_grid_desc
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
06701e70
...
...
@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation
element_op_
;
};
// Specilized for WMMA
// Specilized for WMMA
-Navi3
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
...
...
@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
ElementwiseOperation
element_op_
{};
};
// Specilized for WMMA-Navi4
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
bool
IntraRowSwizzlePerm
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
(
const
Index
&
src_idx
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
ignore
=
src_idx
;
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! SliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
()
&&
DstBuffer
::
IsStaticBuffer
(),
"wrong! Buffer need to be StaticBuffer"
);
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
dst_slice_origin_idx
=
to_multi_index
(
DstSliceOriginIdx
{});
// scalar per access on each dim
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// src_desc error, non constexpr, caused by merge transform
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
v_this_row
;
// int type temp value due to intrinsic requirement
int
temp
=
0
;
// apply element-wise operation
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply intra-row permute.
if
constexpr
(
IntraRowSwizzlePerm
)
{
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert_sp
<
int
>
(
v_this_row
),
0xb3a29180
,
0xf7e6d5c4
,
1
,
0
);
v_this_row
=
type_convert_sp
<
SrcData
>
(
temp
);
}
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert_sp
<
DstData
>
(
v_this_row
);
});
});
}
ElementwiseOperation
element_op_
{};
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
0 → 100644
View file @
06701e70
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_smfmac.hpp"
namespace
ck
{
enum
struct
SmfmacInstr
{
smfmac_f32_16x16x32f16
=
0
,
smfmac_f32_32x32x16f16
,
smfmac_f32_16x16x32bf16
,
smfmac_f32_32x32x16bf16
,
};
template
<
SmfmacInstr
instr
>
struct
smfmac_type
;
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_16x16x32f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_32x32x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_16x16x32bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_32x32x16bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
struct
SmfmacSelector
{
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
static
constexpr
auto
GetSmfmac
();
template
<
>
static
constexpr
auto
GetSmfmac
<
half_t
,
16
,
16
>
()
{
return
SmfmacInstr
::
smfmac_f32_16x16x32f16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
half_t
,
32
,
32
>
()
{
return
SmfmacInstr
::
smfmac_f32_32x32x16f16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
bhalf_t
,
16
,
16
>
()
{
return
SmfmacInstr
::
smfmac_f32_16x16x32bf16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
bhalf_t
,
32
,
32
>
()
{
return
SmfmacInstr
::
smfmac_f32_32x32x16bf16
;
}
static
constexpr
auto
selected_smfmac
=
smfmac_type
<
GetSmfmac
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
__host__
__device__
constexpr
SmfmacSelector
()
{
static_assert
(
selected_smfmac
.
group_size
*
selected_smfmac
.
num_groups_per_blk
==
selected_smfmac
.
num_regs_per_blk
,
"wrong! num_regs_per_blk"
);
static_assert
(
selected_smfmac
.
num_threads_per_blk
==
selected_smfmac
.
n_per_blk
,
"n_per_blk != num_threads_per_blk"
);
static_assert
(
selected_smfmac
.
num_regs_per_blk
*
selected_smfmac
.
num_input_blks
==
selected_smfmac
.
m_per_blk
,
"m_per_blk != num_input_blks * num_regs_per_blk"
);
static_assert
(
selected_smfmac
.
num_output_blks
==
selected_smfmac
.
num_input_blks
||
selected_smfmac
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
selected_smfmac
.
num_regs_per_blk
*
selected_smfmac
.
wave_size
==
selected_smfmac
.
m_per_blk
*
selected_smfmac
.
n_per_blk
,
"num_regs_per_blk incorrect"
);
static_assert
(
selected_smfmac
.
is_k_reduction
||
(
selected_smfmac
.
num_input_blks
==
selected_smfmac
.
num_output_blks
),
"is_k_reduction wrong!"
);
}
static
constexpr
index_t
GetKPerXdlops
()
{
return
(
selected_smfmac
.
is_k_reduction
?
selected_smfmac
.
num_input_blks
:
1
)
*
selected_smfmac
.
k_per_blk
;
}
static
constexpr
index_t
GetK1PerXdlops
()
{
return
selected_smfmac
.
k_per_blk
;
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
,
typename
additional_type
=
base_type
>
struct
SparseXdlopsGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
smfmac_instr
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
(
smfmac_instr
.
m_per_blk
*
smfmac_instr
.
n_per_blk
*
smfmac_instr
.
num_output_blks
);
}
__host__
__device__
constexpr
SparseXdlopsGemm
()
{
static_assert
(
NPerXdlops
==
16
||
NPerXdlops
==
32
,
"Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops"
);
static_assert
(
MPerXdlops
==
16
||
MPerXdlops
==
32
,
"Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops"
);
static_assert
(
KPack
%
smfmac_instr
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
Number
<
smfmac_instr
.
num_groups_per_blk
>
{},
Number
<
smfmac_instr
.
num_input_blks
>
{},
Number
<
smfmac_instr
.
group_size
>
{})),
make_pass_through_transform
(
Number
<
smfmac_instr
.
num_threads_per_blk
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
}
template
<
typename
CDesc_G_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_G_M0_N0_M1_N1_M2_N2
&
c_desc_g_m0_n0_m1_n1_m2_n2
)
{
const
auto
G
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
M0
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
N0
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
M1
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
N1
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_g_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
G
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
smfmac_instr
.
num_groups_per_blk
,
smfmac_instr
.
num_input_blks
,
smfmac_instr
.
group_size
)),
make_pass_through_transform
(
smfmac_instr
.
num_threads_per_blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{},
Sequence
<
8
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
smfmac_instr
.
wave_size
;
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
smfmac_instr
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
Idx
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
const
Idx
&
idx
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
,
"base base_type must be half or bfloat16!"
);
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
],
p_c_thread
);
});
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
smfmac_instr
.
wave_size
;
}
__device__
static
auto
GetBlkIdx
()
{
const
auto
laneId
=
GetLaneId
();
constexpr
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
smfmac_instr
.
num_input_blks
,
smfmac_instr
.
num_threads_per_blk
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
blk_idx
=
threadidx_to_blk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
laneId
));
const
auto
blk_id
=
blk_idx
[
I1
];
const
auto
blk_td
=
blk_idx
[
I2
];
return
make_tuple
(
blk_id
,
blk_td
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
smfmac_instr
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
smfmac_instr
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
index_t
n_offset
=
blk_i
*
smfmac_instr
.
n_per_blk
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
smfmac_instr
.
m_per_blk
+
blk_id
*
smfmac_instr
.
group_size
;
return
CIndex
{
m_offset
,
n_offset
};
}
__device__
static
CIndex4D
GetBeginOfThreadBlk4D
(
index_t
/* xdlops_i */
,
index_t
/* blk_i */
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
return
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
smfmac
=
SmfmacSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
static
constexpr
auto
smfmac_instr
=
smfmac
.
selected_smfmac
;
static
constexpr
auto
KPerXdlops
=
smfmac
.
GetKPerXdlops
();
static
constexpr
auto
K1PerXdlops
=
smfmac
.
GetK1PerXdlops
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
{
return
make_tuple
(
Number
<
smfmac_instr
.
num_groups_per_blk
>
{},
I1
,
Number
<
smfmac_instr
.
group_size
>
{},
I1
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
06701e70
...
...
@@ -11,12 +11,17 @@ namespace ck {
enum
struct
WmmaInstr
{
// gfx11
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_bf16
,
wmma_f16_16x16x16_f16
,
wmma_bf16_16x16x16_bf16
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu4
wmma_i32_16x16x16_iu4
,
// gfx12
wmma_f32_16x16x16_f16_gfx12
,
wmma_f32_16x16x16_bf16_gfx12
,
wmma_i32_16x16x16_iu8_gfx12
,
};
/*
...
...
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
};
// gfx12
// A-swizzled
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
// * Data Pixel
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
,
bool
neg_a
=
false
,
bool
neg_b
=
false
,
bool
clamp
=
false
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
MPerWmma
,
NPerWmma
,
neg_a
,
neg_b
,
clamp
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
...
...
@@ -296,13 +417,21 @@ struct WmmaSelector
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
#endif
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
#endif
}
template
<
>
...
...
@@ -320,8 +449,13 @@ struct WmmaSelector
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
#else
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
#endif
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
...
...
@@ -502,6 +636,9 @@ struct WmmaGemm
__device__
static
auto
GetSubGroupId
()
{
static_assert
(
wmma_instr
.
num_thread_per_subgroups
*
wmma_instr
.
num_subgroups
==
wmma_instr
.
wave_size
,
""
);
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
}
...
...
@@ -516,12 +653,20 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
#ifdef __gfx12__
return
GetLaneIdUnderSubGroup
();
#else
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
#endif
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
#ifdef __gfx12__
return
GetLaneIdUnderSubGroup
();
#else
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
#endif
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
...
...
include/ck/utility/amd_smfmac.hpp
0 → 100644
View file @
06701e70
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#pragma once
namespace
ck
{
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32bf16
;
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_32x32x16f16
;
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_32x32x16bf16
;
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
}
// namespace ck
include/ck/utility/amd_wmma.hpp
View file @
06701e70
...
...
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
}
};
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12
(
neg_a
,
bit_cast
<
int32x2_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x2_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
06701e70
...
...
@@ -219,7 +219,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
}
};
int
static
err
=
0
;
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
...
...
include/ck/utility/synchronization.hpp
View file @
06701e70
...
...
@@ -10,12 +10,20 @@ namespace ck {
__device__
void
block_sync_lds
()
{
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
__builtin_amdgcn_s_barrier
();
#endif
#else
__syncthreads
();
#endif
...
...
@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
__device__
void
block_sync_lds_direct_load
()
{
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_vmcnt 0x0
\n
\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"\
s_waitcnt vmcnt(0)
\n
\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
#endif
}
__device__
void
s_nop
()
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
06701e70
This diff is collapsed.
Click to expand it.
include/ck_tile/core/arch/arch.hpp
View file @
06701e70
...
...
@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
"
::
);
}
CK_TILE_DEVICE
void
s_nop
()
CK_TILE_DEVICE
void
s_nop
(
index_t
cnt
=
0
)
{
#if 1
asm
volatile
(
"\
s_nop 0
\n
\
"
::
);
asm
volatile
(
"s_nop %0"
:
:
"n"
(
cnt
)
:
);
#else
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
cnt
);
#endif
}
...
...
include/ck_tile/core/config.hpp
View file @
06701e70
...
...
@@ -17,7 +17,11 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
...
...
@@ -144,6 +148,14 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
...
...
@@ -155,7 +167,7 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) // for GPU code
#elif defined(__gfx11__)
|| defined(__gfx12__)
// for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
06701e70
...
...
@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
generic
;
...
...
@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
int32x4_t
cached_buf_res_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
:
p_data_
{},
buffer_size_
{},
cached_buf_res_
{
0
},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
invalid_element_value
}
{
}
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE
void
init_raw
()
{
cached_buf_res_
=
make_wave_buffer_resource
(
p_data_
,
buffer_size_
*
sizeof
(
type
));
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
global
;
...
...
@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
)
const
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
dst
,
p_data_
,
i
,
buffer_size_
,
is_valid_element
);
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{}
);
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
)
const
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
p_data_
,
i
,
buffer_size_
);
amd_async_buffer_load_with_oob
_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
bool_constant
<
pre_nop
>
{}
);
}
// i is offset of T, not X. i should be aligned to X
...
...
@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
lds
;
...
...
@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
vgpr
;
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment