Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c58d92d3
"vscode:/vscode.git/clone" did not exist on "bc1bb798fc436f4ac436b8fe6cf78dabfb865581"
Commit
c58d92d3
authored
Apr 22, 2022
by
myamlak
Browse files
Cleaning part I
parent
c4a678fc
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
42 deletions
+44
-42
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+16
-16
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+15
-15
library/include/ck/library/utility/conv_fwd_util.hpp
library/include/ck/library/utility/conv_fwd_util.hpp
+12
-10
No files found.
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
c58d92d3
...
...
@@ -154,7 +154,7 @@ struct ParallelTensorFunctor
{
std
::
array
<
std
::
size_t
,
NDIM
>
indices
;
for
(
in
t
idim
=
0
;
idim
<
NDIM
;
++
idim
)
for
(
std
::
size_
t
idim
=
0
;
idim
<
NDIM
;
++
idim
)
{
indices
[
idim
]
=
i
/
mStrides
[
idim
];
i
-=
indices
[
idim
]
*
mStrides
[
idim
];
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
c58d92d3
...
...
@@ -72,9 +72,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
if
constexpr
(
NumDimSpatial
==
1
)
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_
t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_
t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_
t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
in
t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
in
t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
in
t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
AccDataType
v_acc
=
0
;
...
...
@@ -119,12 +119,12 @@ struct ReferenceConvBwdData : public device::BaseOperator
else
if
constexpr
(
NumDimSpatial
==
2
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_
t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_
t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_
t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
in
t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
in
t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
in
t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_
t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_
t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
in
t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
in
t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
...
...
@@ -183,14 +183,14 @@ struct ReferenceConvBwdData : public device::BaseOperator
else
if
constexpr
(
NumDimSpatial
==
3
)
{
auto
f_ncdhw
=
[
&
](
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_
t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_
t
Z
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_
t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_
t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
std
::
size_
t
Do
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_
t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_
t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
in
t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
in
t
Z
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
in
t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
in
t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
in
t
Do
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
in
t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
in
t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
c58d92d3
...
...
@@ -88,13 +88,13 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_ncw
=
[
&
](
auto
n
,
auto
k
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
int
c
=
0
;
c
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]
)
;
++
c
)
{
for
(
int
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
x
)
for
(
int
x
=
0
;
x
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]
)
;
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
0
]
+
x
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
if
(
wi
>=
0
&&
wi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
if
(
wi
>=
0
&&
wi
<
static_cast
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
)
{
float
v_in
;
float
v_wei
;
...
...
@@ -128,18 +128,18 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
int
c
=
0
;
c
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]
)
;
++
c
)
{
for
(
int
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
int
y
=
0
;
y
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]
)
;
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
0
]
+
y
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
for
(
int
x
=
0
;
x
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]
)
;
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
1
]
+
x
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
if
(
hi
>=
0
&&
hi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
if
(
hi
>=
0
&&
hi
<
static_cast
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
)
&&
wi
>=
0
&&
wi
<
static_cast
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
)
{
float
v_in
;
float
v_wei
;
...
...
@@ -174,23 +174,23 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
int
c
=
0
;
c
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]
)
;
++
c
)
{
for
(
int
z
=
0
;
z
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
z
)
for
(
int
z
=
0
;
z
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]
)
;
++
z
)
{
int
di
=
d_o
*
arg
.
conv_strides_
[
0
]
+
z
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
y
)
for
(
int
y
=
0
;
y
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]
)
;
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
1
]
+
y
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
for
(
int
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
++
x
)
for
(
int
x
=
0
;
x
<
static_cast
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
]
)
;
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
2
]
+
x
*
arg
.
conv_dilations_
[
2
]
-
arg
.
in_left_pads_
[
2
];
if
(
di
>=
0
&&
di
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
hi
>=
0
&&
hi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
wi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
if
(
di
>=
0
&&
di
<
static_cast
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
)
&&
hi
>=
0
&&
hi
<
static_cast
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
)
&&
wi
>=
0
&&
wi
<
static_cast
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
)
{
float
v_in
;
float
v_wei
;
...
...
library/include/ck/library/utility/conv_fwd_util.hpp
View file @
c58d92d3
...
...
@@ -145,11 +145,12 @@ struct ConvParams
input_left_pads
(
left_pads
),
input_right_pads
(
right_pads
)
{
if
(
filter_spatial_lengths
.
size
()
!=
num_dim_spatial
||
input_spatial_lengths
.
size
()
!=
num_dim_spatial
||
conv_filter_strides
.
size
()
!=
num_dim_spatial
||
conv_filter_dilations
.
size
()
!=
num_dim_spatial
||
input_left_pads
.
size
()
!=
num_dim_spatial
||
input_right_pads
.
size
()
!=
num_dim_spatial
)
if
(
static_cast
<
ck
::
index_t
>
(
filter_spatial_lengths
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
input_spatial_lengths
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
conv_filter_strides
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
conv_filter_dilations
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
input_left_pads
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
input_right_pads
.
size
())
!=
num_dim_spatial
)
{
throw
(
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
...
...
@@ -173,11 +174,12 @@ struct ConvParams
std
::
vector
<
ck
::
index_t
>
GetOutputSpatialLengths
()
const
{
if
(
filter_spatial_lengths
.
size
()
!=
num_dim_spatial
||
input_spatial_lengths
.
size
()
!=
num_dim_spatial
||
conv_filter_strides
.
size
()
!=
num_dim_spatial
||
conv_filter_dilations
.
size
()
!=
num_dim_spatial
||
input_left_pads
.
size
()
!=
num_dim_spatial
||
input_right_pads
.
size
()
!=
num_dim_spatial
)
if
(
static_cast
<
ck
::
index_t
>
(
filter_spatial_lengths
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
input_spatial_lengths
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
conv_filter_strides
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
conv_filter_dilations
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
input_left_pads
.
size
())
!=
num_dim_spatial
||
static_cast
<
ck
::
index_t
>
(
input_right_pads
.
size
())
!=
num_dim_spatial
)
{
throw
(
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment