Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
788f4786
Commit
788f4786
authored
Dec 02, 2021
by
Chao Liu
Browse files
tweak
parent
5a79ff1e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
112 additions
and
42 deletions
+112
-42
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
+112
-42
No files found.
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
View file @
788f4786
...
@@ -26,14 +26,74 @@ struct PassThrough
...
@@ -26,14 +26,74 @@ struct PassThrough
struct
BiasReluAdd
struct
BiasReluAdd
{
{
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
__host__
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif
0
// this spill register
float
a
=
v0
+
v1
;
float
a
=
v0
+
v1
;
float
b
=
float
(
0.1
)
*
a
;
float
b
=
float
(
0.1
)
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
float
d
=
c
+
v2
;
return
d
;
return
d
;
#elif 0
// this use lots of registers (but no spill)
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this spill registers, 89 Tflops
float
a
=
v0
+
v1
;
float
alpha
=
0.1
;
float
b
;
asm
volatile
(
"
\n
\
v_mul_f32_e32 %0, %1, %2
\n
\
"
:
"=v"
(
b
)
:
"s"
(
alpha
),
"v"
(
a
));
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#endif
}
}
};
};
...
@@ -114,53 +174,63 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
...
@@ -114,53 +174,63 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
// if(argc != 4)
bool
do_verification
=
0
;
int
init_method
=
0
;
int
nrepeat
=
5
;
// Conv shape
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
256
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
conv_stride_h
=
2
;
ck
::
index_t
conv_stride_w
=
2
;
ck
::
index_t
conv_dilation_h
=
1
;
ck
::
index_t
conv_dilation_w
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
}
if
(
argc
==
19
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
N
=
std
::
stoi
(
argv
[
4
]);
K
=
std
::
stoi
(
argv
[
5
]);
C
=
std
::
stoi
(
argv
[
6
]);
Y
=
std
::
stoi
(
argv
[
7
]);
X
=
std
::
stoi
(
argv
[
8
]);
Hi
=
std
::
stoi
(
argv
[
9
]);
Wi
=
std
::
stoi
(
argv
[
10
]);
conv_stride_h
=
std
::
stoi
(
argv
[
11
]);
conv_stride_w
=
std
::
stoi
(
argv
[
12
]);
conv_dilation_h
=
std
::
stoi
(
argv
[
13
]);
conv_dilation_w
=
std
::
stoi
(
argv
[
14
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
15
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
16
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
17
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
18
]);
}
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
//
exit(0);
exit
(
0
);
}
}
const
bool
do_verification
=
std
::
stoi
(
argv
[
1
]);
const
int
init_method
=
std
::
stoi
(
argv
[
2
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
3
]);
// Conv shape
#if 0
const ck::index_t N = 128;
const ck::index_t K = 256;
const ck::index_t C = 192;
const ck::index_t Y = 3;
const ck::index_t X = 3;
const ck::index_t Hi = 71;
const ck::index_t Wi = 71;
const ck::index_t conv_stride_h = 2;
const ck::index_t conv_stride_w = 2;
const ck::index_t conv_dilation_h = 1;
const ck::index_t conv_dilation_w = 1;
const ck::index_t in_left_pad_h = 1;
const ck::index_t in_left_pad_w = 1;
const ck::index_t in_right_pad_h = 1;
const ck::index_t in_right_pad_w = 1;
#else
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
4
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
5
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
6
]);
const
ck
::
index_t
Y
=
std
::
stoi
(
argv
[
7
]);
const
ck
::
index_t
X
=
std
::
stoi
(
argv
[
8
]);
const
ck
::
index_t
Hi
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
Wi
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
12
]);
const
ck
::
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
13
]);
const
ck
::
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
14
]);
const
ck
::
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
15
]);
const
ck
::
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
16
]);
const
ck
::
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
17
]);
const
ck
::
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
18
]);
#endif
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment