Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a11cf2c6
"example/01_gemm/gemm_wmma_bf16.cpp" did not exist on "500fa9951297c033a9c4c1d300b03895a46528d2"
Unverified
Commit
a11cf2c6
authored
Jan 24, 2025
by
arai713
Committed by
GitHub
Jan 24, 2025
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
a72e9efa
64d5c4d6
Changes
156
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
572 additions
and
967 deletions
+572
-967
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
+1
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+56
-4
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+25
-15
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
..._tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
+4
-2
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+12
-1
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
+13
-1
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+55
-52
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+4
-4
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+66
-58
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+1
-2
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
+4
-4
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+55
-14
example/ck_tile/17_grouped_gemm/utils.hpp
example/ck_tile/17_grouped_gemm/utils.hpp
+0
-38
include/ck/ck.hpp
include/ck/ck.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+1
-14
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+213
-686
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+0
-2
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+17
-6
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+40
-63
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+0
-1
No files found.
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
View file @
a11cf2c6
...
...
@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
t
.
prec_sq
,
t
.
prec_kw
,
t
.
block_m
,
t
.
activation
,
t
.
gate_only
,
t
.
fused_quant
};
auto
a1
=
fused_moegemm_args
{
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
a11cf2c6
...
...
@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format off
float
r
=
-
1
;
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
&&
t
.
activation
==
0
)
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
constexpr
ck_tile
::
index_t
act_
=
0
;
constexpr
ck_tile
::
index_t
go_
=
1
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
&&
t
.
activation
==
0
)
{
constexpr
ck_tile
::
index_t
act_
=
0
;
constexpr
ck_tile
::
index_t
go_
=
0
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
&&
t
.
activation
==
0
)
{
constexpr
ck_tile
::
index_t
act_
=
0
;
constexpr
ck_tile
::
index_t
go_
=
1
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
&&
t
.
activation
==
0
)
{
constexpr
ck_tile
::
index_t
act_
=
0
;
constexpr
ck_tile
::
index_t
go_
=
0
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
&&
t
.
activation
==
1
)
{
constexpr
ck_tile
::
index_t
act_
=
1
;
constexpr
ck_tile
::
index_t
go_
=
1
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
&&
t
.
activation
==
1
)
{
constexpr
ck_tile
::
index_t
act_
=
1
;
constexpr
ck_tile
::
index_t
go_
=
0
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
&&
t
.
activation
==
1
)
{
constexpr
ck_tile
::
index_t
act_
=
1
;
constexpr
ck_tile
::
index_t
go_
=
1
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
&&
t
.
activation
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
constexpr
ck_tile
::
index_t
act_
=
1
;
constexpr
ck_tile
::
index_t
go_
=
0
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
act_
,
go_
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// clang-format on
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
View file @
a11cf2c6
...
...
@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
IndexDataType
,
ck_tile
::
element_wise
::
FastGeluAsm
,
// TODO: hardcoded
f_shape
,
f_traits
>
;
constexpr
auto
get_activation_
=
[]()
{
if
constexpr
(
Ts_
::
Activation
==
0
)
{
return
ck_tile
::
element_wise
::
FastGeluAsm
{};
}
else
return
ck_tile
::
element_wise
::
Silu
{};
};
using
f_act_
=
ck_tile
::
remove_cvref_t
<
decltype
(
get_activation_
())
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
IndexDataType
,
f_act_
,
// TODO: hardcoded
f_shape
,
f_traits
>
;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk
<
f_problem
>
;
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
View file @
a11cf2c6
...
...
@@ -15,7 +15,8 @@ template <typename I,
typename
KW
,
typename
BlockTIle_
,
// seq<b_token, b_interm, b_hidden, b_down>
typename
WarpPerBlock_
,
typename
WarpTile_
,
// seq<*,*,*>, used to select mfma
typename
WarpTile_
,
// seq<*,*,*>, used to select mfma
ck_tile
::
index_t
Activation_
=
0
,
// 0: Gelu 1: Silu
ck_tile
::
index_t
GateOnly_
=
0
,
ck_tile
::
index_t
FusedQuant_
=
0
>
struct
fmoe_
// traits, ugly name, only used for internal
...
...
@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
Activation
=
Activation_
;
// 0: Gelu 1: Silu
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
};
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
View file @
a11cf2c6
...
...
@@ -8,7 +8,18 @@
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
View file @
a11cf2c6
...
...
@@ -8,7 +8,19 @@
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/15_fused_moe/main.cpp
View file @
a11cf2c6
...
...
@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
.
insert
(
"gate_only"
,
"1"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"api"
,
"0"
,
"benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm"
)
.
insert
(
"act"
,
"0"
,
"activation after first gemm. 0:gelu, 1:silu"
)
.
insert
(
"balance"
,
"0"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"init"
,
"2"
,
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"1"
,
"init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
"normalized[0, 1]"
"normalized(slow)"
)
.
insert
(
"seed"
,
"11939"
,
"seed used to do random"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
...
...
@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
ck_tile
::
index_t
activation
=
arg_parser
.
get_int
(
"act"
);
if
(
stride
<
0
)
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
...
...
@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
std
::
string
(
", st:"
)
+
std
::
to_string
(
stride
);
}();
std
::
cout
<<
"["
<<
api_str
<<
"|"
<<
prec_str
<<
"]"
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
stride_str
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", shrd_interm:"
<<
shared_intermediate_size_0
<<
"|"
<<
shared_intermediate_size_1
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
std
::
cout
<<
"["
<<
api_str
<<
"|"
<<
prec_str
<<
"]"
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
stride_str
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", act:"
<<
activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<<
(
gate_only
?
", g1u0"
:
", g1u1"
)
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
...
...
@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq
,
prec_kw
,
block_m
,
activation
,
gate_only
,
fused_quant
};
...
...
@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
block_m
,
hidden_size
,
shared_
intermediate_size
_0
,
intermediate_size
/
tp
,
tokens
,
experts
,
topk
,
...
...
@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
bool
pass
=
true
;
#define CPU_FUSED_MOE(act_type_) \
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
g_host, \
d_host, \
sa_host, \
sg_host, \
sd_host, \
sy_host, \
o_host, \
sorted_token_ids_host, \
sorted_weight_host, \
sorted_expert_ids_host, \
num_sorted_tiles_host, \
topk_ids_host, \
block_m, \
tokens, \
experts, \
hidden_size, \
intermediate_size / tp, \
topk, \
gate_only)
if
(
do_validation
)
{
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
...
...
@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
if
(
activation
==
0
)
{
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
}
else
{
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Silu
);
}
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
...
...
@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq
,
prec_kw
,
block_m
,
activation
,
gate_only
,
fused_quant
};
...
...
@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
shared_
intermediate_size
_0
,
intermediate_size
/
tp
,
tokens
,
experts
,
topk
,
...
...
@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
if
(
activation
==
0
)
{
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
}
else
{
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Silu
);
}
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
a11cf2c6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
...
...
@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
...
...
@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
kOutputRank
,
1
,
0
,
TilePartitioner
::
k
M
,
TilePartitioner
::
k
N
>>
,
TilePartitioner
::
M
PerBlock
,
TilePartitioner
::
N
PerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
a11cf2c6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
Layout
>
static
constexpr
inline
auto
is_row_major
(
Layout
layout_
)
{
return
ck_tile
::
bool_constant
<
std
::
is_same_v
<
ck_tile
::
remove_cvref_t
<
decltype
(
layout_
)
>
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_batched_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
...
...
@@ -86,56 +113,16 @@ int run_batched_gemm_example_with_layouts(int argc,
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
stride
,
1_
uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
a_layout
);
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
b_layout
);
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
c_layout
);
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
a_layout
));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
b_layout
));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
c_layout
));
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
stride_C
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_C
,
is_row_major
(
c_layout
));
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
ck_tile
::
host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
is_row_major
(
a_layout
)));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
ck_tile
::
host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
is_row_major
(
b_layout
)));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
is_row_major
(
c_layout
)));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
...
...
@@ -171,23 +158,33 @@ int run_batched_gemm_example_with_layouts(int argc,
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
is_row_major
(
CLayout
)
{}));
c_m_n_host_ref
.
SetZero
();
const
auto
b_n_k
=
b_k_n
.
transpose
({
0
,
2
,
1
});
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
ck_tile
::
host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
is_row_major
(
CLayout
)
{}));
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
...
...
@@ -240,7 +237,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The GPU verification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
a11cf2c6
...
...
@@ -15,7 +15,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
#include "utils.hpp"
namespace
{
...
...
@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
GemmEpilogue
<
CLayout
>>
;
};
// namespace
std
::
size_t
G
et
W
orkspace
S
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
std
::
size_t
g
et
_w
orkspace
_s
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
{
return
::
Kernel
<
std
::
nullptr_t
,
std
::
nullptr_t
,
std
::
nullptr_t
>::
GetWorkSpaceSize
(
gemm_descs
);
}
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
View file @
a11cf2c6
...
...
@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
}
std
::
size_t
G
et
W
orkspace
S
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
);
std
::
size_t
g
et
_w
orkspace
_s
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
);
float
grouped_gemm
_calc
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
,
const
ck_tile
::
stream_config
&
s
,
void
*
p_workspace_
);
float
grouped_gemm
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
,
const
ck_tile
::
stream_config
&
s
,
void
*
p_workspace_
);
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
a11cf2c6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
Layout
>
static
constexpr
inline
auto
is_row_major
(
Layout
layout_
)
{
return
ck_tile
::
bool_constant
<
std
::
is_same_v
<
ck_tile
::
remove_cvref_t
<
decltype
(
layout_
)
>
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
int
n_warmup
,
int
n_repeat
,
...
...
@@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup,
{
ck_tile
::
DeviceMem
gemm_workspace
;
gemm_workspace
.
Realloc
(
G
et
W
orkspace
S
ize
(
args
));
gemm_workspace
.
Realloc
(
g
et
_w
orkspace
_s
ize
(
args
));
float
ave_time
=
grouped_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
...
...
@@ -108,16 +135,19 @@ int run_grouped_gemm_example_with_layouts(int argc,
const
ck_tile
::
index_t
N
=
Ns
[
i
];
const
ck_tile
::
index_t
K
=
Ks
[
i
];
stride_As
[
i
]
=
f_get_default_stride
(
M
,
N
,
stride_As
[
i
],
a_layout
);
stride_Bs
[
i
]
=
f_get_default_stride
(
K
,
N
,
stride_Bs
[
i
],
b_layout
);
stride_Cs
[
i
]
=
f_get_default_stride
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{});
a_m_k_tensors
.
push_back
(
ck_tile
::
HostTensor
<
ADataType
>
(
f_host_tensor_descriptor
(
M
,
K
,
stride_As
[
i
],
a_layout
)));
b_k_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
BDataType
>
(
f_host_tensor_descriptor
(
K
,
N
,
stride_Bs
[
i
],
b_layout
)));
stride_As
[
i
]
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_As
[
i
],
is_row_major
(
a_layout
));
stride_Bs
[
i
]
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_Bs
[
i
],
is_row_major
(
b_layout
));
stride_Cs
[
i
]
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_Cs
[
i
],
is_row_major
(
CLayout
{}));
a_m_k_tensors
.
push_back
(
ck_tile
::
HostTensor
<
ADataType
>
(
ck_tile
::
host_tensor_descriptor
(
M
,
K
,
stride_As
[
i
],
is_row_major
(
a_layout
))));
b_k_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
BDataType
>
(
ck_tile
::
host_tensor_descriptor
(
K
,
N
,
stride_Bs
[
i
],
is_row_major
(
b_layout
))));
c_m_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
CDataType
>
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{})));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_Cs
[
i
],
is_row_major
(
CLayout
{})))
)
;
std
::
cout
<<
"gemm["
<<
i
<<
"]"
<<
" a_m_k: "
<<
a_m_k_tensors
[
i
]
.
mDesc
<<
" b_k_n: "
<<
b_k_n_tensors
[
i
]
.
mDesc
...
...
@@ -157,12 +187,23 @@ int run_grouped_gemm_example_with_layouts(int argc,
{
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
stride_Cs
[
i
],
CLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
stride_Cs
[
i
],
is_row_major
(
CLayout
{}))
)
;
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k_tensors
[
i
],
b_k_n_tensors
[
i
],
c_m_n_host_ref
);
pass
&=
ck_tile
::
check_err
(
c_m_n_tensors
[
i
],
c_m_n_host_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
Ks
[
i
],
1
/*kbatch*/
,
max_accumulated_value
);
pass
&=
ck_tile
::
check_err
(
c_m_n_tensors
[
i
],
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"gemm["
<<
i
<<
"] Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
}
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
...
...
example/ck_tile/17_grouped_gemm/utils.hpp
deleted
100644 → 0
View file @
a72e9efa
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
TLayout
>
constexpr
auto
f_host_tensor_descriptor
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
TLayout
layout
)
{
using
namespace
ck_tile
::
literals
;
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
}
template
<
typename
TLayout
>
constexpr
auto
f_get_default_stride
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
TLayout
layout
)
{
if
(
stride
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
}
include/ck/ck.hpp
View file @
a11cf2c6
...
...
@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#endif
// to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#define CK_TIME_KERNEL 1
#endif
// constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
...
...
@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
a11cf2c6
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -122,19 +122,6 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
a_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
b_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
CDEElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
cde_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
isMultiA
||
isMultiB
)
{
AsPointer
p_as_grid_grp
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
a11cf2c6
...
...
@@ -247,32 +247,6 @@ struct DequantPack8
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct
UnaryOpBase
{
public:
__host__
__device__
~
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
virtual
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
=
0
;
};
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -304,27 +278,8 @@ struct PassThroughPack2
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
struct
PassThrough
final
:
public
UnaryOpBase
struct
PassThrough
{
__host__
__device__
constexpr
PassThrough
()
=
default
;
__host__
__device__
constexpr
PassThrough
(
const
PassThrough
&
)
=
default
;
__host__
__device__
constexpr
PassThrough
(
PassThrough
&&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
const
PassThrough
&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
PassThrough
&&
)
=
default
;
__host__
__device__
~
PassThrough
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
x
;
}
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
...
...
@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
...
...
@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
...
...
@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
int8_t
>
(
half_t
&
y
,
const
int8_t
&
x
)
const
{
...
...
@@ -675,45 +666,21 @@ struct UnarySquare
};
};
struct
UnaryAbs
final
:
public
UnaryOpBase
struct
UnaryAbs
{
__host__
__device__
constexpr
UnaryAbs
()
=
default
;
__host__
__device__
constexpr
UnaryAbs
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
constexpr
UnaryAbs
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
~
UnaryAbs
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
};
template
<
>
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
...
...
@@ -732,41 +699,20 @@ struct UnarySqrt
};
};
struct
Relu
final
:
public
UnaryOpBase
struct
Relu
{
__host__
__device__
constexpr
Relu
()
=
default
;
__host__
__device__
constexpr
Relu
(
const
Relu
&
)
=
default
;
__host__
__device__
constexpr
Relu
(
Relu
&&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
const
Relu
&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
Relu
&&
)
=
default
;
__host__
__device__
~
Relu
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
template
<
>
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_f32
=
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
...
...
@@ -913,52 +859,18 @@ struct Gelu
}
};
struct
Sigmoid
final
:
public
UnaryOpBase
struct
Sigmoid
{
__host__
__device__
constexpr
Sigmoid
()
=
default
;
__host__
__device__
constexpr
Sigmoid
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
constexpr
Sigmoid
(
Sigmoid
&&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
Sigmoid
&&
)
=
default
;
__host__
__device__
~
Sigmoid
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
one
/
(
one
+
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
one
/
(
one
+
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
constexpr
float
one
=
type_convert
<
float
>
(
1
);
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
one
/
(
one
+
ck
::
math
::
exp
(
x_f32
));
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
}
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
math
::
exp
(
-
x
));
};
};
struct
Silu
...
...
@@ -974,44 +886,18 @@ struct Silu
};
};
struct
TanH
final
:
public
UnaryOpBase
struct
TanH
{
__host__
__device__
constexpr
TanH
()
=
default
;
__host__
__device__
constexpr
TanH
(
const
TanH
&
)
=
default
;
__host__
__device__
constexpr
TanH
(
TanH
&&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
const
TanH
&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
TanH
&&
)
=
default
;
__host__
__device__
~
TanH
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
y
=
ck
::
math
::
tanh
(
x
);
}
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
y
=
math
::
tanh
(
x
);
};
};
struct
ACos
...
...
@@ -1252,418 +1138,138 @@ struct Rcp
};
};
struct
Swish
final
:
public
UnaryOpBase
struct
Swish
{
__host__
__device__
constexpr
Swish
(
const
Swish
&
)
=
default
;
__host__
__device__
constexpr
Swish
(
Swish
&&
)
=
default
;
__host__
__device__
~
Swish
()
=
default
;
__host__
__device__
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
float
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
double
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int32_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int8_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
half_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
bhalf_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
half
_t
>::
value
,
is_same
<
X
,
ck
::
half_t
>::
value
||
is_same
<
X
,
int8
_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
is_same
<
Y
,
half
_t
>::
value
,
is_same
<
Y
,
ck
::
half_t
>::
value
||
is_same
<
Y
,
int8
_t
>::
value
,
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
math
::
exp
(
bx
)));
}
};
const
float
beta_
;
};
struct
SoftRelu
final
:
public
UnaryOpBase
struct
SoftRelu
{
__host__
__device__
constexpr
SoftRelu
(
const
SoftRelu
&
)
=
default
;
__host__
__device__
constexpr
SoftRelu
(
SoftRelu
&&
)
=
default
;
__host__
__device__
~
SoftRelu
()
=
default
;
__host__
__device__
SoftRelu
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
math
::
log
(
one
+
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
math
::
log
(
one
+
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
final
:
public
UnaryOpBase
struct
Power
{
__host__
__device__
constexpr
Power
(
const
Power
&
)
=
default
;
__host__
__device__
constexpr
Power
(
Power
&&
)
=
default
;
__host__
__device__
~
Power
()
=
default
;
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
__host__
__device__
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
)
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
__host__
__device__
float
get_gamma
()
const
{
return
gamma_
;
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
float
casted_gamma
=
type_convert
<
float
>
(
gamma_
);
float
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
double
casted_gamma
=
type_convert
<
double
>
(
gamma_
);
double
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
int32_t
casted_gamma
=
type_convert
<
int32_t
>
(
gamma_
);
int32_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
int8_t
casted_gamma
=
type_convert
<
int8_t
>
(
gamma_
);
int8_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
half_t
casted_gamma
=
type_convert
<
half_t
>
(
gamma_
);
half_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
bhalf_t
casted_gamma
=
type_convert
<
bhalf_t
>
(
gamma_
);
bhalf_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
};
struct
ClippedRelu
final
:
public
UnaryOpBase
struct
ClippedRelu
{
__host__
__device__
constexpr
ClippedRelu
(
const
ClippedRelu
&
)
=
default
;
__host__
__device__
constexpr
ClippedRelu
(
ClippedRelu
&&
)
=
default
;
__host__
__device__
~
ClippedRelu
()
=
default
;
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
__host__
__device__
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
)
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
math
::
min
(
casted_beta
,
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
alpha_
;
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
y
=
math
::
min
(
casted_beta
,
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
};
struct
LeakyRelu
final
:
public
UnaryOpBase
struct
LeakyRelu
{
__host__
__device__
constexpr
LeakyRelu
(
const
LeakyRelu
&
)
=
default
;
__host__
__device__
constexpr
LeakyRelu
(
LeakyRelu
&&
)
=
default
;
__host__
__device__
~
LeakyRelu
()
=
default
;
__host__
__device__
LeakyRelu
(
float
alpha
=
0.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
__host__
__device__
inline
void
operator
()([[
maybe_unused
]]
bhalf_t
&
y
,
[[
maybe_unused
]]
const
bhalf_t
&
x
)
const
final
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
final
:
public
UnaryOpBase
struct
Elu
{
__host__
__device__
constexpr
Elu
(
const
Elu
&
)
=
default
;
__host__
__device__
constexpr
Elu
(
Elu
&&
)
=
default
;
__host__
__device__
~
Elu
()
=
default
;
__host__
__device__
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
math
::
expm1
(
x
);
}
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
struct
Logistic
final
:
public
UnaryOpBase
struct
Logistic
{
__host__
__device__
constexpr
Logistic
(
const
Logistic
&
)
=
default
;
__host__
__device__
constexpr
Logistic
(
Logistic
&&
)
=
default
;
__host__
__device__
~
Logistic
()
=
default
;
__host__
__device__
Logistic
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
const
float
alpha_
;
};
struct
ConvInvscale
...
...
@@ -1728,7 +1334,7 @@ struct ConvScaleRelu
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
float
x
;
Relu
{}(
x
,
c
*
scale_in_
*
scale_wei_
);
Relu
{}
.
template
operator
()
<
float
>
(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
};
...
...
@@ -1809,225 +1415,138 @@ struct FastNumericArrayConverter<uint8_t, half_t, N>
struct
DynamicUnaryOp
{
DynamicUnaryOp
&
operator
=
(
const
DynamicUnaryOp
&
other
)
{
if
(
this
!=
&
other
)
{
unary_op_ptr_
=
other
.
unary_op_ptr_
;
unary_op_type_
=
other
.
unary_op_type_
;
}
return
*
this
;
}
__host__
__device__
DynamicUnaryOp
()
=
delete
;
__host__
__device__
DynamicUnaryOp
(
const
Swish
&
swish
)
:
unary_op_type_
(
UnaryOpType
::
Swish
),
swish_
{
swish
.
beta_
}
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
Swish
&&
swish
)
:
unary_op_type_
(
UnaryOpType
::
Swish
),
swish_
{
swish
.
beta_
}
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&
)
:
unary_op_type_
(
UnaryOpType
::
Sigmoid
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&&
)
:
unary_op_type_
(
UnaryOpType
::
Sigmoid
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&
)
:
unary_op_type_
(
UnaryOpType
::
PassThrough
)
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&&
)
:
unary_op_type_
(
UnaryOpType
::
PassThrough
)
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&
logistic
)
:
unary_op_type_
(
UnaryOpType
::
Logistic
),
logistic_
{
logistic
.
alpha_
}
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&&
logistic
)
:
unary_op_type_
(
UnaryOpType
::
Logistic
),
logistic_
{
logistic
.
alpha_
}
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&
)
:
unary_op_type_
(
UnaryOpType
::
TanH
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&&
)
:
unary_op_type_
(
UnaryOpType
::
TanH
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&
)
:
unary_op_type_
(
UnaryOpType
::
Relu
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&&
)
:
unary_op_type_
(
UnaryOpType
::
Relu
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&
softrelu
)
:
unary_op_type_
(
UnaryOpType
::
SoftRelu
),
soft_relu_
{
softrelu
.
alpha_
}
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&&
softrelu
)
:
unary_op_type_
(
UnaryOpType
::
SoftRelu
),
soft_relu_
{
softrelu
.
alpha_
}
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&
)
:
unary_op_type_
(
UnaryOpType
::
UnaryAbs
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&&
)
:
unary_op_type_
(
UnaryOpType
::
UnaryAbs
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&
pow
)
:
unary_op_type_
(
UnaryOpType
::
Power
),
power_
(
pow
.
alpha_
,
pow
.
beta_
,
pow
.
gamma_
)
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&&
pow
)
:
unary_op_type_
(
UnaryOpType
::
Power
),
power_
(
pow
.
alpha_
,
pow
.
beta_
,
pow
.
gamma_
)
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&
clippedrelu
)
:
unary_op_type_
(
UnaryOpType
::
ClippedRelu
),
clipped_relu_
{
clippedrelu
.
alpha_
,
clippedrelu
.
beta_
}
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&&
clippedrelu
)
:
unary_op_type_
(
UnaryOpType
::
ClippedRelu
),
clipped_relu_
{
clippedrelu
.
alpha_
,
clippedrelu
.
beta_
}
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&
leakyrelu
)
:
unary_op_type_
(
UnaryOpType
::
LeakyRelu
),
leaky_relu_
{
leakyrelu
.
alpha_
}
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&&
leakyrelu
)
:
unary_op_type_
(
UnaryOpType
::
LeakyRelu
),
leaky_relu_
{
leakyrelu
.
alpha_
}
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&
elu
)
:
unary_op_type_
(
UnaryOpType
::
Elu
),
elu_
{
elu
.
alpha_
}
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&&
elu
)
:
unary_op_type_
(
UnaryOpType
::
Elu
),
elu_
{
elu
.
alpha_
}
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
DynamicUnaryOp
&
dynamic_op
)
:
unary_op_type_
(
dynamic_op
.
unary_op_type_
),
unary_op_ptr_
(
dynamic_op
.
unary_op_ptr_
),
alpha
(
dynamic_op
.
alpha
),
beta
(
dynamic_op
.
beta
),
gamma
(
dynamic_op
.
gamma
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
DynamicUnaryOp
&
dynamic_op
)
=
default
;
__host__
__device__
~
DynamicUnaryOp
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
delete
static_cast
<
Swish
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
delete
static_cast
<
Sigmoid
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
PassThrough
):
delete
static_cast
<
PassThrough
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Logistic
):
delete
static_cast
<
Logistic
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
TanH
):
delete
static_cast
<
TanH
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Relu
):
delete
static_cast
<
Relu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
delete
static_cast
<
SoftRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
delete
static_cast
<
UnaryAbs
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Power
):
delete
static_cast
<
Power
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
delete
static_cast
<
ClippedRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
delete
static_cast
<
LeakyRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Elu
):
delete
static_cast
<
Elu
*>
(
unary_op_ptr_
);
break
;
default:
break
;
}
}
__device__
void
InitUnaryOpPtrOnDevice
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
unary_op_ptr_
=
new
Swish
(
beta
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
unary_op_ptr_
=
new
Sigmoid
;
break
;
case
(
UnaryOpType
::
PassThrough
):
unary_op_ptr_
=
new
PassThrough
;
break
;
case
(
UnaryOpType
::
Logistic
):
unary_op_ptr_
=
new
Logistic
(
alpha
);
break
;
case
(
UnaryOpType
::
TanH
):
unary_op_ptr_
=
new
TanH
;
break
;
case
(
UnaryOpType
::
Relu
):
unary_op_ptr_
=
new
Relu
;
break
;
case
(
UnaryOpType
::
SoftRelu
):
unary_op_ptr_
=
new
SoftRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
unary_op_ptr_
=
new
UnaryAbs
;
break
;
case
(
UnaryOpType
::
Power
):
unary_op_ptr_
=
new
Power
(
alpha
,
beta
,
gamma
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
unary_op_ptr_
=
new
ClippedRelu
(
alpha
,
beta
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
unary_op_ptr_
=
new
LeakyRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
Elu
):
unary_op_ptr_
=
new
Elu
(
alpha
);
break
;
default:
unary_op_ptr_
=
nullptr
;
break
;
}
}
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
unary_op_ptr_
->
operator
()(
y
,
x
);
}
__host__
__device__
~
DynamicUnaryOp
()
{}
template
<
typename
Y
,
typename
X
>
__host__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
S
wish
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
S
igmoid
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
PassThrough
):
P
ass
T
hrough
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Logistic
):
L
ogistic
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
TanH
):
T
an
H
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Relu
):
R
elu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
S
oft
R
elu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
U
nary
Abs
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Power
):
P
ower
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
C
lipped
R
elu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
L
eaky
R
elu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Elu
):
E
lu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Swish
):
s
wish
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
s
igmoid
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
PassThrough
):
p
ass
_t
hrough
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Logistic
):
l
ogistic
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
TanH
):
t
an
h_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Relu
):
r
elu
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
s
oft
_r
elu
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
u
nary
_abs_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Power
):
p
ower
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
c
lipped
_r
elu
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
l
eaky
_r
elu
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Elu
):
e
lu
_
(
y
,
x
);
break
;
default:
break
;
}
}
template
<
typename
X
,
typename
Y
>
__
device__
__host__
constexpr
void
isSupported
(
)
const
template
<
>
__
host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
static_assert
(
std
::
is_same
<
X
,
Y
>::
value
,
"X and Y must be of the same type"
);
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
bhalf_t
>::
value
||
is_same
<
X
,
half_t
>::
value
||
is_same
<
X
,
int32_t
>::
value
||
is_same
<
X
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
float
y_float
;
float
x_float
=
type_convert
<
float
>
(
x
);
this
->
operator
()(
y_float
,
x_float
);
y
=
type_convert
<
bhalf_t
>
(
y_float
);
}
private:
...
...
@@ -2049,12 +1568,20 @@ struct DynamicUnaryOp
public:
UnaryOpType
unary_op_type_
;
UnaryOpBase
*
unary_op_ptr_
=
nullptr
;
float
alpha
;
float
beta
;
float
gamma
;
Swish
swish_
;
Sigmoid
sigmoid_
;
PassThrough
pass_through_
;
Logistic
logistic_
;
TanH
tanh_
;
Relu
relu_
;
SoftRelu
soft_relu_
;
UnaryAbs
unary_abs_
;
Power
power_
;
ClippedRelu
clipped_relu_
;
LeakyRelu
leaky_relu_
;
Elu
elu_
;
};
#pragma clang diagnostic pop
}
// namespace element_wise
}
// namespace tensor_operation
...
...
include/ck/utility/data_type.hpp
View file @
a11cf2c6
...
...
@@ -31,8 +31,6 @@ struct pk_i4_t
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
operator
float
()
const
{
return
static_cast
<
int8_t
>
(
data
);
}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
a11cf2c6
...
...
@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize
element_space_size_
;
T
invalid_element_value_
=
T
{
0
};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
T
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
...
...
@@ -82,14 +89,18 @@ struct DynamicBuffer
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
,
invalid_element_value_
);
}
}
else
...
...
@@ -191,7 +202,7 @@ struct DynamicBuffer
dst_buf
.
p_data_
,
dst_offset
,
is_valid_element
,
element_space_size_
);
element_space_size_
/
PackedSize
);
}
template
<
typename
X
,
...
...
@@ -226,7 +237,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
&&
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
&&
...
...
@@ -378,7 +389,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
{
...
...
@@ -417,7 +428,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
if
(
is_valid_element
)
{
...
...
include/ck/utility/type_convert.hpp
View file @
a11cf2c6
...
...
@@ -14,6 +14,41 @@ namespace ck {
#define __gfx94__
#endif
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
// Nan check
if
(
x
!=
x
)
{
return
uint16_t
(
0x7FC0
);
}
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
const
uint32_t
first_bf16_mantisa_bit
=
((
u
.
int32
>>
16
)
&
1
);
constexpr
uint32_t
rounding_bias
=
uint32_t
((
1
<<
15
)
-
1
);
return
uint16_t
((
u
.
int32
+
first_bf16_mantisa_bit
+
rounding_bias
)
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
...
...
@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return
u
.
fp32
;
}
// convert fp32 to bfp16
// convert fp32 to bfp16
, round to nearest even
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
#if CK_USE_RNE_BF16_CONVERSION
return
bf16_convert_rtn
<
bhalf_t
>
(
x
);
#else
return
uint16_t
(
u
.
int32
>>
16
);
#endif
}
// convert bfp16 to fp16 via fp32
...
...
@@ -635,60 +668,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array
}
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
}
// namespace ck
include/ck_tile/core.hpp
View file @
a11cf2c6
...
...
@@ -54,7 +54,6 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment