Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dec32dc6
Commit
dec32dc6
authored
Jan 31, 2025
by
ThomasNing
Browse files
Finish the feature and merge with develop on the computeV2
parents
71352c44
c5fff071
Changes
215
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
443 additions
and
381 deletions
+443
-381
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
.../10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
+0
-65
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
+296
-54
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
+34
-85
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
+29
-25
example/ck_tile/12_smoothquant/example_smoothquant.cpp
example/ck_tile/12_smoothquant/example_smoothquant.cpp
+15
-15
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
.../12_smoothquant/instances/smoothquant_instance_common.hpp
+2
-2
example/ck_tile/12_smoothquant/smoothquant.cpp
example/ck_tile/12_smoothquant/smoothquant.cpp
+14
-14
example/ck_tile/12_smoothquant/smoothquant.hpp
example/ck_tile/12_smoothquant/smoothquant.hpp
+11
-11
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp
...thquant/instances/moe_smoothquant_bf16_n1024_instance.cpp
+9
-4
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp
...thquant/instances/moe_smoothquant_bf16_n1536_instance.cpp
+9
-4
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp
...thquant/instances/moe_smoothquant_bf16_n2048_instance.cpp
+9
-4
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp
...othquant/instances/moe_smoothquant_bf16_n256_instance.cpp
+7
-3
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp
...thquant/instances/moe_smoothquant_bf16_n3072_instance.cpp
+8
-4
No files found.
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <iostream>
#pragma once
using
S
=
ck_tile
::
stream_config
;
using
A
=
rmsnorm2d_fwd_args
;
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
using
trait_
=
rmsnorm2d_fwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kSaveInvRms_
,
kTwoPass_
>
;
template
<
typename
Traits_
>
float
rmsnorm2d_fwd_
(
const
S
&
s
,
A
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
typename
RmsnormTypeConfig
<
DataType
>::
XDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
GammaDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
ComputeDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
YDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
InvRmsDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kSaveInvRms
,
Traits_
::
kTwoPass
>
;
using
OnePassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineOnePass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineTwoPass
<
PipelineProblem
>
;
using
Pipeline
=
std
::
conditional_t
<
Traits_
::
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
Kernel
::
GetName
()
<<
std
::
flush
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
View file @
dec32dc6
...
@@ -19,17 +19,37 @@ auto get_elimit<ck_tile::bf16_t>()
...
@@ -19,17 +19,37 @@ auto get_elimit<ck_tile::bf16_t>()
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
}
template
<
>
auto
get_elimit
<
ck_tile
::
int8_t
>
()
{
double
rtol
=
1e-02
;
double
atol
=
1.0
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"x_stride"
,
"-1"
,
"x row_stride, if -1 then equal to n"
)
.
insert
(
"xr_stride"
,
"-1"
,
"x residule row_stride, if -1 then equal to n"
)
.
insert
(
"y_stride"
,
"-1"
,
"y row_stride, if -1 then equal to n"
)
.
insert
(
"yr_stride"
,
"-1"
,
"y residule row_stride, if -1 then equal to n"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"save_rms"
,
"0"
,
"save rms(invrms) or not. set to 1 in training case"
)
.
insert
(
"save_rms"
,
"0"
,
"save rms(invrms) or not. set to 1 in training case"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"prec_i"
,
"fp16"
,
"input precision"
)
.
insert
(
"prec_o"
,
"auto"
,
"output precision, set auto will be the same as input"
)
.
insert
(
"prec_sm"
,
"auto"
,
"output quant scale type, set auto will use fp32. used when fquant=1"
)
.
insert
(
"prec_sy"
,
"auto"
,
"output quant scale type, set auto will use fp32. used when fquant=1 or 2"
)
.
insert
(
"fadd"
,
"0"
,
"fused-add, 0:no fused add, 1:preadd+store, 2:preadd only"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
@@ -37,28 +57,70 @@ auto create_args(int argc, char* argv[])
...
@@ -37,28 +57,70 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
template
<
typename
DataType
,
bool
SaveRms
>
template
<
typename
InDataType
,
typename
OutDataType
,
typename
SmoothScaleDataType
,
typename
YScaleDataType
,
bool
SaveRms
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
fused_add
=
arg_parser
.
get_int
(
"fadd"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
ck_tile
::
index_t
x_stride
=
arg_parser
.
get_int
(
"x_stride"
);
if
(
x_stride
<
0
)
x_stride
=
n
;
ck_tile
::
index_t
xr_stride
=
arg_parser
.
get_int
(
"xr_stride"
);
if
(
xr_stride
<
0
)
xr_stride
=
n
;
ck_tile
::
index_t
y_stride
=
arg_parser
.
get_int
(
"y_stride"
);
if
(
y_stride
<
0
)
y_stride
=
n
;
ck_tile
::
index_t
yr_stride
=
arg_parser
.
get_int
(
"yr_stride"
);
if
(
yr_stride
<
0
)
yr_stride
=
n
;
assert
(
x_stride
>=
n
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_sm
=
arg_parser
.
get_str
(
"prec_sm"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_sm
==
"auto"
)
{
prec_sm
=
"fp32"
;
}
if
(
prec_sy
==
"auto"
)
{
prec_sy
=
"fp32"
;
}
if
((
fused_quant
==
1
||
fused_quant
==
2
)
&&
prec_o
!=
"int8"
&&
prec_o
!=
"fp8"
)
{
std
::
cout
<<
"if fused_quant is 1 or 2, only support
\"
-prec_o=int8
\"
or
\"
-prec_o=fp8
\"
cases."
<<
std
::
endl
;
return
false
;
}
using
TypeConfig
=
RmsnormTypeConfig
<
DataType
>
;
using
TypeConfig
=
RmsnormTypeConfig
<
InDataType
,
OutDataType
,
SmoothScaleDataType
,
YScaleDataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
using
InvRmsDataType
=
using
InvRmsDataType
=
std
::
conditional_t
<
SaveRms
,
typename
TypeConfig
::
InvRmsDataType
,
ck_tile
::
null_type
>
;
std
::
conditional_t
<
SaveRms
,
typename
TypeConfig
::
InvRmsDataType
,
ck_tile
::
null_type
>
;
...
@@ -66,43 +128,84 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -66,43 +128,84 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_
stride
,
1
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
SmoothScaleDataType
>
sm_scale_host
({
n
});
ck_tile
::
HostTensor
<
SmoothScaleDataType
>
sm_scale_host_dev
({
n
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XResidualDataType
>
x_residual_host
({
m
,
n
},
{
xr_stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host
({
m
,
n
},
{
yr_stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
y_stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
y_stride
,
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_dev
({
m
});
ck_tile
::
HostTensor
<
InvRmsDataType
>
invRms_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvRmsDataType
>
invRms_host_ref
({
m
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XResidualDataType
>
{
-
.5
f
,
.5
f
}(
x_residual_host
);
ck_tile
::
FillUniformDistribution
<
SmoothScaleDataType
>
{
-
1.
f
,
1.
f
}(
sm_scale_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_scale_buf
(
y_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sm_scale_buf
(
sm_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_residual_buf
(
x_residual_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_residual_buf
(
y_residual_host
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
x_residual_buf
.
ToDevice
(
x_residual_host
.
data
());
sm_scale_buf
.
ToDevice
(
sm_scale_host
.
data
());
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_o
)
{
base_str
+=
"|"
+
prec_o
;
}
if
(
fused_quant
==
1
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_sy
+
")"
;
}
return
base_str
;
}();
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", xr_stride:"
<<
xr_stride
<<
", y_stride:"
<<
y_stride
<<
", yr_stride:"
<<
yr_stride
<<
std
::
flush
;
rmsnorm2d_fwd_traits
traits
{
data_type
,
SaveRms
};
rmsnorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
prec_sm
,
prec_sy
,
SaveRms
,
fused_add
,
fused_quant
};
rmsnorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
rmsnorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
fused_add
!=
0
?
x_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sm_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
nullptr
,
fused_add
==
1
?
y_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
y_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
// p_invRms, unsupported yet
epsilon
,
epsilon
,
m
,
m
,
n
,
n
,
stride
};
x_stride
,
// x row_stride
xr_stride
,
// x residule row stride
y_stride
,
// y row stride
yr_stride
};
// y residule row stride
float
ave_time
=
rmsnorm2d_fwd
(
float
ave_time
=
rmsnorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
std
::
size_t
num_byte
=
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
num_byte
+=
SaveRms
?
sizeof
(
InvRmsDataType
)
*
m
*
n
:
0
;
num_byte
+=
fused_add
?
sizeof
(
XResidualDataType
)
*
m
*
n
:
0
;
num_byte
+=
((
fused_quant
==
1
)
||
(
fused_quant
==
2
))
?
sizeof
(
YScaleDataType
)
*
m
:
0
;
num_byte
+=
(
fused_quant
==
1
)
?
sizeof
(
SmoothScaleDataType
)
*
n
:
0
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
...
@@ -112,37 +215,134 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -112,37 +215,134 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
// reference
// reference
if
(
fused_add
!=
0
)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to x_host for simplcity here...
std
::
transform
(
x_host
.
mData
.
cbegin
(),
x_host
.
mData
.
cend
(),
x_residual_host
.
mData
.
cbegin
(),
x_host
.
mData
.
begin
(),
[](
auto
x_
,
auto
r_
)
{
auto
o_
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_
)
+
ck_tile
::
type_convert
<
ComputeDataType
>
(
r_
);
return
ck_tile
::
type_convert
<
XDataType
>
(
o_
);
});
}
if
(
fused_quant
!=
0
)
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
&
o_
,
auto
&
acc_
)
{
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
if
(
fused_quant
==
1
)
{
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
// input smooth outlier
acc_
(
m_
,
n_
)
=
acc_
(
m_
,
n_
)
*
ck_tile
::
type_convert
<
ComputeDataType
>
(
sm_scale_host
(
n_
));
}
}
ComputeDataType
absmax
=
static_cast
<
ComputeDataType
>
(
0
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
const
auto
a
=
ck_tile
::
abs
(
acc_
(
m_
,
n_
));
absmax
=
a
>
absmax
?
a
:
absmax
;
}
// printf("cpu:absmax:%f\n", absmax);
constexpr
ComputeDataType
kMaxY
=
std
::
is_same
<
YDataType
,
ck_tile
::
fp8_t
>::
value
?
240.0
:
std
::
is_same
<
YDataType
,
ck_tile
::
int8_t
>::
value
?
127.0
:
0.0
;
ComputeDataType
y_scale
=
absmax
/
kMaxY
;
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
o_
(
m_
,
n_
)
=
ck_tile
::
type_convert
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale
);
}
};
ck_tile
::
reference_rmsnorm2d_fwd
<
XDataType
,
GammaDataType
,
ComputeDataType
,
YDataType
,
InvRmsDataType
>
(
x_host
,
gamma_host
,
y_host_ref
,
invRms_host_ref
,
epsilon
,
dquant_functor
);
}
else
{
ck_tile
::
reference_rmsnorm2d_fwd
<
XDataType
,
ck_tile
::
reference_rmsnorm2d_fwd
<
XDataType
,
GammaDataType
,
GammaDataType
,
ComputeDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
InvRmsDataType
>
(
InvRmsDataType
>
(
x_host
,
gamma_host
,
y_host_ref
,
invRms_host_ref
,
epsilon
);
x_host
,
gamma_host
,
y_host_ref
,
invRms_host_ref
,
epsilon
);
}
y_buf
.
FromDevice
(
y_host_dev
.
data
());
y_buf
.
FromDevice
(
y_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host_dev
({
m
,
n
},
{
yr_stride
,
1
});
if
(
stride
==
n
)
if
(
fused_add
==
1
)
{
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
}
auto
[
rtol
,
atol
]
=
get_elimit
<
YDataType
>
();
if
(
x_stride
==
n
)
{
{
pass
=
ck_tile
::
check_err
(
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
y_host_dev
,
y_host_ref
,
std
::
string
(
"
\n
OUT Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev
,
x_host
,
std
::
string
(
"
\n
ADD Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
else
else
{
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
{
std
::
vector
<
YDataType
>
y_host_dev_row
(
y_host_dev
.
begin
()
+
i_r
*
stride
,
std
::
vector
<
YDataType
>
y_host_dev_row
(
y_host_dev
.
begin
()
+
i_r
*
y_
stride
,
y_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
y_host_dev
.
begin
()
+
i_r
*
y_
stride
+
n
);
std
::
vector
<
YDataType
>
y_host_ref_row
(
y_host_ref
.
begin
()
+
i_r
*
stride
,
std
::
vector
<
YDataType
>
y_host_ref_row
(
y_host_ref
.
begin
()
+
i_r
*
y_
stride
,
y_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
y_host_ref
.
begin
()
+
i_r
*
y_
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_host_dev_row
,
pass
&=
ck_tile
::
check_err
(
y_host_dev_row
,
y_host_ref_row
,
y_host_ref_row
,
std
::
string
(
"OUT["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"
\n
OUT["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
std
::
vector
<
YResidualDataType
>
y_residual_host_dev_row
(
y_residual_host_dev
.
begin
()
+
i_r
*
yr_stride
,
y_residual_host_dev
.
begin
()
+
i_r
*
yr_stride
+
n
);
std
::
vector
<
YResidualDataType
>
y_residual_host_ref_row
(
x_host
.
begin
()
+
i_r
*
yr_stride
,
x_host
.
begin
()
+
i_r
*
yr_stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev_row
,
y_residual_host_ref_row
,
std
::
string
(
"
\n
ADD["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
rtol
,
atol
);
atol
);
}
}
}
}
}
if
(
fused_quant
==
1
)
{
y_scale_buf
.
FromDevice
(
y_scale_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
y_scale_host_dev
,
y_scale_host_ref
,
std
::
string
(
"
\n
SCALE Error: Incorrect results!"
),
rtol
,
atol
);
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
}
...
@@ -156,23 +356,65 @@ int main(int argc, char* argv[])
...
@@ -156,23 +356,65 @@ int main(int argc, char* argv[])
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_sm
=
arg_parser
.
get_str
(
"prec_sm"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_sm
==
"auto"
)
{
prec_sm
=
"fp32"
;
}
if
(
prec_sy
==
"auto"
)
{
prec_sy
=
"fp32"
;
}
int
save_rms
=
arg_parser
.
get_int
(
"save_rms"
);
int
save_rms
=
arg_parser
.
get_int
(
"save_rms"
);
if
(
data_type
==
"fp16"
&&
save_rms
)
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
save_rms
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
save_rms
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
// dynamic quant case, only in inference
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"int8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
{
return
run
<
ck_tile
::
half_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"fp16"
&&
!
save_rms
)
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"int8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
{
return
run
<
ck_tile
::
half_t
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16"
&&
save_rms
)
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
{
return
run
<
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
,
float
,
float
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16"
&&
!
save_rms
)
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
{
return
run
<
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
float
,
float
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
}
}
return
-
3
;
return
-
3
;
...
...
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -8,27 +8,34 @@
...
@@ -8,27 +8,34 @@
#include "ck_tile/ops/rmsnorm2d.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include <string>
#include <string>
template
<
typename
DataType
>
template
<
typename
InType
,
typename
OutType
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
>
struct
RmsnormTypeConfig
;
struct
RmsnormTypeConfig
;
template
<
>
template
<
typename
OutType
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
>
struct
RmsnormTypeConfig
<
ck_tile
::
half_t
>
struct
RmsnormTypeConfig
<
ck_tile
::
half_t
,
OutType
,
SmoothScaleDataType_
,
YScaleDataType_
>
{
{
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
InvRmsDataType
=
ck_tile
::
half_t
;
using
InvRmsDataType
=
ck_tile
::
half_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
using
SmoothScaleDataType
=
SmoothScaleDataType_
;
using
YScaleDataType
=
YScaleDataType_
;
};
};
template
<
>
template
<
typename
OutType
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
>
struct
RmsnormTypeConfig
<
ck_tile
::
bf16_t
>
struct
RmsnormTypeConfig
<
ck_tile
::
bf16_t
,
OutType
,
SmoothScaleDataType_
,
YScaleDataType_
>
{
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
InvRmsDataType
=
ck_tile
::
bf16_t
;
using
InvRmsDataType
=
ck_tile
::
bf16_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
using
SmoothScaleDataType
=
SmoothScaleDataType_
;
using
YScaleDataType
=
YScaleDataType_
;
};
};
// runtime args
// runtime args
...
@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
...
@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
{
{
};
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
struct
rmsnorm2d_fwd_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static
constexpr
ck_tile
::
index_t
total_warps
=
(
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
/
warpSize
;
// num of warps along m
static
constexpr
ck_tile
::
index_t
BlockWarps_M
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return
total_warps
/
(
ThreadPerBlock_N_
/
warpSize
);
}
}();
// num of warps along n
static
constexpr
ck_tile
::
index_t
BlockWarps_N
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
1
;
}
else
{
static_assert
(
ThreadPerBlock_N_
%
warpSize
==
0
);
return
ThreadPerBlock_N_
/
warpSize
;
}
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
rmsnorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
rmsnorm2d_fwd_args
a
);
float
rmsnorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
rmsnorm2d_fwd_args
a
);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
rmsnorm2d_fwd_traits
struct
rmsnorm2d_fwd_traits
{
{
std
::
string
data_type
;
std
::
string
prec_i
;
// input precision
std
::
string
prec_o
;
// output precision
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std
::
string
prec_sm
;
// x-scale, used for [1*N] input smooth quant
std
::
string
prec_sy
;
// y-scale, used for [M*1] output for next layer
bool
save_rms
;
bool
save_rms
;
int
fused_add
;
// 0:no-add, 1:pre-add-store, 2:pre-add
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
};
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
,
rmsnorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
,
rmsnorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
View file @
dec32dc6
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
"-fquant=2 -prec_o=fp8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
for
fadd
in
"0"
"1"
;
do
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
8
-n
=
1501
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
1826
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
8
-n
=
1501
$EXE
-prec
=
$pr_i
-m
=
5
-n
=
2040
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
1826
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
5
-n
=
2040
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
17134
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
done
example/ck_tile/12_smoothquant/example_smoothquant.cpp
View file @
dec32dc6
...
@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
assert
(
x_
stride
>=
n
);
using
XDataType
=
DataType
;
using
XDataType
=
DataType
;
using
X
ScaleDataType
=
float
;
using
Smooth
ScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_stride
,
1
});
ck_tile
::
HostTensor
<
X
ScaleDataType
>
x
scale_host
({
n
});
ck_tile
::
HostTensor
<
Smooth
ScaleDataType
>
sm
scale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
...
@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
X
ScaleDataType
>
{
1e-3
,
.5
f
}(
x
scale_host
);
ck_tile
::
FillUniformDistribution
<
Smooth
ScaleDataType
>
{
1e-3
,
.5
f
}(
sm
scale_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x
scale_buf
(
x
scale_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sm
scale_buf
(
sm
scale_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
yscale_buf
(
yscale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
yscale_buf
(
yscale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
qy_buf
(
qy_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
qy_buf
(
qy_host_dev
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
x
scale_buf
.
ToDevice
(
x
scale_host
.
data
());
sm
scale_buf
.
ToDevice
(
sm
scale_host
.
data
());
constexpr
bool
kTwoPass
=
true
;
constexpr
bool
kTwoPass
=
true
;
...
@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Problem
=
ck_tile
::
SmoothquantPipelineProblem
<
XDataType
,
using
Problem
=
ck_tile
::
SmoothquantPipelineProblem
<
XDataType
,
X
ScaleDataType
,
Smooth
ScaleDataType
,
ComputeDataType
,
ComputeDataType
,
YScaleDataType
,
YScaleDataType
,
QYDataType
,
QYDataType
,
...
@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
Kernel
=
ck_tile
::
Smoothquant
<
Pipeline
>
;
using
Kernel
=
ck_tile
::
Smoothquant
<
Pipeline
>
;
ck_tile
::
SmoothquantHostArgs
args
{
x_buf
.
GetDeviceBuffer
(),
ck_tile
::
SmoothquantHostArgs
args
{
x_buf
.
GetDeviceBuffer
(),
x
scale_buf
.
GetDeviceBuffer
(),
sm
scale_buf
.
GetDeviceBuffer
(),
yscale_buf
.
GetDeviceBuffer
(),
yscale_buf
.
GetDeviceBuffer
(),
qy_buf
.
GetDeviceBuffer
(),
qy_buf
.
GetDeviceBuffer
(),
m
,
m
,
...
@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier
// smooth outlier
{
{
auto
f
=
[
&
](
auto
n_
)
{
auto
f
=
[
&
](
auto
n_
)
{
auto
v_
x
scale
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x
scale_host
(
n_
));
auto
v_
sm
scale
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
sm
scale_host
(
n_
));
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
{
{
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
y_host
(
m_
,
n_
)
=
v_x
*
v_
x
scale
;
y_host
(
m_
,
n_
)
=
v_x
*
v_
sm
scale
;
}
}
};
};
ck_tile
::
make_ParallelTensorFunctor
(
f
,
x
scale_host
.
get_element_space_size
())(
ck_tile
::
make_ParallelTensorFunctor
(
f
,
sm
scale_host
.
get_element_space_size
())(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
}
}
...
...
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include <ck_tile/core.hpp>
#include "smoothquant.hpp"
#include "smoothquant.hpp"
...
@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
...
@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
using
PipelineProblem
=
ck_tile
::
SmoothquantPipelineProblem
<
using
PipelineProblem
=
ck_tile
::
SmoothquantPipelineProblem
<
typename
SmoothquantTypeConfig
<
DataType
>::
XDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
XDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
X
ScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
Smooth
ScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
ComputeDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
ComputeDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
YScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
YScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
QYDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
QYDataType
,
...
...
example/ck_tile/12_smoothquant/smoothquant.cpp
View file @
dec32dc6
...
@@ -67,14 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -67,14 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
TypeConfig
=
SmoothquantTypeConfig
<
DataType
>
;
using
TypeConfig
=
SmoothquantTypeConfig
<
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
X
ScaleDataType
=
typename
TypeConfig
::
X
ScaleDataType
;
using
Smooth
ScaleDataType
=
typename
TypeConfig
::
Smooth
ScaleDataType
;
using
YScaleDataType
=
typename
TypeConfig
::
YScaleDataType
;
using
YScaleDataType
=
typename
TypeConfig
::
YScaleDataType
;
using
QYDataType
=
typename
TypeConfig
::
QYDataType
;
using
QYDataType
=
typename
TypeConfig
::
QYDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_stride
,
1
});
ck_tile
::
HostTensor
<
X
ScaleDataType
>
x
scale_host
({
n
});
ck_tile
::
HostTensor
<
Smooth
ScaleDataType
>
sm
scale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
...
@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
X
ScaleDataType
>
{
1e-3
,
.5
f
}(
x
scale_host
);
ck_tile
::
FillUniformDistribution
<
Smooth
ScaleDataType
>
{
1e-3
,
.5
f
}(
sm
scale_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x
scale_buf
(
x
scale_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sm
scale_buf
(
sm
scale_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
yscale_buf
(
yscale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
yscale_buf
(
yscale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
qy_buf
(
qy_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
qy_buf
(
qy_host_dev
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
x
scale_buf
.
ToDevice
(
x
scale_host
.
data
());
sm
scale_buf
.
ToDevice
(
sm
scale_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", y_stride:"
<<
y_stride
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", y_stride:"
<<
y_stride
...
@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
smoothquant_traits
traits
{
data_type
};
smoothquant_traits
traits
{
data_type
};
smoothquant_args
args
{
x_buf
.
GetDeviceBuffer
(),
smoothquant_args
args
{
x_buf
.
GetDeviceBuffer
(),
x
scale_buf
.
GetDeviceBuffer
(),
sm
scale_buf
.
GetDeviceBuffer
(),
yscale_buf
.
GetDeviceBuffer
(),
yscale_buf
.
GetDeviceBuffer
(),
qy_buf
.
GetDeviceBuffer
(),
qy_buf
.
GetDeviceBuffer
(),
m
,
m
,
...
@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
ave_time
=
smoothquant
(
float
ave_time
=
smoothquant
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
X
ScaleDataType
)
*
n
+
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
Smooth
ScaleDataType
)
*
n
+
sizeof
(
YScaleDataType
)
*
m
+
sizeof
(
QYDataType
)
*
m
*
n
;
sizeof
(
YScaleDataType
)
*
m
+
sizeof
(
QYDataType
)
*
m
*
n
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
...
@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier
// smooth outlier
{
{
auto
f
=
[
&
](
auto
n_
)
{
auto
f
=
[
&
](
auto
n_
)
{
auto
v_
x
scale
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x
scale_host
(
n_
));
auto
v_
sm
scale
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
sm
scale_host
(
n_
));
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
{
{
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
y_host
(
m_
,
n_
)
=
v_x
*
v_
x
scale
;
y_host
(
m_
,
n_
)
=
v_x
*
v_
sm
scale
;
}
}
};
};
ck_tile
::
make_ParallelTensorFunctor
(
f
,
x
scale_host
.
get_element_space_size
())(
ck_tile
::
make_ParallelTensorFunctor
(
f
,
sm
scale_host
.
get_element_space_size
())(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
}
}
...
...
example/ck_tile/12_smoothquant/smoothquant.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -15,7 +15,7 @@ template <>
...
@@ -15,7 +15,7 @@ template <>
struct
SmoothquantTypeConfig
<
ck_tile
::
half_t
>
struct
SmoothquantTypeConfig
<
ck_tile
::
half_t
>
{
{
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
X
ScaleDataType
=
float
;
using
Smooth
ScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
...
@@ -25,7 +25,7 @@ template <>
...
@@ -25,7 +25,7 @@ template <>
struct
SmoothquantTypeConfig
<
ck_tile
::
bf16_t
>
struct
SmoothquantTypeConfig
<
ck_tile
::
bf16_t
>
{
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
X
ScaleDataType
=
float
;
using
Smooth
ScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
...
...
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp
View file @
dec32dc6
...
@@ -15,8 +15,13 @@ template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true
...
@@ -15,8 +15,13 @@ template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
#endif
#endif
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp
View file @
dec32dc6
...
@@ -6,8 +6,13 @@
...
@@ -6,8 +6,13 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp
View file @
dec32dc6
...
@@ -6,9 +6,14 @@
...
@@ -6,9 +6,14 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp
View file @
dec32dc6
...
@@ -6,7 +6,11 @@
...
@@ -6,7 +6,11 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp
View file @
dec32dc6
...
@@ -6,9 +6,13 @@
...
@@ -6,9 +6,13 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment