Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
00b1a7e2
Commit
00b1a7e2
authored
Oct 30, 2024
by
rocking
Browse files
Add smoothquant instance library
parent
d6b0e59e
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
572 additions
and
3 deletions
+572
-3
example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp
...othquant/instances/smoothquant_fp16_n64_n128_instance.cpp
+12
-0
example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp
..._smoothquant/instances/smoothquant_fp16_n768_instance.cpp
+12
-0
example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp
.../ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp
+150
-0
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
.../12_smoothquant/instances/smoothquant_instance_common.hpp
+62
-0
example/ck_tile/12_smoothquant/smoothquant.cpp
example/ck_tile/12_smoothquant/smoothquant.cpp
+219
-0
example/ck_tile/12_smoothquant/smoothquant.hpp
example/ck_tile/12_smoothquant/smoothquant.hpp
+114
-0
include/ck_tile/ops/smoothquant/kernel/smoothquant_shape.hpp
include/ck_tile/ops/smoothquant/kernel/smoothquant_shape.hpp
+1
-1
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
...ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
+2
-2
No files found.
example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp
0 → 100644
View file @
00b1a7e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template
float
smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp
0 → 100644
View file @
00b1a7e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template
float
smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp
0 → 100644
View file @
00b1a7e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "smoothquant.hpp"
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kTwoPass_
>
using
trait_
=
smoothquant_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kTwoPass_
>
;
template
<
typename
data_type
>
float
smoothquant_dispatch
(
smoothquant_traits
/*t*/
,
smoothquant_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
#if 1
float
r
=
-
1
;
// clang-format off
// rm rn tm tn vn pd 2p
if
(
a
.
n
<=
64
)
{
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
128
)
{
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
256
)
{
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
512
)
{
/*if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 8, true, false>>(s, a);
else */
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
8
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
768
)
{
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1024
)
{
/*if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 2, 128, 8, true, false>>(s, a);
else */
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
2
,
128
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
2
,
128
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1536
)
{
/*if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 8, true, false>>(s, a);
else */
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
2048
)
{
/*if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 1, 256, 8, true, false>>(s, a);
else */
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
3072
)
{
/*if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 128, 8, true, false>>(s, a);
else */
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
4096
)
{
/*if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, false>>(s, a);
else */
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
>
4096
)
{
/*if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, true>>(s, a);
else */
if
(
a
.
n
%
4
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
>>
(
s
,
a
);
else
r
=
smoothquant_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
>>
(
s
,
a
);
}
return
r
;
#else
return
smoothquant_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
#endif
// clang-format on
}
float
smoothquant
(
smoothquant_traits
t
,
smoothquant_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
float
r
=
-
1
;
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
return
smoothquant_dispatch
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
{
return
smoothquant_dispatch
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
}
if
(
r
<
0
)
throw
std
::
runtime_error
(
"Without supported instances!"
);
return
r
;
}
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
0 → 100644
View file @
00b1a7e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "smoothquant.hpp"
#include <iostream>
#pragma once
using
S
=
ck_tile
::
stream_config
;
using
A
=
smoothquant_args
;
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kTwoPass_
>
using
trait_
=
smoothquant_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kTwoPass_
>
;
template
<
typename
Traits_
>
float
smoothquant_
(
const
S
&
s
,
A
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
SmoothquantPipelineProblem
<
typename
SmoothquantTypeConfig
<
DataType
>::
XDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
XScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
ComputeDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
YScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
QYDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kTwoPass
>
;
using
OnePassPipeline
=
ck_tile
::
SmoothquantPipelineOnePass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
SmoothquantPipelineTwoPass
<
PipelineProblem
>
;
using
Pipeline
=
std
::
conditional_t
<
Traits_
::
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Smoothquant
<
Pipeline
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
Kernel
::
GetName
()
<<
std
::
flush
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/12_smoothquant/smoothquant.cpp
0 → 100644
View file @
00b1a7e2
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include <cstring>
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
()
{
double
rtol
=
1e-5
;
double
atol
=
1e-5
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
()
{
double
rtol
=
1e-5
;
double
atol
=
1e-5
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
int8_t
>
()
{
// due to rounding, int8 quantization might have 1 abs error
double
rtol
=
1
;
double
atol
=
1
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"warmup"
,
"0"
,
"cold iter"
)
.
insert
(
"repeat"
,
"1"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
using
TypeConfig
=
SmoothquantTypeConfig
<
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
XScaleDataType
=
typename
TypeConfig
::
XScaleDataType
;
using
YScaleDataType
=
typename
TypeConfig
::
YScaleDataType
;
using
QYDataType
=
typename
TypeConfig
::
QYDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
xscale_buf
(
xscale_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
yscale_buf
(
yscale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
qy_buf
(
qy_host_dev
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
xscale_buf
.
ToDevice
(
xscale_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
smoothquant_traits
traits
{
data_type
};
smoothquant_args
args
{
x_buf
.
GetDeviceBuffer
(),
xscale_buf
.
GetDeviceBuffer
(),
yscale_buf
.
GetDeviceBuffer
(),
qy_buf
.
GetDeviceBuffer
(),
m
,
n
,
stride
};
float
ave_time
=
smoothquant
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
XScaleDataType
)
*
n
+
sizeof
(
YScaleDataType
)
*
m
+
sizeof
(
QYDataType
)
*
m
*
n
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
using
YDataType
=
ComputeDataType
;
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
stride
,
1
});
// smooth outlier
{
auto
f
=
[
&
](
auto
n_
)
{
auto
v_xscale
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
xscale_host
(
n_
));
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
{
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
y_host
(
m_
,
n_
)
=
v_x
*
v_xscale
;
}
};
ck_tile
::
make_ParallelTensorFunctor
(
f
,
xscale_host
.
get_element_space_size
())(
std
::
thread
::
hardware_concurrency
());
}
// yscale
{
ck_tile
::
HostTensor
<
YDataType
>
y_rowwise_amax_host
({
m
});
using
ReduceAmax
=
ck_tile
::
ReduceOp
::
AbsMax
;
ck_tile
::
reference_reduce
<
ComputeDataType
,
ComputeDataType
,
YDataType
>
(
y_host
,
y_rowwise_amax_host
,
ReduceAmax
{});
auto
op
=
[](
const
auto
&
v0
)
{
return
v0
/
ck_tile
::
type_convert
<
ComputeDataType
>
(
ck_tile
::
numeric
<
QYDataType
>::
max
());
};
ck_tile
::
reference_unary_elementwise
<
YDataType
,
YScaleDataType
,
ComputeDataType
>
(
y_rowwise_amax_host
,
yscale_host_ref
,
op
);
yscale_buf
.
FromDevice
(
yscale_host_dev
.
mData
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
YScaleDataType
>
();
pass
&=
ck_tile
::
check_err
(
yscale_host_dev
,
yscale_host_ref
,
std
::
string
(
"yscale Error: Incorrect results!"
),
rtol
,
atol
);
}
// rowwise quantization
{
ck_tile
::
reference_rowwise_quantization2d
<
YDataType
,
YScaleDataType
,
QYDataType
>
(
y_host
,
yscale_host_ref
,
qy_host_ref
);
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
if
(
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
qy_host_ref
,
std
::
string
(
"qy Error: Incorrect results!"
),
rtol
,
atol
);
}
else
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
stride
,
qy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
stride
,
qy_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
qy_host_ref_row
,
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
)
{
return
run
<
ck_tile
::
bf16_t
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/12_smoothquant/smoothquant.hpp
0 → 100644
View file @
00b1a7e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>
template
<
typename
DataType
>
struct
SmoothquantTypeConfig
;
template
<
>
struct
SmoothquantTypeConfig
<
ck_tile
::
half_t
>
{
using
XDataType
=
ck_tile
::
half_t
;
using
XScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
};
template
<
>
struct
SmoothquantTypeConfig
<
ck_tile
::
bf16_t
>
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
};
// runtime args
struct
smoothquant_args
:
public
ck_tile
::
SmoothquantHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kTwoPass_
>
struct
smoothquant_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static
constexpr
ck_tile
::
index_t
total_warps
=
(
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
/
warpSize
;
// num of warps along m
static
constexpr
ck_tile
::
index_t
BlockWarps_M
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return
total_warps
/
(
ThreadPerBlock_N_
/
warpSize
);
}
}();
// num of warps along n
static
constexpr
ck_tile
::
index_t
BlockWarps_N
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
1
;
}
else
{
static_assert
(
ThreadPerBlock_N_
%
warpSize
==
0
);
return
ThreadPerBlock_N_
/
warpSize
;
}
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
SmoothquantShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
Traits_
>
float
smoothquant_
(
const
ck_tile
::
stream_config
&
s
,
smoothquant_args
a
);
// This is the public API, will be generated by script
struct
smoothquant_traits
{
std
::
string
data_type
;
};
float
smoothquant
(
smoothquant_traits
,
smoothquant_args
,
const
ck_tile
::
stream_config
&
);
include/ck_tile/ops/smoothquant/kernel/smoothquant_shape.hpp
View file @
00b1a7e2
...
@@ -41,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N>
...
@@ -41,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
SmoothquantShape
struct
SmoothquantShape
{
{
// block size
// block size
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
View file @
00b1a7e2
...
@@ -28,8 +28,8 @@ struct SmoothquantPipelineProblem
...
@@ -28,8 +28,8 @@ struct SmoothquantPipelineProblem
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment