Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9e063018
Commit
9e063018
authored
Oct 30, 2024
by
carlushuang
Browse files
dynamic-quant ready
parent
e2935465
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
579 additions
and
112 deletions
+579
-112
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+81
-55
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+117
-14
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+10
-7
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+27
-25
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/numeric/int8.hpp
include/ck_tile/core/numeric/int8.hpp
+105
-0
include/ck_tile/core/numeric/type_convert.hpp
include/ck_tile/core/numeric/type_convert.hpp
+4
-0
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+55
-0
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
+32
-5
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
+95
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+36
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+8
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+2
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+3
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+2
-4
No files found.
example/ck_tile/02_layernorm2d/generate.py
View file @
9e063018
...
...
@@ -30,8 +30,7 @@ FUSED_ADD_ENUM_STR_MAP = [
FUSED_FUSED_SWEEP_STR_MAP
=
[
'no'
,
'renorm'
,
'dequant'
]
'dquant'
]
DATA_TYPE_MAP
=
{
'fp16'
:
'ck_tile::fp16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
...
...
@@ -48,6 +47,7 @@ class layernorm_fwd_codegen:
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_,
typename YDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
...
...
@@ -62,6 +62,7 @@ struct layernorm2d_fwd_traits_
{
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
...
...
@@ -121,6 +122,7 @@ struct layernorm2d_fwd_traits_
template <typename XDataType_,
typename YDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
...
...
@@ -133,6 +135,7 @@ template <typename XDataType_,
int kFusedSweep_>
using traits_ = layernorm2d_fwd_traits_<XDataType_,
YDataType_,
YScaleDataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
...
...
@@ -165,7 +168,8 @@ float layernorm2d_fwd_(const S& s, A a)
{{
using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType>::ComputeDataType;
using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd,
...
...
@@ -173,13 +177,14 @@ float layernorm2d_fwd_(const S& s, A a)
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedSweepEnum>(Traits_::kFusedSweep)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::YDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::MeanDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::InvStdDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::YDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::MeanDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::InvStdDataType,
typename LayerNormTypeConfig<XDataType, YDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape,
PipelineTraits>;
...
...
@@ -190,7 +195,12 @@ float layernorm2d_fwd_(const S& s, A a)
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using Epilogue = Default2DEpilogue;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using Epilogue = std::conditional_t<Traits_::kFusedSweep == 1, DynamicQuantEpilogue, Default2DEpilogue>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
...
...
@@ -247,7 +257,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp"
// clang-format off
// prec_i prec_o rm rn tm tn vn pd mv 2p add sweep
// prec_i prec_o
prec_s
rm rn tm tn vn pd mv 2p add sweep
{F_instance_def}
// clang-format on
...
...
@@ -325,6 +335,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
class
h_traits
:
F_XDataType
:
str
F_YDataType
:
str
F_YScaleDataType
:
str
F_Repeat_M
:
int
F_Repeat_N
:
int
F_ThreadPerBlock_M
:
int
...
...
@@ -338,7 +349,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@
property
def
trait_name
(
self
)
->
str
:
t_
=
f
'
{
DATA_TYPE_MAP
[
self
.
F_XDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YDataType
]
}
,
{
self
.
F_Repeat_M
:
2
}
,
{
self
.
F_Repeat_N
:
2
}
,
{
self
.
F_ThreadPerBlock_M
:
2
}
,
{
self
.
F_ThreadPerBlock_N
:
4
}
'
t_
=
f
'
{
DATA_TYPE_MAP
[
self
.
F_XDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YScaleDataType
]
}
,
{
self
.
F_Repeat_M
:
2
}
,
{
self
.
F_Repeat_N
:
2
}
,
{
self
.
F_ThreadPerBlock_M
:
2
}
,
{
self
.
F_ThreadPerBlock_N
:
4
}
'
t_
+=
f
',
{
self
.
F_Vector_N
:
2
}
,
{
BOOL_MAP
(
self
.
F_kPadN
):
5
}
,
{
BOOL_MAP
(
self
.
F_kSaveMeanInvStd_
):
5
}
'
t_
+=
f
',
{
BOOL_MAP
(
self
.
F_kTwoPass_
):
5
}
,
{
self
.
F_kFusedAdd
:
4
}
,
{
self
.
F_kFusedSweep
:
4
}
'
return
t_
...
...
@@ -424,7 +435,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
inner_str
+=
self
.
API_INNER_CASE
.
format
(
F_if
=
get_if_str
(
idx_in_n
,
len_in_n
,
False
),
F_VEC_COND
=
_cond
,
F_instance_func
=
ins
.
call_name
)
#inner_str = inner_str + vec_str
n_cnd
=
f
'(a.n <=
{
n_
}
)'
if
n_
!=
'big'
else
''
n_cnd
=
f
'(a.n <=
{
n_
}
)'
if
(
i_n
<
len
(
blob_per_t
)
-
1
)
else
''
n_str
+=
self
.
API_PER_N_CASE
.
format
(
F_if
=
get_if_str
(
i_n
,
len
(
blob_per_t
)),
F_N_COND
=
n_cnd
,
F_inner_dispatch
=
inner_str
)
prec_i
,
prec_o
=
dtype_
.
split
(
','
)
d_str
+=
self
.
API_PER_DTYPE
.
format
(
F_if
=
get_if_str
(
i_d
,
len
(
t_dtype_dict
),
False
),
F_i_type
=
prec_i
,
F_o_type
=
prec_o
,
F_per_n_case
=
n_str
)
...
...
@@ -440,61 +451,76 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_traits
=
layernorm_fwd_codegen
.
h_traits
h_instance
=
layernorm_fwd_codegen
.
h_instance
dynamic_quant_out_dtype
=
[
'int8'
]
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
)]
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
)]
# NOTE: only fused-dynamic-quant use int8 out
fused_add_list
=
[
0
,
1
,
2
]
fused_sweep_list
=
[
0
]
# rm rn tm tn vn pd mv 2p add sweep
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'128'
:
[
h_traits
(
'x'
,
'y'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'256'
:
[
h_traits
(
'x'
,
'y'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'512'
:
[
h_traits
(
'x'
,
'y'
,
1
,
1
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
2
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
8
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'768'
:
[
h_traits
(
'x'
,
'y'
,
1
,
3
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
6
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
12
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'1024'
:[
h_traits
(
'x'
,
'y'
,
1
,
1
,
2
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
2
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
2
,
128
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'1536'
:[
h_traits
(
'x'
,
'y'
,
1
,
3
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
3
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
3
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
6
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'2048'
:[
h_traits
(
'x'
,
'y'
,
1
,
1
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
2
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
8
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'3072'
:[
h_traits
(
'x'
,
'y'
,
1
,
3
,
1
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
3
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
6
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
3
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'4096'
:[
h_traits
(
'x'
,
'y'
,
1
,
1
,
1
,
512
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'big'
:[
h_traits
(
'x'
,
'y'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
0
,
0
)]}
fused_sweep_list
=
[
0
,
1
]
# NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv 2p add sweep
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'128'
:
[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'256'
:
[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'512'
:
[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
1
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
2
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
8
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'768'
:
[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
6
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
12
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'1024'
:[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
1
,
2
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
2
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
2
,
128
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'1536'
:[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
6
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'2048'
:[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
1
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
2
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
8
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'3072'
:[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
1
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
6
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'4096'
:[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
1
,
1
,
512
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'6144'
:[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
3
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
6
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'8192'
:[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'big'
:[
h_traits
(
'x'
,
'y'
,
's'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
's'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
0
,
0
)]}
total_blob
=
list
()
for
hs_key
in
h_trait_dict
:
hs
=
h_trait_dict
[
hs_key
]
current_n
=
hs
[
0
].
F_Repeat_N
*
hs
[
0
].
F_ThreadPerBlock_N
*
hs
[
0
].
F_Vector_N
for
dtype
,
fused_add
,
fused_sweep
in
itertools
.
product
(
dtype_list
,
fused_add_list
,
fused_sweep_list
):
prec_i
,
prec_o
=
dtype
.
split
(
','
)
if
prec_o
in
dynamic_quant_out_dtype
and
fused_sweep
!=
1
:
continue
# skip non dynamic quant case
if
fused_sweep
==
1
and
hs_key
==
'big'
:
continue
current_hs
=
list
()
for
chs_
in
hs
:
h_
=
copy
.
copy
(
chs_
)
# copy the base instance out
h_
.
F_XDataType
=
prec_i
h_
.
F_YDataType
=
prec_o
h_
.
F_YScaleDataType
=
prec_i
h_
.
F_kFusedAdd
=
fused_add
h_
.
F_kFusedSweep
=
fused_sweep
current_hs
.
append
(
h_
)
# + "\n"
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
9e063018
...
...
@@ -32,8 +32,11 @@ auto create_args(int argc, char* argv[])
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec_i"
,
"fp16"
,
"input precision"
)
.
insert
(
"prec_o"
,
"auto"
,
"output precision, set auto will be the same as input"
)
.
insert
(
"prec_s"
,
"auto"
,
"output quant scale type, set auto will be the same as input. used when fsweep=1"
)
.
insert
(
"fadd"
,
"0"
,
"fused-add, 0:no fused add, 1:preadd+store, 2:preadd only"
)
.
insert
(
"fsweep"
,
"0"
,
"fused-sweep"
)
.
insert
(
"fsweep"
,
"0"
,
"fused-sweep
, 0:no, 1:fused-dynamic-quant
"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
...
@@ -41,7 +44,7 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
InDataType
,
typename
OutDataType
,
bool
SaveMeanVar
>
template
<
typename
InDataType
,
typename
OutDataType
,
typename
ScaleDataType
,
bool
SaveMeanVar
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
...
...
@@ -52,27 +55,38 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_s
=
arg_parser
.
get_str
(
"prec_s"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_s
==
"auto"
)
{
prec_s
=
prec_i
;
}
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_add
=
arg_parser
.
get_int
(
"fadd"
);
int
fused_sweep
=
arg_parser
.
get_int
(
"fsweep"
);
if
(
fused_sweep
==
1
&&
prec_o
!=
"int8"
)
{
std
::
cout
<<
"if fused_sweep is 1, only support
\"
-prec_o=int8
\"
case"
<<
std
::
endl
;
return
false
;
}
assert
(
stride
>=
n
);
using
TypeConfig
=
LayerNormTypeConfig
<
InDataType
,
OutDataType
>
;
using
TypeConfig
=
LayerNormTypeConfig
<
InDataType
,
OutDataType
,
ScaleDataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
SXDataType
=
XDataType
;
using
SYDataType
=
Y
DataType
;
using
SYDataType
=
X
DataType
;
using
MeanDataType
=
std
::
conditional_t
<
SaveMeanVar
,
typename
TypeConfig
::
MeanDataType
,
ck_tile
::
null_type
>
;
...
...
@@ -94,6 +108,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
ScaleDataType
>
y_scale_host_ref
({
m
});
ck_tile
::
HostTensor
<
ScaleDataType
>
y_scale_host_dev
({
m
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
...
...
@@ -103,6 +119,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_scale_buf
(
y_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sx_buf
(
sx_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sy_buf
(
sy_host
.
get_element_space_size_in_bytes
());
...
...
@@ -112,10 +129,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
beta_buf
.
ToDevice
(
beta_host
.
data
());
sx_buf
.
ToDevice
(
sx_host
.
data
());
std
::
cout
<<
"["
<<
prec_i
<<
"]"
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_o
)
{
base_str
+=
"|"
+
prec_o
;
}
if
(
fused_sweep
==
1
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_s
+
")"
;
}
return
base_str
;
}();
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
SaveMeanVar
,
fused_add
,
fused_sweep
};
layernorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
prec_s
,
SaveMeanVar
,
fused_add
,
fused_sweep
};
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
fused_add
!=
0
?
sx_buf
.
GetDeviceBuffer
()
:
nullptr
,
...
...
@@ -125,6 +155,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_add
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
nullptr
,
fused_sweep
==
1
?
y_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
epsilon
,
m
,
n
,
...
...
@@ -170,6 +201,50 @@ bool run(const ck_tile::ArgParser& arg_parser)
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
if
(
fused_sweep
==
1
)
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
o_
,
auto
acc_
)
{
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
ComputeDataType
absmax
=
0
;
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
const
auto
a
=
abs
(
acc_
(
m_
,
n_
));
absmax
=
a
>
absmax
?
a
:
absmax
;
}
y_scale_host_ref
(
m_
)
=
absmax
/
127.0
;
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
o_
(
m_
,
n_
)
=
static_cast
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale_host_ref
(
m_
));
}
};
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
,
dquant_functor
);
}
else
{
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
}
y_buf
.
FromDevice
(
y_host_dev
.
data
());
ck_tile
::
HostTensor
<
SYDataType
>
sy_host_dev
({
m
,
n
},
{
stride
,
1
});
...
...
@@ -179,6 +254,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
auto
[
rtol
,
atol
]
=
get_elimit
<
InDataType
>
();
if
(
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
...
...
@@ -218,6 +294,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
if
(
fused_sweep
==
1
)
{
y_scale_buf
.
FromDevice
(
y_scale_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
y_scale_host_dev
,
y_scale_host_ref
,
std
::
string
(
"SCALE Error: Incorrect results!"
),
rtol
,
atol
);
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
...
...
@@ -233,26 +318,44 @@ int main(int argc, char* argv[])
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_s
=
arg_parser
.
get_str
(
"prec_s"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_s
==
"auto"
)
{
prec_s
=
prec_i
;
}
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
save_mv
)
// no dynamic quant case
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_s
==
"fp16"
&&
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
ck_tile
::
half_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_s
==
"fp16"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
ck_tile
::
half_t
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_s
==
"bf16"
&&
save_mv
)
{
return
run
<
ck_tile
::
half
_t
,
ck_tile
::
half
_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16
_t
,
ck_tile
::
bf16
_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"f
p
16"
&&
prec_o
==
"f
p
16"
&&
!
save_mv
)
else
if
(
prec_i
==
"
b
f16"
&&
prec_o
==
"
bf16"
&&
prec_s
==
"b
f16"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
half
_t
,
ck_tile
::
half_t
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16
_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
save_mv
)
// dynamic quant case, only in inference
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"int8"
&&
prec_s
==
"fp16"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
bf16
_t
,
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half
_t
,
ck_tile
::
int8_t
,
ck_tile
::
half_t
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
!
save_mv
)
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"int8"
&&
prec_s
==
"bf16"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
9e063018
...
...
@@ -8,11 +8,11 @@
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
template
<
typename
InType
,
typename
OutType
>
template
<
typename
InType
,
typename
OutType
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
;
template
<
typename
OutType
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
,
OutType
>
template
<
typename
OutType
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
,
OutType
,
YScaleDataType_
>
{
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
OutType
;
...
...
@@ -21,10 +21,11 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType>
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
using
ComputeDataType
=
float
;
using
YScaleDataType
=
YScaleDataType_
;
};
template
<
typename
OutType
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
,
OutType
>
template
<
typename
OutType
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
,
OutType
,
YScaleDataType_
>
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
OutType
;
...
...
@@ -33,6 +34,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType>
using
MeanDataType
=
ck_tile
::
bf16_t
;
using
InvStdDataType
=
ck_tile
::
bf16_t
;
using
ComputeDataType
=
float
;
using
YScaleDataType
=
YScaleDataType_
;
};
// runtime args
...
...
@@ -48,9 +50,10 @@ struct layernorm2d_fwd_traits
{
std
::
string
prec_i
;
std
::
string
prec_o
;
std
::
string
prec_s
;
// scale value, used as scale factor store out when fused_sweep=1
bool
save_mean_var
;
int
fused_add
;
// 0:no-add, 1:pre-add, 2:p
ost
-add
int
fused_sweep
;
// 0:no-sweep,
int
fused_add
;
// 0:no-add, 1:pre-add
-store
, 2:p
re
-add
int
fused_sweep
;
// 0:no-sweep,
1:dynamic-quant
};
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
9e063018
...
...
@@ -2,32 +2,34 @@
# call from top of CK folder
EXE
=
./build/bin/tile_example_layernorm2d_fwd
for
fsweep
in
""
"-fsweep=1 -prec_o=int8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
fadd
in
"0"
"1"
"2"
;
do
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
99
-n
=
13
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
17
-n
=
16
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
1
-n
=
100
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
4
-n
=
128
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
80
-n
=
127
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
7
-n
=
599
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
19
-n
=
512
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
11
-n
=
510
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
91
-n
=
636
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
31
-n
=
1024
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
8
-n
=
1501
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
3
-n
=
1826
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
5
-n
=
2040
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
7
-n
=
2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
3
-n
=
8192
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
1
-n
=
10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
3
-n
=
17134
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
99
-n
=
13
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
17
-n
=
16
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
1
-n
=
100
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
4
-n
=
128
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
80
-n
=
127
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
7
-n
=
599
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
19
-n
=
512
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
11
-n
=
510
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
91
-n
=
636
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
31
-n
=
1024
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
8
-n
=
1501
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
3
-n
=
1826
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
5
-n
=
2040
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
7
-n
=
2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fsweep
-m
=
3
-n
=
8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fsweep -m=3 -n=17134
done
done
done
include/ck_tile/core.hpp
View file @
9e063018
...
...
@@ -24,6 +24,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
...
...
include/ck_tile/core/numeric/int8.hpp
0 → 100644
View file @
9e063018
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace
ck_tile
{
// use int8_t directly for int8 arithemetic
// here one can use ck_tile::int8_t to access original int8_t
using
int8_t
=
int8_t
;
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
int8_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
min
()
{
return
int8_t
(
-
128
);
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
lowest
()
{
return
int8_t
(
-
128
);
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
max
()
{
return
int8_t
(
127
);
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
epsilon
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
round_error
()
{
return
1
;
// not used
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
infinity
()
{
return
1
;
// not used
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
quiet_NaN
()
{
return
1
;
// not used
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
signaling_NaN
()
{
return
1
;
// not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
denorm_min
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
zero
()
{
return
0
;
}
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr
float
int8_to_float
(
const
int8_t
&
x
)
{
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
int8_t
float_to_int8
(
const
float
&
x
)
{
return
static_cast
<
int8_t
>
(
x
);
}
}
// namespace ck_tile
include/ck_tile/core/numeric/type_convert.hpp
View file @
9e063018
...
...
@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
namespace
ck_tile
{
...
...
@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT
(
fp8_t
,
fp8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf8_t
,
bf8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
int8_t
,
int8
)
CK_TILE_TYPE_CONVERT
(
int8_t
,
int8
,
float
,
float
)
#undef CK_TILE_TYPE_CONVERT
#endif
...
...
include/ck_tile/host/check_err.hpp
View file @
9e063018
...
...
@@ -396,4 +396,59 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
return
res
;
}
#if 0
// TODO: Note, int8 validation is risky, need more check
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, int8_t>,
bool>::type CK_TILE_HOST
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error[int8]: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool /*allow_infinity_ref*/ = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}
#endif
}
// namespace ck_tile
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
View file @
9e063018
...
...
@@ -8,20 +8,44 @@
namespace
ck_tile
{
// Note: for simplicity, each functor only care about single M
struct
reference_layernorm2d_default_epilogue
{
template
<
typename
OutDataType
,
typename
AccDataType
>
void
operator
()(
int
m
,
HostTensor
<
OutDataType
>&
o
,
const
HostTensor
<
AccDataType
>&
acc
)
{
const
int
N
=
acc
.
mDesc
.
get_lengths
()[
1
];
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
o
(
m
,
n
)
=
ck_tile
::
type_convert
<
OutDataType
>
(
acc
(
m
,
n
));
}
}
template
<
typename
OutDataType
,
typename
AccDataType
>
auto
operator
()(
int
m
,
const
HostTensor
<
AccDataType
>&
acc
)
{
HostTensor
<
OutDataType
>
o
(
acc
.
get_lengths
(),
acc
.
get_strides
());
operator
()(
m
,
o
,
acc
);
return
o
;
}
};
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
InvStdDataType
>
typename
InvStdDataType
,
typename
Epilogue
=
reference_layernorm2d_default_epilogue
>
void
reference_layernorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
BetaDataType
>&
beta_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
MeanDataType
>&
mean_m
,
HostTensor
<
InvStdDataType
>&
invStd_m
,
ComputeDataType
epsilon
)
ComputeDataType
epsilon
,
Epilogue
epilogue_functor
=
{})
{
auto
layernorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
...
...
@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if
constexpr
(
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
)
invStd_m
(
m
)
=
ck_tile
::
type_convert
<
InvStdDataType
>
(
divisor
);
HostTensor
<
ComputeDataType
>
acc
(
x_m_n
.
get_lengths
(),
x_m_n
.
get_strides
());
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
beta
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
beta_n
(
n
));
auto
y
=
(
x
-
mean
)
*
divisor
;
y
=
y
*
gamma
+
beta
;
auto
a_
=
(
x
-
mean
)
*
divisor
;
a_
=
a_
*
gamma
+
beta
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
)
;
acc
(
m
,
n
)
=
a_
;
}
epilogue_functor
(
m
,
y_m_n
,
acc
);
};
make_ParallelTensorFunctor
(
layernorm2d_fwd_func
,
...
...
include/ck_tile/ops/epilogue.hpp
View file @
9e063018
...
...
@@ -5,4 +5,5 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
0 → 100644
View file @
9e063018
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
UseRawStore_
=
true
,
bool
UseMax3_
=
false
>
struct
DynamicQuantEpilogueTraits
{
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
static
constexpr
bool
UseMax3
=
UseMax3_
;
};
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
YScaleDataType_
,
typename
ODataType_
,
typename
Traits_
>
struct
DynamicQuantEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
DynamicQuantEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
Traits
::
UseRawStore
;
static
constexpr
bool
UseMax3
=
Problem
::
Traits
::
UseMax3
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
)
{
// compute row max
auto
reduce_row_absmax
=
BlockReduce2D
{
o_acc_tile
,
type_convert
<
AccDataType
>
(
0
)};
auto
row_absmax
=
[
&
]()
{
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
AccDataType
,
float
>
)
{
const
auto
f_max
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
// const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
// float rtn;
// asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
// : "=v"(rtn)
// : "v"(acc_), "v"(v_0_), "v"(v_1_));
// return rtn;
// };
// return reduce_row_absmax(f_max3, f_max, sequence<1, 2>{});
return
reduce_row_absmax
(
f_max
);
}
else
{
const
auto
f_max
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
return
reduce_row_absmax
(
f_max
);
}
}();
// here y_scale is Acc TYpe, need convert to YScale type later
auto
y_scale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
AccDataType
>
(
numeric
<
ODataType
>::
max
());
},
row_absmax
);
store_tile
(
y_scale_window
,
cast_tile
<
YScaleDataType
>
(
y_scale
));
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
9e063018
...
...
@@ -21,6 +21,7 @@ struct Layernorm2dFwdHostArgs
void
*
p_sy
;
// shortcut output, set to nullptr if no
void
*
p_mean
;
void
*
p_invStd
;
void
*
p_y_scale
;
// store out a dynamic quant per row, used by next layer. nullptr if not used
float
epsilon
;
...
...
@@ -44,6 +45,7 @@ struct Layernorm2dFwd
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
// for simplicity, shortcut input/output type is same as X
using
SXDataType
=
XDataType
;
...
...
@@ -81,6 +83,7 @@ struct Layernorm2dFwd
void
*
p_sy
;
// shortcut output, set to nullptr if no
void
*
p_mean
;
void
*
p_invStd
;
void
*
p_y_scale
;
// store out a dynamic quant value, used in next layer
float
epsilon
;
...
...
@@ -100,6 +103,7 @@ struct Layernorm2dFwd
hargs
.
p_sy
,
hargs
.
p_mean
,
hargs
.
p_invStd
,
hargs
.
p_y_scale
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
n
,
...
...
@@ -120,6 +124,7 @@ struct Layernorm2dFwd
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
// in byte
...
...
@@ -140,7 +145,18 @@ struct Layernorm2dFwd
if
(
kTwoPass
)
n
+=
"_2p"
;
return
n
;
}();
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
XDataType
>::
name
);
if
(
!
std
::
is_same_v
<
XDataType
,
YDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
YDataType
>::
name
);
}
if
(
kFusedSweep
==
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_s"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
...
...
@@ -295,6 +311,24 @@ struct Layernorm2dFwd
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
auto
y_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedSweep
==
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
YScaleDataType
*>
(
kargs
.
p_y_scale
),
make_tuple
(
kargs
.
m
),
number
<
1
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
...
...
@@ -305,6 +339,7 @@ struct Layernorm2dFwd
sy_window
,
mean_window
,
inv_std_window
,
y_scale_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
smem
,
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
9e063018
...
...
@@ -59,6 +59,7 @@ struct Layernorm2dFwdPipelineOnePass
typename
SYWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
SXWindow
&
sx_window_
,
...
...
@@ -68,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass
const
SYWindow
&
sy_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
YScaleWindow
&
y_scale_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
,
...
...
@@ -143,7 +145,12 @@ struct Layernorm2dFwdPipelineOnePass
ln
(
idx
)
=
ln_
;
});
Epilogue
{}(
y_window_
,
ln
);
if
constexpr
(
kFusedSweep
==
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
y_scale_window
,
ln
);
}
else
Epilogue
{}(
y_window_
,
ln
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
9e063018
...
...
@@ -14,6 +14,7 @@ template <typename XDataType_,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
typename
Traits_
>
struct
Layernorm2dFwdPipelineProblem
...
...
@@ -25,6 +26,7 @@ struct Layernorm2dFwdPipelineProblem
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
9e063018
...
...
@@ -58,6 +58,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename
SYWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
SXWindow
&
sx_window_
,
...
...
@@ -67,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
const
SYWindow
&
sy_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
,
...
...
@@ -189,6 +191,7 @@ struct Layernorm2dFwdPipelineTwoPass
ln
(
idx
)
=
ln_
;
});
static_assert
(
kFusedSweep
!=
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
);
Epilogue
{}(
y_window
,
ln
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
View file @
9e063018
...
...
@@ -26,15 +26,13 @@ template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD>
enum
class
Layernorm2dFusedSweepEnum
{
NO_SWEEP
=
0
,
RENORM
=
1
,
DYNAMIC_QUANT
=
2
,
DYNAMIC_QUANT
=
1
,
};
// clang-format off
template
<
Layernorm2dFusedSweepEnum
E
>
struct
Layernorm2dFusedSweepEnumName
;
template
<
>
struct
Layernorm2dFusedSweepEnumName
<
Layernorm2dFusedSweepEnum
::
NO_SWEEP
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dFusedSweepEnumName
<
Layernorm2dFusedSweepEnum
::
RENORM
>
{
static
constexpr
const
char
*
name
=
"renorm"
;
};
template
<
>
struct
Layernorm2dFusedSweepEnumName
<
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"dequant"
;
};
template
<
>
struct
Layernorm2dFusedSweepEnumName
<
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"dquant"
;
};
// clang-format on
template
<
bool
kPadN_
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment