Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f18cfec4
Unverified
Commit
f18cfec4
authored
Feb 14, 2025
by
Haocong WANG
Committed by
GitHub
Feb 14, 2025
Browse files
Merge branch 'develop' into update_cka8w8_uc
parents
4599ee00
0e5e29c4
Changes
163
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
368 additions
and
81 deletions
+368
-81
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
..._batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
+19
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+10
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
...pu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
+18
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp
...grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp
+10
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
...ion/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
+10
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
+13
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
+10
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+9
-1
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-1
include/ck_tile/core/utility/transpose_vectors.hpp
include/ck_tile/core/utility/transpose_vectors.hpp
+73
-43
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/concat.hpp
include/ck_tile/host/concat.hpp
+122
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
f18cfec4
...
...
@@ -343,10 +343,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
math
::
lcm
(
AK1
,
BK1
)
<=
4
)
?
true
:
false
;
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
...
...
@@ -552,8 +558,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
f18cfec4
...
...
@@ -469,8 +469,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
f18cfec4
...
...
@@ -498,8 +498,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
f18cfec4
...
...
@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
f18cfec4
...
...
@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
AComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
AComputeType
,
half_t
>::
value
||
is_same
<
AComputeType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
AComputeType
,
MPerXdl
,
NPerXdl
,
AComputeType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
f18cfec4
...
...
@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
View file @
f18cfec4
...
...
@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ABDataType
,
half_t
>::
value
||
is_same
<
ABDataType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
,
ABDataType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ABDataType
,
half_t
>::
value
||
is_same
<
ABDataType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
,
ABDataType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp
View file @
f18cfec4
...
...
@@ -595,9 +595,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ComputeType
,
half_t
>::
value
||
is_same
<
ComputeType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ComputeType
,
MPerXdl
,
NPerXdl
,
ComputeType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
View file @
f18cfec4
...
...
@@ -79,9 +79,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1Number
,
BK1Number
);
static
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ComputeTypeA
,
half_t
>::
value
||
is_same
<
ComputeTypeA
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeA
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
100755 → 100644
View file @
f18cfec4
...
...
@@ -139,9 +139,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1Number
,
BK1Number
);
static
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ComputeTypeA
,
half_t
>::
value
||
is_same
<
ComputeTypeA
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeA
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateMPadded
(
index_t
M
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
View file @
f18cfec4
...
...
@@ -869,9 +869,16 @@ struct GridwiseGemm_xdl_cshuffle_v2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ComputeTypeA
,
half_t
>::
value
||
is_same
<
ComputeTypeA
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeA
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
f18cfec4
...
...
@@ -147,9 +147,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1Number
,
BK1Number
);
static
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ComputeTypeA
,
half_t
>::
value
||
is_same
<
ComputeTypeA
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeA
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
f18cfec4
...
...
@@ -155,9 +155,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1Number
,
BK1Number
);
static
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ComputeTypeA
,
half_t
>::
value
||
is_same
<
ComputeTypeA
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeA
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -1424,7 +1431,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// b scale
// static_assert(KPerBlock <= ScaleBlockK);
static
constexpr
auto
mfma
=
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeA
,
is_single_rate_mfma
>
{};
static
constexpr
auto
KPerXdlops
=
mfma
.
GetKPerXdlops
();
static
constexpr
auto
K1PerXdlops
=
mfma
.
GetK1PerXdlops
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
...
...
@@ -1895,7 +1903,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock
);
// B scale
static
constexpr
auto
mfma
=
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeA
,
is_single_rate_mfma
>
{};
static
constexpr
auto
KPerXdlops
=
mfma
.
GetKPerXdlops
();
static
constexpr
auto
K1PerXdlops
=
mfma
.
GetK1PerXdlops
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
f18cfec4
...
...
@@ -489,8 +489,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
f18cfec4
...
...
@@ -487,9 +487,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
else
if
(
TileMathThreadGroup
::
IsBelong
())
{
// branch early for math wave
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ABDataType
,
half_t
>::
value
||
is_same
<
ABDataType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
,
ABDataType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
TileMathThreadGroupSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
f18cfec4
...
...
@@ -446,8 +446,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
k_pack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
...
...
include/ck_tile/core.hpp
View file @
f18cfec4
...
...
@@ -27,12 +27,12 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
...
...
include/ck_tile/core/utility/transpose_vectors.hpp
View file @
f18cfec4
...
...
@@ -68,52 +68,82 @@ struct transpose_vectors
}
else
if
constexpr
(
sizeof
(
S
)
==
1
)
{
static_assert
((
NX
%
4
==
0
&&
NY
%
4
==
0
),
"wrong!"
);
static_assert
((
(
NX
%
4
==
0
&&
NY
%
4
==
0
)
||
(
NX
%
2
==
0
&&
NY
%
2
==
0
))
,
"wrong!"
);
using
S4
=
array
<
S
,
4
>
;
// typename array<S, 4>::type;
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
4
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
4
>
{}([
&
](
auto
ix
)
{
// 4 int8x4 data from vx_tuple
const
int32_t
x_s4_0
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_1
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I1
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_2
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I2
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_3
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I3
].
template
get_as
<
S4
>()[
iy
/
I4
]);
// transpose
int32_t
t_s4_0
,
t_s4_1
;
int32_t
y_s4_0
,
y_s4_1
,
y_s4_2
,
y_s4_3
;
constexpr
int32_t
m0
=
0x05010400
;
constexpr
int32_t
m1
=
0x05040100
;
constexpr
int32_t
m2
=
0x07060302
;
constexpr
int32_t
m3
=
0x07030602
;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m0
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m0
);
y_s4_0
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_1
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m3
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m3
);
y_s4_2
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_3
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
// 4 int8x4 data from vy_tuple
vy_tuple
(
iy
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_0
);
vy_tuple
(
iy
+
I1
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_1
);
vy_tuple
(
iy
+
I2
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_2
);
vy_tuple
(
iy
+
I3
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_3
);
using
S2
=
array
<
S
,
2
>
;
// typename array<S, 4>::type;
if
constexpr
(
NX
%
4
==
0
&&
NY
%
4
==
0
)
{
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
4
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
4
>
{}([
&
](
auto
ix
)
{
// 4 int8x4 data from vx_tuple
const
int32_t
x_s4_0
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_1
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I1
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_2
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I2
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_3
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I3
].
template
get_as
<
S4
>()[
iy
/
I4
]);
// transpose
int32_t
t_s4_0
,
t_s4_1
;
int32_t
y_s4_0
,
y_s4_1
,
y_s4_2
,
y_s4_3
;
constexpr
int32_t
m0
=
0x05010400
;
constexpr
int32_t
m1
=
0x05040100
;
constexpr
int32_t
m2
=
0x07060302
;
constexpr
int32_t
m3
=
0x07030602
;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) ->
// 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits
// first)
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m0
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m0
);
y_s4_0
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_1
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m3
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m3
);
y_s4_2
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_3
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
// 4 int8x4 data from vy_tuple
vy_tuple
(
iy
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_0
);
vy_tuple
(
iy
+
I1
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_1
);
vy_tuple
(
iy
+
I2
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_2
);
vy_tuple
(
iy
+
I3
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_3
);
});
});
});
}
else
if
constexpr
(
NX
%
2
==
0
&&
NY
%
2
==
0
)
{
static_for
<
0
,
NY
,
2
>
{}([
&
](
auto
ix
)
{
static_for
<
0
,
NX
,
2
>
{}([
&
](
auto
iy
)
{
const
int16_t
x_s2_0
=
bit_cast
<
int16_t
>
(
vx_tuple
[
ix
].
template
get_as
<
S2
>()[
iy
/
I2
]);
const
int16_t
x_s2_1
=
bit_cast
<
int16_t
>
(
vx_tuple
[
ix
+
I1
].
template
get_as
<
S2
>()[
iy
/
I2
]);
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m1
=
0x07060302
;
const
int32_t
x0_32
=
static_cast
<
int32_t
>
(
x_s2_0
&
0xFFFF
);
const
int32_t
x1_32
=
static_cast
<
int32_t
>
(
x_s2_1
&
0xFFFF
);
const
int32_t
y_s2_0
=
__builtin_amdgcn_perm
(
x1_32
,
x0_32
,
m0
);
const
int32_t
y_s2_1
=
__builtin_amdgcn_perm
(
x1_32
,
x0_32
,
m1
);
vy_tuple
(
iy
).
template
get_as
<
S2
>()[
ix
/
I2
]
=
bit_cast
<
S2
>
(
static_cast
<
int16_t
>
(
y_s2_0
&
0xFFFF
));
vy_tuple
(
iy
+
I1
).
template
get_as
<
S2
>()[
ix
/
I2
]
=
bit_cast
<
S2
>
(
static_cast
<
int16_t
>
(
y_s2_1
&
0xFFFF
));
});
});
}
}
else
{
...
...
include/ck_tile/host.hpp
View file @
f18cfec4
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
...
...
include/ck_tile/host/concat.hpp
0 → 100644
View file @
f18cfec4
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
template
<
typename
T
>
struct
IsCharArray
:
std
::
false_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
char
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
const
char
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
char
(
&
)[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
const
char
(
&
)[
N
]
>
:
std
::
true_type
{
};
template
<
typename
...
Ts
>
inline
constexpr
bool
AllConvertibleToStringView
=
((
std
::
is_convertible_v
<
Ts
,
std
::
string_view
>
||
IsCharArray
<
Ts
>::
value
||
std
::
is_same_v
<
Ts
,
char
>
)
&&
...);
template
<
typename
...
Ts
>
[[
nodiscard
]]
auto
concat
(
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<!
AllConvertibleToStringView
<
Ts
...
>
,
std
::
string
>
{
using
::
operator
<<
;
thread_local
std
::
ostringstream
oss
;
oss
.
str
(
""
);
(
oss
<<
...
<<
xs
);
return
oss
.
str
();
}
template
<
std
::
size_t
N
>
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
char
(
&
)[
N
])
noexcept
{
return
N
;
}
template
<
std
::
size_t
N
>
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
(
&
)[
N
])
noexcept
{
return
N
;
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
*
s
)
noexcept
{
const
char
*
end
=
s
;
while
(
*
end
++
!=
0
)
{}
return
end
-
s
-
1
;
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
&
)
noexcept
{
return
1
;
}
[[
nodiscard
]]
inline
std
::
size_t
getSize
(
const
std
::
string
&
s
)
noexcept
{
return
s
.
size
();
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
std
::
string_view
&
s
)
noexcept
{
return
s
.
size
();
}
template
<
typename
...
Ts
>
auto
concatInto
(
std
::
string
&
result
,
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
Ts
...
>
,
void
>
{
const
std
::
size_t
space
=
(
1
+
...
+
getSize
(
xs
));
result
.
reserve
(
result
.
size
()
+
space
);
((
result
+=
xs
),
...);
}
template
<
typename
...
Ts
>
[[
nodiscard
]]
auto
concat
(
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
Ts
...
>
,
std
::
string
>
{
std
::
string
result
;
concatInto
(
result
,
xs
...);
return
result
;
}
// Function for types convertible to std::string_view
template
<
typename
Sep
,
typename
First
,
typename
...
Rest
>
[[
nodiscard
]]
auto
concat
(
Sep
sep
,
const
First
&
first
,
const
Rest
&
...
rest
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
First
,
Rest
...
>
,
std
::
string
>
{
std
::
string
result
;
result
+=
first
;
((
result
+=
sep
,
result
+=
rest
),
...);
return
result
;
}
// Function for other types
template
<
typename
Sep
,
typename
First
,
typename
...
Rest
>
[[
nodiscard
]]
auto
concat
(
Sep
sep
,
const
First
&
first
,
const
Rest
&
...
rest
)
->
std
::
enable_if_t
<!
AllConvertibleToStringView
<
First
,
Rest
...
>
,
std
::
string
>
{
using
::
operator
<<
;
thread_local
std
::
ostringstream
oss
;
oss
.
str
(
""
);
oss
<<
first
;
((
oss
<<
sep
<<
rest
),
...);
return
oss
.
str
();
}
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment