Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
64d5c4d6
Unverified
Commit
64d5c4d6
authored
Jan 24, 2025
by
ruanjm
Committed by
GitHub
Jan 24, 2025
Browse files
Implement fp8 quant for layernorm and rmsnorm (#1814)
parent
5b9b083d
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
67 additions
and
19 deletions
+67
-19
example/ck_tile/02_layernorm2d/CMakeLists.txt
example/ck_tile/02_layernorm2d/CMakeLists.txt
+1
-1
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+5
-3
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+28
-4
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+1
-1
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
+1
-1
example/ck_tile/10_rmsnorm2d/generate.py
example/ck_tile/10_rmsnorm2d/generate.py
+5
-3
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
+19
-3
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
+2
-2
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+5
-1
No files found.
example/ck_tile/02_layernorm2d/CMakeLists.txt
View file @
64d5c4d6
...
@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
...
@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set
(
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
)
set
(
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
list
(
APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
--offload-compress
)
target_compile_options
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
}
)
target_compile_options
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
}
)
...
...
example/ck_tile/02_layernorm2d/generate.py
View file @
64d5c4d6
...
@@ -39,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
...
@@ -39,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
'fp16'
:
'ck_tile::fp16_t'
,
'fp16'
:
'ck_tile::fp16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'int8'
:
'ck_tile::int8_t'
}
'int8'
:
'ck_tile::int8_t'
,
'fp8'
:
'ck_tile::fp8_t'
}
def
BOOL_MAP
(
b_
)
->
str
:
def
BOOL_MAP
(
b_
)
->
str
:
if
b_
:
if
b_
:
...
@@ -504,12 +505,13 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
...
@@ -504,12 +505,13 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_traits
=
layernorm_fwd_codegen
.
h_traits
h_traits
=
layernorm_fwd_codegen
.
h_traits
h_instance
=
layernorm_fwd_codegen
.
h_instance
h_instance
=
layernorm_fwd_codegen
.
h_instance
dynamic_quant_out_dtype
=
[
'int8'
]
dynamic_quant_out_dtype
=
[
'int8'
,
'fp8'
]
# some predefined support range
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list
=
[(
'fp32,fp32'
)]
scale_list
=
[(
'fp32,fp32'
)]
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
)]
# NOTE: only fused-dynamic-quant use int8 out
(
'fp16,int8'
),
(
'bf16,int8'
),
(
'fp16,fp8'
),
(
'bf16,fp8'
)]
# NOTE: only fused-dynamic-quant use int8 or fp8 out
types_8bit
=
(
'int8'
,
'fp8'
)
types_8bit
=
(
'int8'
,
'fp8'
)
types_16bit
=
(
'int16'
,
'fp16'
,
'bf16'
)
types_16bit
=
(
'int16'
,
'fp16'
,
'bf16'
)
#fused_add_list = [0, 1, 2]
#fused_add_list = [0, 1, 2]
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
64d5c4d6
...
@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>()
...
@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>()
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
}
template
<
>
auto
get_elimit
<
ck_tile
::
int8_t
>
()
{
double
rtol
=
1e-2
;
double
atol
=
1.0
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
...
@@ -97,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -97,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
xbias
=
arg_parser
.
get_int
(
"xbias"
);
int
xbias
=
arg_parser
.
get_int
(
"xbias"
);
int
fused_add
=
arg_parser
.
get_int
(
"fadd"
);
int
fused_add
=
arg_parser
.
get_int
(
"fadd"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
if
(
fused_quant
==
1
&&
prec_o
!=
"int8"
)
if
(
fused_quant
==
1
&&
prec_o
!=
"int8"
&&
prec_o
!=
"fp8"
)
{
{
std
::
cout
<<
"if fused_quant is 1, only support
\"
-prec_o=int8
\"
case"
<<
std
::
endl
;
std
::
cout
<<
"if fused_quant is 1 or 2, only support
\"
-prec_o=int8
\"
or
\"
-prec_o=fp8
\"
cases."
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -291,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -291,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax
=
a
>
absmax
?
a
:
absmax
;
absmax
=
a
>
absmax
?
a
:
absmax
;
}
}
// printf("cpu:absmax:%f\n", absmax);
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
constexpr
ComputeDataType
kMaxY
=
std
::
is_same
<
YDataType
,
ck_tile
::
fp8_t
>::
value
?
240.0
:
std
::
is_same
<
YDataType
,
ck_tile
::
int8_t
>::
value
?
127.0
:
0.0
;
ComputeDataType
y_scale
=
absmax
/
kMaxY
;
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
{
...
@@ -334,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -334,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
}
}
auto
[
rtol
,
atol
]
=
get_elimit
<
In
DataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
Out
DataType
>
();
if
(
x_stride
==
n
)
if
(
x_stride
==
n
)
{
{
...
@@ -452,6 +466,16 @@ int main(int argc, char* argv[])
...
@@ -452,6 +466,16 @@ int main(int argc, char* argv[])
{
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
return
-
3
;
}
}
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
64d5c4d6
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_layernorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_layernorm2d_fwd
-type
f |
head
-n
1
)
"
for
fquant
in
""
"-fquant=1 -prec_o=int8"
;
do
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
fadd
in
"0"
"1"
;
do
for
fadd
in
"0"
"1"
;
do
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
...
...
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
View file @
64d5c4d6
...
@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
...
@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
list
(
APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
--offload-compress
)
target_compile_options
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
}
)
target_compile_options
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
}
)
...
...
example/ck_tile/10_rmsnorm2d/generate.py
View file @
64d5c4d6
...
@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
...
@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
'fp16'
:
'ck_tile::fp16_t'
,
'fp16'
:
'ck_tile::fp16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'int8'
:
'ck_tile::int8_t'
}
'int8'
:
'ck_tile::int8_t'
,
'fp8'
:
'ck_tile::fp8_t'
}
def
BOOL_MAP
(
b_
)
->
str
:
def
BOOL_MAP
(
b_
)
->
str
:
if
b_
:
if
b_
:
...
@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits
=
rmsnorm_fwd_codegen
.
h_traits
h_traits
=
rmsnorm_fwd_codegen
.
h_traits
h_instance
=
rmsnorm_fwd_codegen
.
h_instance
h_instance
=
rmsnorm_fwd_codegen
.
h_instance
dynamic_quant_out_dtype
=
[
'int8'
]
dynamic_quant_out_dtype
=
[
'int8'
,
'fp8'
]
# some predefined support range
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list
=
[(
'fp32,fp32'
)]
scale_list
=
[(
'fp32,fp32'
)]
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
)]
# NOTE: only fused-dynamic-quant use int8 out
(
'fp16,int8'
),
(
'bf16,int8'
),
(
'fp16,fp8'
),
(
'bf16,fp8'
)]
# NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2]
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list
=
[
0
,
1
]
fused_add_list
=
[
0
,
1
]
...
...
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
View file @
64d5c4d6
...
@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sy
=
"fp32"
;
prec_sy
=
"fp32"
;
}
}
if
((
fused_quant
==
1
||
fused_quant
==
2
)
&&
prec_o
!=
"int8"
)
if
((
fused_quant
==
1
||
fused_quant
==
2
)
&&
prec_o
!=
"int8"
&&
prec_o
!=
"fp8"
)
{
{
std
::
cout
<<
"if fused_quant is 1, only support
\"
-prec_o=int8
\"
case"
<<
std
::
endl
;
std
::
cout
<<
"if fused_quant is 1 or 2, only support
\"
-prec_o=int8
\"
or
\"
-prec_o=fp8
\"
cases."
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax
=
a
>
absmax
?
a
:
absmax
;
absmax
=
a
>
absmax
?
a
:
absmax
;
}
}
// printf("cpu:absmax:%f\n", absmax);
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
constexpr
ComputeDataType
kMaxY
=
std
::
is_same
<
YDataType
,
ck_tile
::
fp8_t
>::
value
?
240.0
:
std
::
is_same
<
YDataType
,
ck_tile
::
int8_t
>::
value
?
127.0
:
0.0
;
ComputeDataType
y_scale
=
absmax
/
kMaxY
;
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
{
...
@@ -400,6 +406,16 @@ int main(int argc, char* argv[])
...
@@ -400,6 +406,16 @@ int main(int argc, char* argv[])
{
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
return
-
3
;
}
}
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
View file @
64d5c4d6
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
;
do
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
"-fquant=2 -prec_o=fp8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
fadd
in
"0"
"1"
;
do
for
fadd
in
"0"
"1"
;
do
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
...
@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
...
@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
#
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
done
...
...
include/ck_tile/host/check_err.hpp
View file @
64d5c4d6
...
@@ -443,7 +443,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -443,7 +443,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment