Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
018e939f
Commit
018e939f
authored
Jan 22, 2025
by
Jiming Ruan
Browse files
Modify tests and bug fix
parent
54617a85
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
118 additions
and
73 deletions
+118
-73
example/ck_tile/10_rmsnorm2d/generate.py
example/ck_tile/10_rmsnorm2d/generate.py
+103
-73
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
+6
-0
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
+7
-0
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
+2
-0
No files found.
example/ck_tile/10_rmsnorm2d/generate.py
View file @
018e939f
...
@@ -61,6 +61,8 @@ template <typename XDataType_,
...
@@ -61,6 +61,8 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kPadN_,
bool kSaveInvRms_,
bool kSaveInvRms_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
ck_tile::index_t kFusedQuant_ = 0>
...
@@ -122,6 +124,8 @@ struct rmsnorm2d_fwd_traits_
...
@@ -122,6 +124,8 @@ struct rmsnorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
...
@@ -138,6 +142,8 @@ template <typename XDataType_,
...
@@ -138,6 +142,8 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kPadN_,
bool kSaveInvRms_,
bool kSaveInvRms_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedAdd_,
int kFusedQuant_>
int kFusedQuant_>
...
@@ -152,6 +158,8 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
...
@@ -152,6 +158,8 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
Vector_N_,
Vector_N_,
kPadN_,
kPadN_,
kSaveInvRms_,
kSaveInvRms_,
kFastFDiv_,
kWelford_,
kTwoPass_,
kTwoPass_,
kFusedAdd_,
kFusedAdd_,
kFusedQuant_>;
kFusedQuant_>;
...
@@ -185,6 +193,8 @@ float rmsnorm2d_fwd_(const S& s, A a)
...
@@ -185,6 +193,8 @@ float rmsnorm2d_fwd_(const S& s, A a)
using PipelineTraits =
using PipelineTraits =
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveInvRms,
Traits_::kSaveInvRms,
Traits_::kFastFDiv,
Traits_::kWelford,
Traits_::kTwoPass,
Traits_::kTwoPass,
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
...
@@ -204,12 +214,14 @@ float rmsnorm2d_fwd_(const S& s, A a)
...
@@ -204,12 +214,14 @@ float rmsnorm2d_fwd_(const S& s, A a)
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
static constexpr bool UseRawStore = sizeof(YDataType) == 4;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, UseRawStore>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale,
fals
e, true/*max3*/>>;
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale,
UseRawStor
e, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
...
@@ -262,24 +274,21 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -262,24 +274,21 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
#include "rmsnorm2d_fwd_api_common.hpp"
#include "rmsnorm2d_fwd_api_common.hpp"
// clang-format off
// clang-format off
// rm rn tm tn vn pd rms 2p
// rm rn tm tn vn pd rms
welford
2p
{F_instance_def}
{F_instance_def}
// clang-format on
// clang-format on
"""
"""
API_PER_DTYPE
=
"""
API_PER_DTYPE
=
""" {F_if}(t.prec_i ==
\"
{F_i_type}
\"
&& t.prec_o ==
\"
{F_o_type}
\"
){{
{F_if}(t.prec_i ==
\"
{F_i_type}
\"
&& t.prec_o ==
\"
{F_o_type}
\"
){{
{F_per_n_case}
{F_per_n_case}
}}
}}
"""
"""
API_PER_N_CASE
=
"""
API_PER_N_CASE
=
""" {F_if} {F_N_COND} {{
{F_if} {F_N_COND} {{
{F_inner_dispatch}
{F_inner_dispatch}
}}
}}
"""
"""
API_INNER_CASE
=
"""
API_INNER_CASE
=
""" {F_if} {F_VEC_COND}
{F_if} {F_VEC_COND}
r={F_instance_func}(s, a);
r={F_instance_func}(s, a);
"""
"""
...
@@ -362,6 +371,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -362,6 +371,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
F_Vector_N
:
int
F_Vector_N
:
int
F_kPadN
:
bool
F_kPadN
:
bool
F_kSaveInvRms
:
bool
F_kSaveInvRms
:
bool
F_kFastFDiv_
:
bool
F_kWelford_
:
bool
F_kTwoPass
:
bool
F_kTwoPass
:
bool
F_kFusedAdd
:
int
F_kFusedAdd
:
int
F_kFusedQuant
:
int
F_kFusedQuant
:
int
...
@@ -369,7 +380,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -369,7 +380,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
@
property
@
property
def
trait_name
(
self
)
->
str
:
def
trait_name
(
self
)
->
str
:
t_
=
f
'
{
DATA_TYPE_MAP
[
self
.
F_XDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_SmoothScaleDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YScaleDataType
]
}
,
{
self
.
F_Repeat_M
:
2
}
,
{
self
.
F_Repeat_N
:
2
}
,
{
self
.
F_ThreadPerBlock_M
:
2
}
,
{
self
.
F_ThreadPerBlock_N
:
4
}
'
t_
=
f
'
{
DATA_TYPE_MAP
[
self
.
F_XDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_SmoothScaleDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YScaleDataType
]
}
,
{
self
.
F_Repeat_M
:
2
}
,
{
self
.
F_Repeat_N
:
2
}
,
{
self
.
F_ThreadPerBlock_M
:
2
}
,
{
self
.
F_ThreadPerBlock_N
:
4
}
'
t_
+=
f
',
{
self
.
F_Vector_N
:
2
}
,
{
BOOL_MAP
(
self
.
F_kPadN
):
5
}
,
{
BOOL_MAP
(
self
.
F_kSaveInvRms
):
5
}
'
t_
+=
f
',
{
self
.
F_Vector_N
:
2
}
,
{
BOOL_MAP
(
self
.
F_kPadN
):
5
}
,
{
BOOL_MAP
(
self
.
F_kSaveInvRms
):
5
}
,
{
BOOL_MAP
(
self
.
F_kFastFDiv_
):
5
}
,
{
BOOL_MAP
(
self
.
F_kWelford_
):
5
}
'
t_
+=
f
',
{
BOOL_MAP
(
self
.
F_kTwoPass
):
5
}
,
{
self
.
F_kFusedAdd
:
4
}
,
{
self
.
F_kFusedQuant
:
4
}
'
t_
+=
f
',
{
BOOL_MAP
(
self
.
F_kTwoPass
):
5
}
,
{
self
.
F_kFusedAdd
:
4
}
,
{
self
.
F_kFusedQuant
:
4
}
'
return
t_
return
t_
...
@@ -422,11 +433,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -422,11 +433,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
def
name_common_header
(
self
)
->
str
:
def
name_common_header
(
self
)
->
str
:
return
'rmsnorm2d_fwd_api_common'
return
'rmsnorm2d_fwd_api_common'
@
property
def
content_api
(
self
,
args
)
->
str
:
def
content_api
(
self
)
->
str
:
# 1 sort based on dtype
# 1 sort based on dtype
t_dtype_dict
=
dict
()
t_dtype_dict
=
dict
()
blobs
=
self
.
get_blobs
()
blobs
=
self
.
get_blobs
(
args
)
for
blob
in
blobs
:
for
blob
in
blobs
:
if
blob
.
F_DataTypePair
not
in
t_dtype_dict
:
if
blob
.
F_DataTypePair
not
in
t_dtype_dict
:
t_dtype_dict
[
blob
.
F_DataTypePair
]
=
{}
t_dtype_dict
[
blob
.
F_DataTypePair
]
=
{}
...
@@ -462,8 +472,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -462,8 +472,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
inner_str
+=
self
.
API_INNER_CASE
.
format
(
F_if
=
get_if_str
(
idx_in_n
,
len_in_n
,
False
),
inner_str
+=
self
.
API_INNER_CASE
.
format
(
F_if
=
get_if_str
(
idx_in_n
,
len_in_n
,
False
),
F_VEC_COND
=
_cond
,
F_instance_func
=
ins
.
call_name
)
F_VEC_COND
=
_cond
,
F_instance_func
=
ins
.
call_name
)
#inner_str = inner_str + vec_str
#inner_str = inner_str + vec_str
n_cnd
=
f
'(a.n <=
{
n_
}
)'
if
(
i_n
<
len
(
blob_per_t
)
-
1
)
else
''
n_cnd
=
f
'(a.n <=
{
n_
}
)'
if
isinstance
(
n_
,
int
)
else
''
n_str
+=
self
.
API_PER_N_CASE
.
format
(
F_if
=
get_if_str
(
i_n
,
len
(
blob_per_t
)),
F_N_COND
=
n_cnd
,
F_inner_dispatch
=
inner_str
)
n_str
+=
self
.
API_PER_N_CASE
.
format
(
F_if
=
get_if_str
(
i_n
,
len
(
blob_per_t
)
,
not
isinstance
(
n_
,
int
)
),
F_N_COND
=
n_cnd
,
F_inner_dispatch
=
inner_str
)
prec_i
,
prec_o
=
dtype_
.
split
(
','
)
prec_i
,
prec_o
=
dtype_
.
split
(
','
)
d_str
+=
self
.
API_PER_DTYPE
.
format
(
F_if
=
get_if_str
(
i_d
,
len
(
t_dtype_dict
),
False
),
F_i_type
=
prec_i
,
F_o_type
=
prec_o
,
F_per_n_case
=
n_str
)
d_str
+=
self
.
API_PER_DTYPE
.
format
(
F_if
=
get_if_str
(
i_d
,
len
(
t_dtype_dict
),
False
),
F_i_type
=
prec_i
,
F_o_type
=
prec_o
,
F_per_n_case
=
n_str
)
...
@@ -474,7 +484,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -474,7 +484,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
def
content_common_header
(
self
)
->
str
:
def
content_common_header
(
self
)
->
str
:
return
self
.
API_COMMON_HEADER
.
format
(
F_traits_define
=
self
.
API_TRAITS_DEFINE
)
return
self
.
API_COMMON_HEADER
.
format
(
F_traits_define
=
self
.
API_TRAITS_DEFINE
)
def
get_blobs
(
self
):
def
get_blobs
(
self
,
args
):
h_traits
=
rmsnorm_fwd_codegen
.
h_traits
h_traits
=
rmsnorm_fwd_codegen
.
h_traits
h_instance
=
rmsnorm_fwd_codegen
.
h_instance
h_instance
=
rmsnorm_fwd_codegen
.
h_instance
...
@@ -485,60 +495,61 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -485,60 +495,61 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
),
(
'fp16,int8'
),
(
'bf16,int8'
),
(
'fp16,fp8'
),
(
'bf16,fp8'
)]
# NOTE: only fused-dynamic-quant use int8 out
(
'fp16,fp8'
),
(
'bf16,fp8'
)]
# NOTE: only fused-dynamic-quant use int8 out
types_8bit
=
(
'int8'
,
'fp8'
)
types_16bit
=
(
'int16'
,
'fp16'
,
'bf16'
)
#fused_add_list = [0, 1, 2]
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list
=
[
0
,
1
]
fused_add_list
=
[
0
,
1
]
fused_sweep_list
=
[
0
,
1
,
2
]
# NOTE: only single pass can use fused (smooth) dynamic quant
fused_sweep_list
=
[
0
,
1
,
2
]
# NOTE: only single pass can use fused (smooth) dynamic quant
# rm rn tm tn vn pd mv
2p
add
sweep
# rm rn tm tn vn pd mv
fdiv welford 2p
add sweep
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
8
,
8
,
8
,
True
,
False
,
False
,
0
,
0
),
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
8
,
8
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'128'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
8
,
True
,
False
,
False
,
0
,
0
),
'128'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'256'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
'256'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'512'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
'512'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'768'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
'768'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
4
,
64
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
12
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
12
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'1024'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
2
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
'1024'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
2
,
128
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
2
,
128
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
2
,
128
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
2
,
128
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'1536'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
'1536'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
2
,
128
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'2048'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
'2048'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
256
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'3072'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
'3072'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
128
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'4096'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
'4096'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'6144'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
'6144'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
512
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'8192'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
'8192'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'big'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
0
,
0
),
'big'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
0
,
0
)]}
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
True
,
0
,
0
)]}
total_blob
=
list
()
total_blob
=
list
()
for
hs_key
in
h_trait_dict
:
for
hs_key
in
h_trait_dict
:
hs
=
h_trait_dict
[
hs_key
]
hs
=
h_trait_dict
[
hs_key
]
...
@@ -559,16 +570,27 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -559,16 +570,27 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_
.
F_YScaleDataType
=
scale_y
h_
.
F_YScaleDataType
=
scale_y
h_
.
F_kFusedAdd
=
fused_add
h_
.
F_kFusedAdd
=
fused_add
h_
.
F_kFusedQuant
=
fused_quant
h_
.
F_kFusedQuant
=
fused_quant
# disable welford update for 8bit and 16 bit smallN
if
not
h_
.
F_kTwoPass
:
#disable 16 bit when set args disable_16b_welford
if
args
.
disable_16b_welford
and
prec_i
in
types_16bit
:
h_
.
F_kWelford_
=
False
#disable 8bit by default
elif
prec_i
in
types_8bit
or
prec_o
in
types_8bit
:
h_
.
F_kWelford_
=
False
#disable 16bit small N
elif
prec_i
in
types_16bit
and
hs_key
==
'64'
:
h_
.
F_kWelford_
=
False
current_hs
.
append
(
h_
)
# + "\n"
current_hs
.
append
(
h_
)
# + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str
=
'big'
if
hs_key
==
'big'
else
current_n
current_n_str
=
'big'
if
hs_key
==
'big'
else
current_n
total_blob
.
append
(
h_instance
(
dtype
,
current_n_str
,
fused_add
,
fused_quant
,
current_hs
))
total_blob
.
append
(
h_instance
(
dtype
,
current_n_str
,
fused_add
,
fused_quant
,
current_hs
))
return
total_blob
return
total_blob
def
list_blobs
(
self
)
->
None
:
def
list_blobs
(
self
,
args
)
->
None
:
w_p
=
Path
(
self
.
working_path
)
w_p
=
Path
(
self
.
working_path
)
list_p
=
w_p
/
'rmsnorm2d_fwd_blobs.txt'
list_p
=
w_p
/
'rmsnorm2d_fwd_blobs.txt'
blobs
=
self
.
get_blobs
()
blobs
=
self
.
get_blobs
(
args
)
with
list_p
.
open
(
'w'
)
as
list_f
:
with
list_p
.
open
(
'w'
)
as
list_f
:
# api related file
# api related file
list_f
.
write
(
str
(
w_p
/
(
self
.
name_api
+
".cpp"
))
+
"
\n
"
)
list_f
.
write
(
str
(
w_p
/
(
self
.
name_api
+
".cpp"
))
+
"
\n
"
)
...
@@ -577,11 +599,12 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -577,11 +599,12 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
for
b
in
blobs
:
for
b
in
blobs
:
list_f
.
write
(
str
(
w_p
/
(
b
.
name
+
".cpp"
))
+
"
\n
"
)
list_f
.
write
(
str
(
w_p
/
(
b
.
name
+
".cpp"
))
+
"
\n
"
)
def
gen_blobs
(
self
)
->
None
:
def
gen_blobs
(
self
,
args
)
->
None
:
w_p
=
Path
(
self
.
working_path
)
w_p
=
Path
(
self
.
working_path
)
(
w_p
/
(
self
.
name_api
+
".cpp"
)).
write_text
(
self
.
content_api
)
w_str
=
self
.
content_api
(
args
)
(
w_p
/
(
self
.
name_api
+
".cpp"
)).
write_text
(
w_str
)
(
w_p
/
(
self
.
name_common_header
+
".hpp"
)).
write_text
(
self
.
content_common_header
)
(
w_p
/
(
self
.
name_common_header
+
".hpp"
)).
write_text
(
self
.
content_common_header
)
blobs
=
self
.
get_blobs
()
blobs
=
self
.
get_blobs
(
args
)
for
b
in
blobs
:
for
b
in
blobs
:
(
w_p
/
(
b
.
name
+
".cpp"
)).
write_text
(
b
.
content
)
(
w_p
/
(
b
.
name
+
".cpp"
)).
write_text
(
b
.
content
)
...
@@ -590,14 +613,14 @@ def list_blobs(args):
...
@@ -590,14 +613,14 @@ def list_blobs(args):
api_list
=
args
.
api
.
split
(
','
)
api_list
=
args
.
api
.
split
(
','
)
for
api
in
api_list
:
for
api
in
api_list
:
if
api
==
'fwd'
:
if
api
==
'fwd'
:
rmsnorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
list_blobs
()
rmsnorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
list_blobs
(
args
)
def
gen_blobs
(
args
):
def
gen_blobs
(
args
):
api_list
=
args
.
api
.
split
(
','
)
api_list
=
args
.
api
.
split
(
','
)
for
api
in
api_list
:
for
api
in
api_list
:
if
api
==
'fwd'
:
if
api
==
'fwd'
:
rmsnorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
gen_blobs
()
rmsnorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
gen_blobs
(
args
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -666,6 +689,13 @@ if __name__ == "__main__":
...
@@ -666,6 +689,13 @@ if __name__ == "__main__":
help
=
"codegen receipt."
help
=
"codegen receipt."
)
)
parser
.
add_argument
(
"--disable_16b_welford"
,
default
=
False
,
required
=
False
,
help
=
"enable/disable welford for 16bit datatype n > 64"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# print(f'{args.list_blobs}-{args.gen_blobs}')
# print(f'{args.list_blobs}-{args.gen_blobs}')
...
...
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
View file @
018e939f
...
@@ -200,6 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -200,6 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
ave_time
=
rmsnorm2d_fwd
(
float
ave_time
=
rmsnorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
std
::
size_t
num_byte
=
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
num_byte
+=
SaveRms
?
sizeof
(
InvRmsDataType
)
*
m
*
n
:
0
;
num_byte
+=
SaveRms
?
sizeof
(
InvRmsDataType
)
*
m
*
n
:
0
;
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
View file @
018e939f
...
@@ -120,6 +120,13 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -120,6 +120,13 @@ struct Rmsnorm2dFwdPipelineOnePass
block_norm_reduce_sync
(
square_mean
,
cur_count
);
block_norm_reduce_sync
(
square_mean
,
cur_count
);
block_norm_reduce_cross_warp_sync
(
square_mean
,
cur_count
,
smem
);
block_norm_reduce_cross_warp_sync
(
square_mean
,
cur_count
,
smem
);
if
constexpr
(
!
kWelford
)
{
sweep_tile
(
square_mean
,
[
&
](
auto
idx
)
{
square_mean
(
idx
)
=
square_mean
(
idx
)
/
type_convert
<
ComputeDataType
>
(
row_size
);
});
}
// compute inv-rms
// compute inv-rms
auto
inv_rms
=
tile_elementwise_in
(
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
[
&
](
const
auto
&
v_
)
{
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
View file @
018e939f
...
@@ -70,6 +70,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -70,6 +70,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
void
*
smem
,
void
*
smem
,
Epilogue
)
const
Epilogue
)
const
{
{
static_assert
(
kWelford
==
true
,
"2 pass only supports welford merge"
);
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
auto
gamma_window
=
make_tile_window
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment