Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
81caa447
"vscode:/vscode.git/clone" did not exist on "c52dfdef791a5721480f9a950d0931dc9ea5cdb8"
Commit
81caa447
authored
Jun 08, 2023
by
rocking
Browse files
Fix bf16 error
parent
1a1059ab
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_pool_fwd.hpp
+10
-10
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp
View file @
81caa447
...
...
@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi
>=
0
&&
wi
<
static_cast
<
ck
::
index_t
>
(
arg
.
in_
.
mDesc
.
GetLengths
()[
4
]))
{
ComputeDataType
currVal
=
static_cast
<
ComputeDataType
>
(
arg
.
in_
(
n
,
c
,
di
,
hi
,
wi
));
ComputeDataType
currVal
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
in_
(
n
,
c
,
di
,
hi
,
wi
));
in_elementwise_op
(
currVal
,
currVal
);
...
...
@@ -112,7 +112,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op
(
accuVal
,
accuVal
);
arg
.
out_
(
n
,
c
,
do_
,
ho
,
wo
)
=
accuVal
;
arg
.
out_
(
n
,
c
,
do_
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
accuVal
)
;
};
make_ParallelTensorFunctor
(
f_ncdhw
,
...
...
@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi
>=
0
&&
wi
<
static_cast
<
ck
::
index_t
>
(
arg
.
in_
.
mDesc
.
GetLengths
()[
4
]))
{
ComputeDataType
currVal
=
static_cast
<
ComputeDataType
>
(
arg
.
in_
(
n
,
c
,
di
,
hi
,
wi
));
ComputeDataType
currVal
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
in_
(
n
,
c
,
di
,
hi
,
wi
));
IndexDataType
currIndex
=
arg
.
in_
.
GetOffsetFromMultiIndex
(
n
,
c
,
di
,
hi
,
wi
);
...
...
@@ -166,7 +166,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
acc_elementwise_op
(
accuVal
,
accuVal
);
arg
.
out_
(
n
,
c
,
do_
,
ho
,
wo
)
=
accuVal
;
arg
.
out_
(
n
,
c
,
do_
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
accuVal
)
;
arg
.
out_indices_
(
n
,
c
,
do_
,
ho
,
wo
)
=
accuIndex
;
};
...
...
@@ -212,7 +212,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi
<
static_cast
<
ck
::
index_t
>
(
arg
.
in_
.
mDesc
.
GetLengths
()[
3
]))
{
ComputeDataType
currVal
=
static_cas
t
<
ComputeDataType
>
(
arg
.
in_
(
n
,
c
,
hi
,
wi
));
ck
::
type_conver
t
<
ComputeDataType
>
(
arg
.
in_
(
n
,
c
,
hi
,
wi
));
in_elementwise_op
(
currVal
,
currVal
);
...
...
@@ -222,7 +222,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op
(
accuVal
,
accuVal
);
arg
.
out_
(
n
,
c
,
ho
,
wo
)
=
accuVal
;
arg
.
out_
(
n
,
c
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
accuVal
)
;
};
make_ParallelTensorFunctor
(
f_nchw
,
...
...
@@ -255,7 +255,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi
<
static_cast
<
ck
::
index_t
>
(
arg
.
in_
.
mDesc
.
GetLengths
()[
3
]))
{
ComputeDataType
currVal
=
static_cas
t
<
ComputeDataType
>
(
arg
.
in_
(
n
,
c
,
hi
,
wi
));
ck
::
type_conver
t
<
ComputeDataType
>
(
arg
.
in_
(
n
,
c
,
hi
,
wi
));
IndexDataType
currIndex
=
arg
.
in_
.
GetOffsetFromMultiIndex
(
n
,
c
,
hi
,
wi
);
...
...
@@ -268,7 +268,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op
(
accuVal
,
accuVal
);
arg
.
out_
(
n
,
c
,
ho
,
wo
)
=
accuVal
;
arg
.
out_
(
n
,
c
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
accuVal
)
;
arg
.
out_indices_
(
n
,
c
,
ho
,
wo
)
=
accuIndex
;
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment