Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
08c9433e
Commit
08c9433e
authored
Dec 04, 2021
by
Chao Liu
Browse files
fix relu
parent
41cdd380
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
16 deletions
+102
-16
example/1_gemm_xdl/gemm_xdl.cpp
example/1_gemm_xdl/gemm_xdl.cpp
+32
-11
example/2_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
example/2_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
+35
-3
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
+35
-2
No files found.
example/1_gemm_xdl/gemm_xdl.cpp
View file @
08c9433e
...
...
@@ -123,17 +123,9 @@ struct DeviceGemmInstance<float,
int
main
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
4
)
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
exit
(
0
);
}
const
bool
do_verification
=
std
::
stoi
(
argv
[
1
]);
const
int
init_method
=
std
::
stoi
(
argv
[
2
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
3
]);
bool
do_verification
=
0
;
int
init_method
=
0
;
int
nrepeat
=
5
;
// GEMM shape
ck
::
index_t
M
=
3840
;
...
...
@@ -144,6 +136,35 @@ int main(int argc, char* argv[])
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
if
(
argc
==
4
)
{
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
}
else
if
(
argc
==
10
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
}
// matrix data type
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
...
...
example/2_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
View file @
08c9433e
...
...
@@ -20,10 +20,42 @@
// 0 in the "n" dimension
// assume C1 and C have same layout C
struct
BiasReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float
a
=
v1
+
v2
;
float
b
=
v2
;
float
c
=
(
v0
>
-
v1
)
?
a
+
v0
:
v2
;
return
c
;
#endif
}
};
// v0 is from A * B
// v1 is from C0
// v2 is from C1
struct
BiasReluAdd
struct
Bias
Leaky
ReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
...
...
@@ -51,7 +83,7 @@ struct BiasReluAdd
}
};
struct
BiasRelu
struct
Bias
Leaky
Relu
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
)
const
...
...
@@ -99,7 +131,7 @@ struct BiasAdd
}
#elif 0
float
alpha
=
0.1
;
float
beta
=
0.2
;
float
beta
=
0.2
;
float
gamma
=
0.3
;
// wrong result
...
...
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
View file @
08c9433e
...
...
@@ -23,7 +23,7 @@ struct PassThrough
}
};
struct
BiasReluAdd
struct
Bias
Leaky
ReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
...
...
@@ -97,7 +97,39 @@ struct BiasReluAdd
}
};
struct
BiasRelu
struct
BiasReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float
a
=
v1
+
v2
;
float
b
=
v2
;
float
c
=
(
v0
>
-
v1
)
?
a
+
v0
:
v2
;
return
c
;
#endif
}
};
struct
BiasLeakyRelu
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
)
const
...
...
@@ -377,6 +409,7 @@ int main(int argc, char* argv[])
std
::
size_t
num_btype
=
sizeof
(
InDataType
)
*
(
N
*
C
*
Hi
*
Wi
)
+
sizeof
(
WeiDataType
)
*
(
K
*
C
*
Y
*
X
)
+
sizeof
(
OutDataType
)
*
(
N
*
K
*
Ho
*
Wo
)
+
sizeof
(
OutDataType
)
*
(
K
)
+
sizeof
(
OutDataType
)
*
(
N
*
K
*
Ho
*
Wo
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment