Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b321bd86
Commit
b321bd86
authored
Nov 29, 2024
by
rocking
Browse files
Support pure quant in instance library
parent
26f221eb
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
49 deletions
+67
-49
example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp
.../ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp
+50
-42
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
.../12_smoothquant/instances/smoothquant_instance_common.hpp
+5
-3
example/ck_tile/12_smoothquant/smoothquant.cpp
example/ck_tile/12_smoothquant/smoothquant.cpp
+8
-3
example/ck_tile/12_smoothquant/smoothquant.hpp
example/ck_tile/12_smoothquant/smoothquant.hpp
+4
-1
No files found.
example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp
View file @
b321bd86
...
...
@@ -11,7 +11,8 @@ template <typename DataType_,
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kTwoPass_
>
bool
kTwoPass_
,
bool
kSmoothX_
>
using
trait_
=
smoothquant_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
...
...
@@ -19,9 +20,10 @@ using trait_ = smoothquant_traits_<DataType_,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kTwoPass_
>
;
kTwoPass_
,
kSmoothX_
>
;
template
<
typename
data_type
>
template
<
typename
data_type
,
bool
smooth_x
>
float
smoothquant_dispatch
(
smoothquant_traits
/*t*/
,
smoothquant_args
a
,
const
ck_tile
::
stream_config
&
s
)
...
...
@@ -30,99 +32,99 @@ float smoothquant_dispatch(smoothquant_traits /*t*/,
// clang-format off
// rm rn tm tn vn pd 2p
if
(
a
.
n
<=
64
)
{
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
128
)
{
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
256
)
{
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
512
)
{
if
(
a
.
n
%
8
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
8
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
768
)
{
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1024
)
{
if
(
a
.
n
%
8
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
2
,
128
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
2
,
128
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
2
,
128
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1536
)
{
if
(
a
.
n
%
8
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
2048
)
{
if
(
a
.
n
%
8
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
3072
)
{
if
(
a
.
n
%
8
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
smooth_x
>>
(
s
,
a
);
}
else
if
(
a
.
n
>
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
true
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
true
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
smooth_x
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
smooth_x
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
>>
(
s
,
a
);
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
smooth_x
>>
(
s
,
a
);
}
return
r
;
// clang-format on
...
...
@@ -132,11 +134,17 @@ float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::strea
{
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
return
smoothquant_dispatch
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
if
(
t
.
smooth_x
)
return
smoothquant_dispatch
<
ck_tile
::
fp16_t
,
true
>
(
t
,
a
,
s
);
else
return
smoothquant_dispatch
<
ck_tile
::
fp16_t
,
false
>
(
t
,
a
,
s
);
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
{
return
smoothquant_dispatch
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
if
(
t
.
smooth_x
)
return
smoothquant_dispatch
<
ck_tile
::
bf16_t
,
true
>
(
t
,
a
,
s
);
else
return
smoothquant_dispatch
<
ck_tile
::
bf16_t
,
false
>
(
t
,
a
,
s
);
}
else
throw
std
::
runtime_error
(
"Without supported instances!"
);
...
...
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
View file @
b321bd86
...
...
@@ -18,7 +18,8 @@ template <typename DataType_,
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kTwoPass_
>
bool
kTwoPass_
,
bool
kSmoothX_
>
using
trait_
=
smoothquant_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
...
...
@@ -26,7 +27,8 @@ using trait_ = smoothquant_traits_<DataType_,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kTwoPass_
>
;
kTwoPass_
,
kSmoothX_
>
;
template
<
typename
Traits_
>
float
smoothquant_
(
const
S
&
s
,
A
a
)
...
...
@@ -42,7 +44,7 @@ float smoothquant_(const S& s, A a)
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kTwoPass
,
true
>
;
Traits_
::
kSmoothX
>
;
using
OnePassPipeline
=
ck_tile
::
SmoothquantPipelineOnePass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
SmoothquantPipelineTwoPass
<
PipelineProblem
>
;
...
...
example/ck_tile/12_smoothquant/smoothquant.cpp
View file @
b321bd86
...
...
@@ -34,6 +34,7 @@ auto create_args(int argc, char* argv[])
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"sx"
,
"1"
,
"0 is pure quantization, 1 is to apply smoothquant"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
...
...
@@ -53,6 +54,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
stride
<
0
)
stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
bool
smooth_x
=
arg_parser
.
get_bool
(
"sx"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
...
...
@@ -92,7 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
smoothquant_traits
traits
{
data_type
};
smoothquant_traits
traits
{
data_type
,
smooth_x
};
smoothquant_args
args
{
x_buf
.
GetDeviceBuffer
(),
xscale_buf
.
GetDeviceBuffer
(),
...
...
@@ -125,7 +127,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
{
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
if
(
smooth_x
)
y_host
(
m_
,
n_
)
=
v_x
*
v_xscale
;
else
y_host
(
m_
,
n_
)
=
v_x
;
}
};
...
...
example/ck_tile/12_smoothquant/smoothquant.hpp
View file @
b321bd86
...
...
@@ -44,7 +44,8 @@ template <typename DataType_,
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kTwoPass_
>
bool
kTwoPass_
,
bool
kSmoothX_
>
struct
smoothquant_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
...
...
@@ -100,6 +101,7 @@ struct smoothquant_traits_
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kSmoothX
=
kSmoothX_
;
};
template
<
typename
Traits_
>
...
...
@@ -109,6 +111,7 @@ float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);
struct
smoothquant_traits
{
std
::
string
data_type
;
bool
smooth_x
;
};
float
smoothquant
(
smoothquant_traits
,
smoothquant_args
,
const
ck_tile
::
stream_config
&
);
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment