Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
97f133d2
Commit
97f133d2
authored
Apr 26, 2022
by
myamlak
Browse files
Explicit static_cast to ck::type_convert
parent
b411ee3b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
27 deletions
+27
-27
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+15
-15
library/include/ck/library/utility/conv_fwd_util.hpp
library/include/ck/library/utility/conv_fwd_util.hpp
+12
-12
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
97f133d2
...
...
@@ -88,13 +88,13 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_ncw
=
[
&
](
auto
n
,
auto
k
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
for
(
int
c
=
0
;
c
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
{
for
(
int
x
=
0
;
x
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
x
)
for
(
int
x
=
0
;
x
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
0
]
+
x
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
if
(
wi
>=
0
&&
wi
<
static_cas
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]))
if
(
wi
>=
0
&&
wi
<
ck
::
type_conver
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]))
{
float
v_in
;
float
v_wei
;
...
...
@@ -128,18 +128,18 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
for
(
int
c
=
0
;
c
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
{
for
(
int
y
=
0
;
y
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
y
)
for
(
int
y
=
0
;
y
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
0
]
+
y
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
x
=
0
;
x
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]);
++
x
)
for
(
int
x
=
0
;
x
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]);
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
1
]
+
x
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
if
(
hi
>=
0
&&
hi
<
static_cas
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
wi
<
static_cas
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]))
if
(
hi
>=
0
&&
hi
<
ck
::
type_conver
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
wi
<
ck
::
type_conver
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]))
{
float
v_in
;
float
v_wei
;
...
...
@@ -174,23 +174,23 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
for
(
int
c
=
0
;
c
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
{
for
(
int
z
=
0
;
z
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
z
)
for
(
int
z
=
0
;
z
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
z
)
{
int
di
=
d_o
*
arg
.
conv_strides_
[
0
]
+
z
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
y
=
0
;
y
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]);
++
y
)
for
(
int
y
=
0
;
y
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]);
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
1
]
+
y
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
for
(
int
x
=
0
;
x
<
static_cas
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
]);
++
x
)
for
(
int
x
=
0
;
x
<
ck
::
type_conver
t
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
]);
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
2
]
+
x
*
arg
.
conv_dilations_
[
2
]
-
arg
.
in_left_pads_
[
2
];
if
(
di
>=
0
&&
di
<
static_cas
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
&&
hi
>=
0
&&
hi
<
static_cas
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
&&
wi
>=
0
&&
wi
<
static_cas
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
4
]))
if
(
di
>=
0
&&
di
<
ck
::
type_conver
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
&&
hi
>=
0
&&
hi
<
ck
::
type_conver
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
&&
wi
>=
0
&&
wi
<
ck
::
type_conver
t
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
4
]))
{
float
v_in
;
float
v_wei
;
...
...
library/include/ck/library/utility/conv_fwd_util.hpp
View file @
97f133d2
...
...
@@ -145,12 +145,12 @@ struct ConvParams
input_left_pads
(
left_pads
),
input_right_pads
(
right_pads
)
{
if
(
static_cas
t
<
ck
::
index_t
>
(
filter_spatial_lengths
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
input_spatial_lengths
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
conv_filter_strides
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
conv_filter_dilations
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
input_left_pads
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
input_right_pads
.
size
())
!=
num_dim_spatial
)
if
(
ck
::
type_conver
t
<
ck
::
index_t
>
(
filter_spatial_lengths
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
input_spatial_lengths
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
conv_filter_strides
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
conv_filter_dilations
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
input_left_pads
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
input_right_pads
.
size
())
!=
num_dim_spatial
)
{
throw
(
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
...
...
@@ -174,12 +174,12 @@ struct ConvParams
std
::
vector
<
ck
::
index_t
>
GetOutputSpatialLengths
()
const
{
if
(
static_cas
t
<
ck
::
index_t
>
(
filter_spatial_lengths
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
input_spatial_lengths
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
conv_filter_strides
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
conv_filter_dilations
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
input_left_pads
.
size
())
!=
num_dim_spatial
||
static_cas
t
<
ck
::
index_t
>
(
input_right_pads
.
size
())
!=
num_dim_spatial
)
if
(
ck
::
type_conver
t
<
ck
::
index_t
>
(
filter_spatial_lengths
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
input_spatial_lengths
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
conv_filter_strides
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
conv_filter_dilations
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
input_left_pads
.
size
())
!=
num_dim_spatial
||
ck
::
type_conver
t
<
ck
::
index_t
>
(
input_right_pads
.
size
())
!=
num_dim_spatial
)
{
throw
(
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment