Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a11cf2c6
Unverified
Commit
a11cf2c6
authored
Jan 24, 2025
by
arai713
Committed by
GitHub
Jan 24, 2025
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
a72e9efa
64d5c4d6
Changes
156
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
572 additions
and
967 deletions
+572
-967
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
+1
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+56
-4
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+25
-15
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
..._tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
+4
-2
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+12
-1
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
+13
-1
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+55
-52
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+4
-4
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+66
-58
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+1
-2
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
+4
-4
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+55
-14
example/ck_tile/17_grouped_gemm/utils.hpp
example/ck_tile/17_grouped_gemm/utils.hpp
+0
-38
include/ck/ck.hpp
include/ck/ck.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+1
-14
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+213
-686
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+0
-2
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+17
-6
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+40
-63
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+0
-1
No files found.
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
View file @
a11cf2c6
...
@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
...
@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
t
.
prec_sq
,
t
.
prec_sq
,
t
.
prec_kw
,
t
.
prec_kw
,
t
.
block_m
,
t
.
block_m
,
t
.
activation
,
t
.
gate_only
,
t
.
gate_only
,
t
.
fused_quant
};
t
.
fused_quant
};
auto
a1
=
fused_moegemm_args
{
auto
a1
=
fused_moegemm_args
{
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
a11cf2c6
...
@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format off
// clang-format off
float
r
=
-
1
;
float
r
=
-
1
;
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
&&
t
.
activation
==
0
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
constexpr
ck_tile
::
index_t
act_
=
0
;
constexpr
ck_tile
::
index_t
go_
=
1
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
&&
t
.
activation
==
0
)
{
constexpr
ck_tile
::
index_t
act_
=
0
;
constexpr
ck_tile
::
index_t
go_
=
0
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
&&
t
.
activation
==
0
)
{
constexpr
ck_tile
::
index_t
act_
=
0
;
constexpr
ck_tile
::
index_t
go_
=
1
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
&&
t
.
activation
==
0
)
{
constexpr
ck_tile
::
index_t
act_
=
0
;
constexpr
ck_tile
::
index_t
go_
=
0
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
&&
t
.
activation
==
1
)
{
constexpr
ck_tile
::
index_t
act_
=
1
;
constexpr
ck_tile
::
index_t
go_
=
1
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
&&
t
.
activation
==
1
)
{
constexpr
ck_tile
::
index_t
act_
=
1
;
constexpr
ck_tile
::
index_t
go_
=
0
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
&&
t
.
activation
==
1
)
{
constexpr
ck_tile
::
index_t
act_
=
1
;
constexpr
ck_tile
::
index_t
go_
=
1
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
&&
t
.
activation
==
1
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
constexpr
ck_tile
::
index_t
act_
=
1
;
constexpr
ck_tile
::
index_t
go_
=
0
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
// clang-format on
// clang-format on
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
View file @
a11cf2c6
...
@@ -21,8 +21,18 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
...
@@ -21,8 +21,18 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
>
;
typename
Ts_
::
WarpTile_0
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
constexpr
auto
get_activation_
=
[]()
{
if
constexpr
(
Ts_
::
Activation
==
0
)
{
return
ck_tile
::
element_wise
::
FastGeluAsm
{};
}
else
return
ck_tile
::
element_wise
::
Silu
{};
};
using
f_act_
=
ck_tile
::
remove_cvref_t
<
decltype
(
get_activation_
())
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
AccDataType
,
...
@@ -33,7 +43,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
...
@@ -33,7 +43,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
IndexDataType
,
typename
Ts_
::
IndexDataType
,
ck_tile
::
element_wise
::
FastGeluAsm
,
// TODO: hardcoded
f_act_
,
// TODO: hardcoded
f_shape
,
f_shape
,
f_traits
>
;
f_traits
>
;
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
View file @
a11cf2c6
...
@@ -16,6 +16,7 @@ template <typename I,
...
@@ -16,6 +16,7 @@ template <typename I,
typename
BlockTIle_
,
// seq<b_token, b_interm, b_hidden, b_down>
typename
BlockTIle_
,
// seq<b_token, b_interm, b_hidden, b_down>
typename
WarpPerBlock_
,
typename
WarpPerBlock_
,
typename
WarpTile_
,
// seq<*,*,*>, used to select mfma
typename
WarpTile_
,
// seq<*,*,*>, used to select mfma
ck_tile
::
index_t
Activation_
=
0
,
// 0: Gelu 1: Silu
ck_tile
::
index_t
GateOnly_
=
0
,
ck_tile
::
index_t
GateOnly_
=
0
,
ck_tile
::
index_t
FusedQuant_
=
0
>
ck_tile
::
index_t
FusedQuant_
=
0
>
struct
fmoe_
// traits, ugly name, only used for internal
struct
fmoe_
// traits, ugly name, only used for internal
...
@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
Activation
=
Activation_
;
// 0: Gelu 1: Silu
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
};
};
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
View file @
a11cf2c6
...
@@ -8,7 +8,18 @@
...
@@ -8,7 +8,18 @@
// clang-format off
// clang-format off
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
View file @
a11cf2c6
...
@@ -8,7 +8,19 @@
...
@@ -8,7 +8,19 @@
// clang-format off
// clang-format off
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
example/ck_tile/15_fused_moe/main.cpp
View file @
a11cf2c6
...
@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
...
@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
.
insert
(
.
insert
(
"gate_only"
,
"1"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
"gate_only"
,
"1"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"api"
,
"0"
,
"benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm"
)
.
insert
(
"api"
,
"0"
,
"benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm"
)
.
insert
(
"act"
,
"0"
,
"activation after first gemm. 0:gelu, 1:silu"
)
.
insert
(
"balance"
,
.
insert
(
"balance"
,
"0"
,
"0"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"init"
,
.
insert
(
"init"
,
"2"
,
"1"
,
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
"normalized[0, 1]"
"normalized(slow)"
)
"normalized(slow)"
)
.
insert
(
"seed"
,
"11939"
,
"seed used to do random"
)
.
insert
(
"seed"
,
"11939"
,
"seed used to do random"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
...
@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
ck_tile
::
index_t
activation
=
arg_parser
.
get_int
(
"act"
);
if
(
stride
<
0
)
if
(
stride
<
0
)
stride
=
hidden_size
;
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
...
@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
std
::
string
(
", st:"
)
+
std
::
to_string
(
stride
);
return
std
::
string
(
", st:"
)
+
std
::
to_string
(
stride
);
}();
}();
std
::
cout
<<
"["
<<
api_str
<<
"|"
<<
prec_str
<<
"]"
std
::
cout
<<
"["
<<
api_str
<<
"|"
<<
prec_str
<<
"]"
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
stride_str
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
stride_str
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", shrd_interm:"
<<
shared_intermediate_size_0
<<
"|"
<<
shared_intermediate_size_1
<<
", act:"
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
<<
activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<<
(
gate_only
?
", g1u0"
:
", g1u1"
)
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
...
@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq
,
prec_sq
,
prec_kw
,
prec_kw
,
block_m
,
block_m
,
activation
,
gate_only
,
gate_only
,
fused_quant
};
fused_quant
};
...
@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
block_m
,
block_m
,
hidden_size
,
hidden_size
,
shared_
intermediate_size
_0
,
intermediate_size
/
tp
,
tokens
,
tokens
,
experts
,
experts
,
topk
,
topk
,
...
@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
bool
pass
=
true
;
bool
pass
=
true
;
#define CPU_FUSED_MOE(act_type_) \
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
g_host, \
d_host, \
sa_host, \
sg_host, \
sd_host, \
sy_host, \
o_host, \
sorted_token_ids_host, \
sorted_weight_host, \
sorted_expert_ids_host, \
num_sorted_tiles_host, \
topk_ids_host, \
block_m, \
tokens, \
experts, \
hidden_size, \
intermediate_size / tp, \
topk, \
gate_only)
if
(
do_validation
)
if
(
do_validation
)
{
{
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
...
@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host
.
mData
[
0
],
num_sorted_tiles_host
.
mData
[
0
],
experts
,
experts
,
block_m
);
block_m
);
if
(
activation
==
0
)
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
{
a_host
,
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
g_host
,
}
d_host
,
else
sa_host
,
{
sg_host
,
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Silu
);
sd_host
,
}
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
// o_dev.savetxt("gpu-out.txt", "float");
...
@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq
,
prec_sq
,
prec_kw
,
prec_kw
,
block_m
,
block_m
,
activation
,
gate_only
,
gate_only
,
fused_quant
};
fused_quant
};
...
@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
hidden_size
,
shared_
intermediate_size
_0
,
intermediate_size
/
tp
,
tokens
,
tokens
,
experts
,
experts
,
topk
,
topk
,
...
@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
if
(
activation
==
0
)
a_host
,
{
g_host
,
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
d_host
,
}
sa_host
,
else
sg_host
,
{
sd_host
,
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Silu
);
sy_host
,
}
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
// o_dev.savetxt("gpu-out.txt", "float");
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
a11cf2c6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
CShuffleEpilogue
,
...
@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
kOutputRank
,
kOutputRank
,
1
,
1
,
0
,
0
,
TilePartitioner
::
k
M
,
TilePartitioner
::
M
PerBlock
,
TilePartitioner
::
k
N
>>
,
TilePartitioner
::
N
PerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
a11cf2c6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
template
<
typename
Layout
>
static
constexpr
inline
auto
is_row_major
(
Layout
layout_
)
{
return
ck_tile
::
bool_constant
<
std
::
is_same_v
<
ck_tile
::
remove_cvref_t
<
decltype
(
layout_
)
>
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_batched_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
float
invoke_batched_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
...
@@ -86,56 +113,16 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -86,56 +113,16 @@ int run_batched_gemm_example_with_layouts(int argc,
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
stride_C
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_C
,
is_row_major
(
c_layout
));
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
ck_tile
::
host_tensor_descriptor
(
std
::
size_t
row
,
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
is_row_major
(
a_layout
)));
std
::
size_t
col
,
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
ck_tile
::
host_tensor_descriptor
(
std
::
size_t
stride
,
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
is_row_major
(
b_layout
)));
std
::
size_t
batch_stride
,
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
host_tensor_descriptor
(
auto
layout
)
{
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
is_row_major
(
c_layout
)));
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
stride
,
1_
uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
a_layout
);
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
b_layout
);
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
c_layout
);
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
a_layout
));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
b_layout
));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
c_layout
));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
...
@@ -171,23 +158,33 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -171,23 +158,33 @@ int run_batched_gemm_example_with_layouts(int argc,
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
host_tensor_descriptor
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
is_row_major
(
CLayout
)
{}));
c_m_n_host_ref
.
SetZero
();
c_m_n_host_ref
.
SetZero
();
const
auto
b_n_k
=
b_k_n
.
transpose
({
0
,
2
,
1
});
const
auto
b_n_k
=
b_k_n
.
transpose
({
0
,
2
,
1
});
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
ck_tile
::
host_tensor_descriptor
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
is_row_major
(
CLayout
)
{}));
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
...
@@ -240,7 +237,18 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -240,7 +237,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The GPU verification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The GPU verification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
a11cf2c6
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
#include "grouped_gemm.hpp"
#include "utils.hpp"
namespace
{
namespace
{
...
@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
...
@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
GemmEpilogue
<
CLayout
>>
;
GemmEpilogue
<
CLayout
>>
;
};
// namespace
};
// namespace
std
::
size_t
G
et
W
orkspace
S
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
std
::
size_t
g
et
_w
orkspace
_s
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
{
{
return
::
Kernel
<
std
::
nullptr_t
,
std
::
nullptr_t
,
std
::
nullptr_t
>::
GetWorkSpaceSize
(
gemm_descs
);
return
::
Kernel
<
std
::
nullptr_t
,
std
::
nullptr_t
,
std
::
nullptr_t
>::
GetWorkSpaceSize
(
gemm_descs
);
}
}
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
View file @
a11cf2c6
...
@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
...
@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
std
::
size_t
G
et
W
orkspace
S
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
);
std
::
size_t
g
et
_w
orkspace
_s
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
);
float
grouped_gemm
_calc
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
,
float
grouped_gemm
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
,
const
ck_tile
::
stream_config
&
s
,
const
ck_tile
::
stream_config
&
s
,
void
*
p_workspace_
);
void
*
p_workspace_
);
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
a11cf2c6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
template
<
typename
Layout
>
static
constexpr
inline
auto
is_row_major
(
Layout
layout_
)
{
return
ck_tile
::
bool_constant
<
std
::
is_same_v
<
ck_tile
::
remove_cvref_t
<
decltype
(
layout_
)
>
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
int
n_warmup
,
float
invoke_gemm
(
int
n_warmup
,
int
n_repeat
,
int
n_repeat
,
...
@@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup,
...
@@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup,
{
{
ck_tile
::
DeviceMem
gemm_workspace
;
ck_tile
::
DeviceMem
gemm_workspace
;
gemm_workspace
.
Realloc
(
G
et
W
orkspace
S
ize
(
args
));
gemm_workspace
.
Realloc
(
g
et
_w
orkspace
_s
ize
(
args
));
float
ave_time
=
grouped_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
float
ave_time
=
grouped_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
args
,
...
@@ -108,16 +135,19 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -108,16 +135,19 @@ int run_grouped_gemm_example_with_layouts(int argc,
const
ck_tile
::
index_t
N
=
Ns
[
i
];
const
ck_tile
::
index_t
N
=
Ns
[
i
];
const
ck_tile
::
index_t
K
=
Ks
[
i
];
const
ck_tile
::
index_t
K
=
Ks
[
i
];
stride_As
[
i
]
=
f_get_default_stride
(
M
,
N
,
stride_As
[
i
],
a_layout
);
stride_As
[
i
]
=
stride_Bs
[
i
]
=
f_get_default_stride
(
K
,
N
,
stride_Bs
[
i
],
b_layout
);
ck_tile
::
get_default_stride
(
M
,
N
,
stride_As
[
i
],
is_row_major
(
a_layout
));
stride_Cs
[
i
]
=
f_get_default_stride
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{});
stride_Bs
[
i
]
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_Bs
[
i
],
is_row_major
(
b_layout
));
a_m_k_tensors
.
push_back
(
stride_Cs
[
i
]
=
ck_tile
::
HostTensor
<
ADataType
>
(
f_host_tensor_descriptor
(
M
,
K
,
stride_As
[
i
],
a_layout
)));
ck_tile
::
get_default_stride
(
M
,
N
,
stride_Cs
[
i
],
is_row_major
(
CLayout
{}));
b_k_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
BDataType
>
(
f_host_tensor_descriptor
(
K
,
N
,
stride_Bs
[
i
],
b_layout
)));
a_m_k_tensors
.
push_back
(
ck_tile
::
HostTensor
<
ADataType
>
(
ck_tile
::
host_tensor_descriptor
(
M
,
K
,
stride_As
[
i
],
is_row_major
(
a_layout
))));
b_k_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
BDataType
>
(
ck_tile
::
host_tensor_descriptor
(
K
,
N
,
stride_Bs
[
i
],
is_row_major
(
b_layout
))));
c_m_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
CDataType
>
(
c_m_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
CDataType
>
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{})));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_Cs
[
i
],
is_row_major
(
CLayout
{})))
)
;
std
::
cout
<<
"gemm["
<<
i
<<
"]"
std
::
cout
<<
"gemm["
<<
i
<<
"]"
<<
" a_m_k: "
<<
a_m_k_tensors
[
i
]
.
mDesc
<<
" b_k_n: "
<<
b_k_n_tensors
[
i
]
.
mDesc
<<
" a_m_k: "
<<
a_m_k_tensors
[
i
]
.
mDesc
<<
" b_k_n: "
<<
b_k_n_tensors
[
i
]
.
mDesc
...
@@ -157,12 +187,23 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -157,12 +187,23 @@ int run_grouped_gemm_example_with_layouts(int argc,
{
{
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
host_tensor_descriptor
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
stride_Cs
[
i
],
CLayout
{}));
Ms
[
i
],
Ns
[
i
],
stride_Cs
[
i
],
is_row_major
(
CLayout
{}))
)
;
c_m_n_host_ref
.
SetZero
();
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k_tensors
[
i
],
b_k_n_tensors
[
i
],
c_m_n_host_ref
);
a_m_k_tensors
[
i
],
b_k_n_tensors
[
i
],
c_m_n_host_ref
);
pass
&=
ck_tile
::
check_err
(
c_m_n_tensors
[
i
],
c_m_n_host_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
Ks
[
i
],
1
/*kbatch*/
,
max_accumulated_value
);
pass
&=
ck_tile
::
check_err
(
c_m_n_tensors
[
i
],
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"gemm["
<<
i
<<
"] Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
}
}
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
...
...
example/ck_tile/17_grouped_gemm/utils.hpp
deleted
100644 → 0
View file @
a72e9efa
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
TLayout
>
constexpr
auto
f_host_tensor_descriptor
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
TLayout
layout
)
{
using
namespace
ck_tile
::
literals
;
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
}
template
<
typename
TLayout
>
constexpr
auto
f_get_default_stride
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
TLayout
layout
)
{
if
(
stride
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
}
include/ck/ck.hpp
View file @
a11cf2c6
...
@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#endif
#endif
// to do: add various levels of logging with CK_LOG_LEVEL
// to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#define CK_TIME_KERNEL 1
#define CK_TIME_KERNEL 1
#endif
// constant address space for kernel parameter
// constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
...
@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
#define CK_USE_SR_F8_CONVERSION 0
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
a11cf2c6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -122,19 +122,6 @@ __global__ void
...
@@ -122,19 +122,6 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
a_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
b_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
CDEElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
cde_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
isMultiA
||
isMultiB
)
if
constexpr
(
isMultiA
||
isMultiB
)
{
{
AsPointer
p_as_grid_grp
;
AsPointer
p_as_grid_grp
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
a11cf2c6
This diff is collapsed.
Click to expand it.
include/ck/utility/data_type.hpp
View file @
a11cf2c6
...
@@ -31,8 +31,6 @@ struct pk_i4_t
...
@@ -31,8 +31,6 @@ struct pk_i4_t
type
data
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
operator
float
()
const
{
return
static_cast
<
int8_t
>
(
data
);
}
};
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
a11cf2c6
This diff is collapsed.
Click to expand it.
include/ck/utility/type_convert.hpp
View file @
a11cf2c6
This diff is collapsed.
Click to expand it.
include/ck_tile/core.hpp
View file @
a11cf2c6
...
@@ -54,7 +54,6 @@
...
@@ -54,7 +54,6 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment