Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c32f0c3b
Commit
c32f0c3b
authored
Jul 01, 2022
by
carlushuang
Browse files
update xdnn desc as cmd line
parent
8d0e00a0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
422 additions
and
62 deletions
+422
-62
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
+68
-31
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
...conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
+68
-31
include/ck/host_utility/xdnn_desc.hpp
include/ck/host_utility/xdnn_desc.hpp
+286
-0
No files found.
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
View file @
c32f0c3b
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "envvar.hpp"
#include "envvar.hpp"
#include "xdnn_desc.hpp"
#include <omp.h>
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define AVX2_DATA_ALIGNMENT 32
...
@@ -237,6 +238,39 @@ int main(int argc, char* argv[])
...
@@ -237,6 +238,39 @@ int main(int argc, char* argv[])
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
ck
::
getenv_int
(
"CK_USE_XDNN_DESC"
,
0
)
==
1
)
{
assert
(
argc
==
4
);
data_type
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
ck
::
desc_t
xdnn_desc
;
if
(
str2desc
(
&
xdnn_desc
,
argv
[
3
])
==
XDNN_OK
)
{
N
=
xdnn_desc
.
mb
;
K
=
xdnn_desc
.
oc
;
C
=
xdnn_desc
.
ic
;
Y
=
xdnn_desc
.
kh
;
X
=
xdnn_desc
.
kw
;
Hi
=
xdnn_desc
.
ih
;
Wi
=
xdnn_desc
.
iw
;
conv_stride_h
=
xdnn_desc
.
sh
;
conv_stride_w
=
xdnn_desc
.
sw
;
conv_dilation_h
=
xdnn_desc
.
dh
;
conv_dilation_w
=
xdnn_desc
.
dw
;
in_left_pad_h
=
xdnn_desc
.
ph
;
in_left_pad_w
=
xdnn_desc
.
pw
;
in_right_pad_h
=
xdnn_desc
.
ph
;
in_right_pad_w
=
xdnn_desc
.
pw
;
}
else
{
printf
(
"fail to parse xdnn arg:%s
\n
"
,
argv
[
3
]);
}
}
else
{
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
data_type
=
0
;
data_type
=
0
;
...
@@ -276,6 +310,7 @@ int main(int argc, char* argv[])
...
@@ -276,6 +310,7 @@ int main(int argc, char* argv[])
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
1
);
exit
(
1
);
}
}
}
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
)
{
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
)
{
using
InDataType
=
decltype
(
input_type
);
using
InDataType
=
decltype
(
input_type
);
...
@@ -333,6 +368,8 @@ int main(int argc, char* argv[])
...
@@ -333,6 +368,8 @@ int main(int argc, char* argv[])
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
fflush
(
stdout
);
int
per_pixel_check
=
0
;
int
per_pixel_check
=
0
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
...
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
View file @
c32f0c3b
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "envvar.hpp"
#include "envvar.hpp"
#include "xdnn_desc.hpp"
#include <omp.h>
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define AVX2_DATA_ALIGNMENT 32
...
@@ -273,6 +274,39 @@ int main(int argc, char* argv[])
...
@@ -273,6 +274,39 @@ int main(int argc, char* argv[])
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
ck
::
getenv_int
(
"CK_USE_XDNN_DESC"
,
0
)
==
1
)
{
assert
(
argc
==
4
);
data_type
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
ck
::
desc_t
xdnn_desc
;
if
(
str2desc
(
&
xdnn_desc
,
argv
[
3
])
==
XDNN_OK
)
{
N
=
xdnn_desc
.
mb
;
K
=
xdnn_desc
.
oc
;
C
=
xdnn_desc
.
ic
;
Y
=
xdnn_desc
.
kh
;
X
=
xdnn_desc
.
kw
;
Hi
=
xdnn_desc
.
ih
;
Wi
=
xdnn_desc
.
iw
;
conv_stride_h
=
xdnn_desc
.
sh
;
conv_stride_w
=
xdnn_desc
.
sw
;
conv_dilation_h
=
xdnn_desc
.
dh
;
conv_dilation_w
=
xdnn_desc
.
dw
;
in_left_pad_h
=
xdnn_desc
.
ph
;
in_left_pad_w
=
xdnn_desc
.
pw
;
in_right_pad_h
=
xdnn_desc
.
ph
;
in_right_pad_w
=
xdnn_desc
.
pw
;
}
else
{
printf
(
"fail to parse xdnn arg:%s
\n
"
,
argv
[
3
]);
}
}
else
{
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
data_type
=
0
;
data_type
=
0
;
...
@@ -312,6 +346,7 @@ int main(int argc, char* argv[])
...
@@ -312,6 +346,7 @@ int main(int argc, char* argv[])
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
1
);
exit
(
1
);
}
}
}
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
)
{
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
)
{
using
InDataType
=
decltype
(
input_type
);
using
InDataType
=
decltype
(
input_type
);
...
@@ -389,6 +424,8 @@ int main(int argc, char* argv[])
...
@@ -389,6 +424,8 @@ int main(int argc, char* argv[])
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
fflush
(
stdout
);
int
per_pixel_check
=
0
;
int
per_pixel_check
=
0
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
...
include/ck/host_utility/xdnn_desc.hpp
0 → 100644
View file @
c32f0c3b
#pragma once
#include <string>
#include <vector>
#include <functional>
#define XDNN_OK 0
#define XDNN_FAIL 1
namespace
ck
{
int
sanitize_desc
(
int
&
ndims
,
std
::
vector
<
std
::
reference_wrapper
<
int64_t
>>
d
,
std
::
vector
<
std
::
reference_wrapper
<
int64_t
>>
h
,
std
::
vector
<
std
::
reference_wrapper
<
int64_t
>>
w
,
const
std
::
vector
<
int64_t
>&
def_values
,
bool
must_have_spatial
)
{
size_t
N
=
d
.
size
();
assert
(
h
.
size
()
==
N
&&
w
.
size
()
==
N
&&
def_values
.
size
()
==
N
);
ndims
=
5
;
// check output spatial values
const
bool
no_d
=
d
[
0
].
get
()
==
0
;
const
bool
no_h
=
h
[
0
].
get
()
==
0
;
const
bool
no_w
=
w
[
0
].
get
()
==
0
;
if
(
no_d
)
ndims
--
;
if
(
no_d
&&
no_h
)
ndims
--
;
if
(
no_d
&&
no_h
&&
no_w
)
ndims
--
;
if
(
must_have_spatial
&&
ndims
<=
2
)
return
XDNN_FAIL
;
if
(
ndims
==
5
)
{
if
(
no_h
&&
no_w
)
{
// User specified values for the d dimension but not values for h
// and w dimensions. Propagate d values to h and w dimensions.
for
(
size_t
n
=
0
;
n
<
N
;
++
n
)
w
[
n
].
get
()
=
h
[
n
].
get
()
=
d
[
n
].
get
();
}
else
if
(
!
no_h
&&
!
no_w
)
{
// User specified them all, good to go.
}
else
{
// Problem is not cubic and one of h or w dimension is missing.
return
XDNN_FAIL
;
}
}
else
if
(
ndims
==
4
&&
no_w
)
{
// User specified values for the h dimension but not values for the w
// dimension. Propagate h values to the w dimension.
for
(
size_t
n
=
0
;
n
<
N
;
++
n
)
w
[
n
].
get
()
=
h
[
n
].
get
();
}
for
(
size_t
n
=
0
;
n
<
N
;
++
n
)
{
if
(
ndims
<
5
)
d
[
n
].
get
()
=
def_values
[
n
];
if
(
ndims
<
4
)
h
[
n
].
get
()
=
def_values
[
n
];
if
(
ndims
<
3
)
w
[
n
].
get
()
=
def_values
[
n
];
}
return
XDNN_OK
;
}
struct
desc_t
{
int64_t
g
,
mb
;
int64_t
ic
,
id
,
ih
,
iw
;
int64_t
oc
,
od
,
oh
,
ow
;
int64_t
kd
,
kh
,
kw
;
int64_t
sd
,
sh
,
sw
;
int64_t
pd
,
ph
,
pw
;
int64_t
pd_r
,
ph_r
,
pw_r
;
// End side padding for each dimension
int64_t
dd
,
dh
,
dw
;
bool
has_groups
;
const
char
*
name
;
int
ndims
;
// Initialize dependent opposite-side paddings values
// from the shape parameters
void
init_pad_r
(
bool
is_deconv
)
{
pw_r
=
opp_pad
(
is_deconv
,
iw
,
ow
,
kw
,
sw
,
pw
,
dw
);
ph_r
=
opp_pad
(
is_deconv
,
ih
,
oh
,
kh
,
sh
,
ph
,
dh
);
pd_r
=
opp_pad
(
is_deconv
,
id
,
od
,
kd
,
sd
,
pd
,
dd
);
}
int64_t
desc_nelems
(
int
arg
,
int
mask
)
const
;
private:
int64_t
opp_pad
(
bool
is_deconv
,
int64_t
i
,
int64_t
o
,
int64_t
k
,
int64_t
s
,
int64_t
p
,
int64_t
d
)
const
{
return
is_deconv
?
(
i
-
1
)
*
s
-
o
+
((
k
-
1
)
*
(
d
+
1
)
+
1
)
-
p
:
(
o
-
1
)
*
s
-
i
+
((
k
-
1
)
*
(
d
+
1
)
+
1
)
-
p
;
}
};
static
inline
int
str2desc
(
desc_t
*
desc
,
const
char
*
str
,
bool
is_deconv
=
false
)
{
/* canonical form:
* gXmbX_icXidXihXiwX_ocXodXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwXnS
*
* where X is number, S - string
* note: symbol `_` is ignored
*
* implicit rules:
* - if smaller dimensions are not specified => square or cubic form;
* - if output is undefined => compute output;
* - if padding is undefined => compute trivial padding;
*/
desc_t
d
{
0
};
d
.
g
=
1
;
d
.
mb
=
2
;
d
.
sd
=
d
.
sh
=
d
.
sw
=
1
;
d
.
pd
=
d
.
ph
=
d
.
pw
=
-
1
;
const
char
*
s
=
str
;
assert
(
s
);
#define CASE_NN(prb, c) \
do \
{ \
if(!strncmp(prb, s, strlen(prb))) \
{ \
ok = 1; \
s += strlen(prb); \
char* end_s; \
d.c = strtol(s, &end_s, 10); \
s += (end_s - s); \
/* check any # groups, including one, works correctly */
\
if(!strncmp(prb, "g", 1)) \
d.has_groups = true; \
if(d.c < 0) \
return XDNN_FAIL; \
/* printf("@@@debug: %s: %d\n", prb, d. c); */
\
} \
} while(0)
#define CASE_N(c) CASE_NN(#c, c)
while
(
*
s
)
{
int
ok
=
0
;
CASE_N
(
g
);
CASE_N
(
mb
);
CASE_N
(
ic
);
CASE_N
(
id
);
CASE_N
(
ih
);
CASE_N
(
iw
);
CASE_N
(
oc
);
CASE_N
(
od
);
CASE_N
(
oh
);
CASE_N
(
ow
);
CASE_N
(
kd
);
CASE_N
(
kh
);
CASE_N
(
kw
);
CASE_N
(
sd
);
CASE_N
(
sh
);
CASE_N
(
sw
);
CASE_N
(
pd
);
CASE_N
(
ph
);
CASE_N
(
pw
);
CASE_N
(
dd
);
CASE_N
(
dh
);
CASE_N
(
dw
);
if
(
*
s
==
'n'
)
{
d
.
name
=
s
+
1
;
break
;
}
if
(
*
s
==
'_'
)
++
s
;
if
(
!
ok
)
return
XDNN_FAIL
;
}
#undef CASE_NN
#undef CASE_N
if
(
d
.
has_groups
&&
d
.
g
<=
0
)
return
XDNN_FAIL
;
if
(
d
.
ic
==
0
||
d
.
oc
==
0
)
return
XDNN_FAIL
;
if
(
d
.
sd
<=
0
||
d
.
sh
<=
0
||
d
.
sw
<=
0
)
return
XDNN_FAIL
;
auto
compute_out
=
[](
bool
is_deconv
,
int64_t
i
,
int64_t
k
,
int64_t
s
,
int64_t
p
,
int64_t
d
)
{
if
(
is_deconv
)
return
(
i
-
1
)
*
s
+
(
k
-
1
)
*
(
d
+
1
)
-
2
*
p
+
1
;
else
return
(
i
-
((
k
-
1
)
*
(
d
+
1
)
+
1
)
+
2
*
p
)
/
s
+
1
;
};
auto
compute_pad
=
[](
bool
is_deconv
,
int64_t
o
,
int64_t
i
,
int64_t
k
,
int64_t
s
,
int64_t
d
)
{
if
(
is_deconv
)
return
((
i
-
1
)
*
s
-
o
+
((
k
-
1
)
*
(
d
+
1
)
+
1
))
/
2
;
else
return
((
o
-
1
)
*
s
-
i
+
((
k
-
1
)
*
(
d
+
1
)
+
1
))
/
2
;
};
const
bool
no_d
=
(
d
.
id
|
d
.
kd
|
d
.
od
|
d
.
dd
)
==
0
&&
d
.
sd
==
1
&&
d
.
pd
<
1
;
const
bool
no_h
=
(
d
.
ih
|
d
.
kh
|
d
.
oh
|
d
.
dh
)
==
0
&&
d
.
sh
==
1
&&
d
.
ph
<
1
;
const
bool
no_w
=
(
d
.
iw
|
d
.
kw
|
d
.
ow
|
d
.
dw
)
==
0
&&
d
.
sw
==
1
&&
d
.
pw
<
1
;
// printf("no_h:%d, no_w:%d, d.iw:%d\n", no_h, no_w, d.iw);
if
(
!
no_d
)
{
if
(
!
d
.
id
||
!
d
.
kd
)
return
XDNN_FAIL
;
if
(
!
d
.
od
)
{
if
(
d
.
pd
<
0
)
d
.
pd
=
0
;
d
.
od
=
compute_out
(
is_deconv
,
d
.
id
,
d
.
kd
,
d
.
sd
,
d
.
pd
,
d
.
dd
);
if
(
d
.
od
<=
0
)
return
XDNN_FAIL
;
}
else
if
(
d
.
pd
<
0
)
d
.
pd
=
compute_pad
(
is_deconv
,
d
.
od
,
d
.
id
,
d
.
kd
,
d
.
sd
,
d
.
dd
);
}
if
(
!
no_h
)
{
if
(
!
d
.
ih
||
!
d
.
kh
)
return
XDNN_FAIL
;
if
(
!
d
.
oh
)
{
if
(
d
.
ph
<
0
)
d
.
ph
=
0
;
d
.
oh
=
compute_out
(
is_deconv
,
d
.
ih
,
d
.
kh
,
d
.
sh
,
d
.
ph
,
d
.
dh
);
if
(
d
.
oh
<=
0
)
return
XDNN_FAIL
;
}
else
if
(
d
.
ph
<
0
)
d
.
ph
=
compute_pad
(
is_deconv
,
d
.
oh
,
d
.
ih
,
d
.
kh
,
d
.
sh
,
d
.
dh
);
}
if
(
!
no_w
)
{
if
(
!
d
.
iw
||
!
d
.
kw
)
return
XDNN_FAIL
;
if
(
!
d
.
ow
)
{
if
(
d
.
pw
<
0
)
d
.
pw
=
0
;
d
.
ow
=
compute_out
(
is_deconv
,
d
.
iw
,
d
.
kw
,
d
.
sw
,
d
.
pw
,
d
.
dw
);
if
(
d
.
ow
<=
0
)
return
XDNN_FAIL
;
}
else
if
(
d
.
pw
<
0
)
d
.
pw
=
compute_pad
(
is_deconv
,
d
.
ow
,
d
.
iw
,
d
.
kw
,
d
.
sw
,
d
.
dw
);
}
if
(
sanitize_desc
(
d
.
ndims
,
{
d
.
od
,
d
.
id
,
d
.
kd
,
d
.
sd
,
d
.
pd
,
d
.
dd
},
{
d
.
oh
,
d
.
ih
,
d
.
kh
,
d
.
sh
,
d
.
ph
,
d
.
dh
},
{
d
.
ow
,
d
.
iw
,
d
.
kw
,
d
.
sw
,
d
.
pw
,
d
.
dw
},
{
1
,
1
,
1
,
1
,
0
,
0
},
true
)
!=
XDNN_OK
)
return
XDNN_FAIL
;
d
.
init_pad_r
(
is_deconv
);
*
desc
=
d
;
// TODO: this is difference CK~OneDNN
d
.
dh
++
;
d
.
dw
++
;
d
.
dd
++
;
return
XDNN_OK
;
}
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment