Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6ff00ed4
Commit
6ff00ed4
authored
Jan 17, 2022
by
Chao Liu
Browse files
teak example
parent
dcf48977
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
71 deletions
+15
-71
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
+15
-71
No files found.
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
View file @
6ff00ed4
...
...
@@ -52,70 +52,13 @@ struct BiasReluAdd
}
};
// v0 is from A * B
// v1 is from C0
// v2 is from C1
struct
BiasLeakyReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
(
a
+
c
);
return
d
;
}
};
struct
BiasLeakyRelu
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
)
const
{
constexpr
float
alpha
=
0.1
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
c
;
return
d
;
}
};
struct
BiasAdd
{
#if 1
// correct result
// no scratch memory, good VGPR allocation (59)
// good perf (101Tflops)
template
<
typename
T1
,
typename
T2
>
__host__
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
__host__
__device__
constexpr
float
operator
()(
float
v0
,
ck
::
half_t
v1
,
ck
::
half_t
v2
)
const
{
constexpr
float
alpha
=
0.1
;
constexpr
float
beta
=
0.2
;
...
...
@@ -124,7 +67,7 @@ struct BiasAdd
// compiler seems very volatile to the order of these calculation:
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// over-allocation. Therefore, move v0 calculation to the very end
float
a
=
T1
(
beta
)
*
v1
+
T2
(
gamma
)
*
v2
;
float
a
=
ck
::
half_t
(
beta
)
*
v1
+
ck
::
half_t
(
gamma
)
*
v2
;
float
b
=
a
+
float
(
alpha
)
*
v0
;
return
b
;
...
...
@@ -151,7 +94,7 @@ struct BiasAdd
{
return
0.1
*
v0
+
0.2
*
v1
+
0.3
*
v2
;
}
#elif
1
#elif
0
// wrong result
// lots of scratch memory
// huge perf drop
...
...
@@ -215,16 +158,15 @@ static void host_verify(const Tensor<AType>& a_m_k,
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
float
acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a_element_op
(
a_m_k
(
m
,
k
)))
*
static_cast
<
const
double
>
(
b_element_op
(
b_k_n
(
k
,
n
)));
acc
+=
static_cast
<
const
double
>
(
a_element_op
(
a_m_k
(
m
,
k
)))
*
static_cast
<
const
double
>
(
b_element_op
(
b_k_n
(
k
,
n
)));
}
c_m_n
(
m
,
n
)
=
c_element_op
(
v
,
static_cast
<
const
double
>
(
c0_m_n
(
m
,
n
)),
static_cast
<
const
double
>
(
c1_m_n
(
m
,
n
)));
c_m_n
(
m
,
n
)
=
c_element_op
(
acc
,
c0_m_n
(
m
,
n
),
c1_m_n
(
m
,
n
));
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
...
...
@@ -249,9 +191,9 @@ int main(int argc, char* argv[])
if
(
argc
==
4
)
{
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
10
)
{
...
...
@@ -337,7 +279,9 @@ int main(int argc, char* argv[])
c0_m_n_device_buf
.
ToDevice
(
c0_m_n
.
mData
.
data
());
c1_m_n_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
auto
c_element_op
=
BiasReluAdd
{};
auto
a_element_op
=
AOp
{};
auto
b_element_op
=
BOp
{};
auto
c_element_op
=
COp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
...
...
@@ -354,8 +298,8 @@ int main(int argc, char* argv[])
StrideA
,
StrideB
,
StrideC
,
PassThrough
{}
,
PassThrough
{}
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment