Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
85b2fee8
Commit
85b2fee8
authored
Jun 06, 2023
by
rocking
Browse files
Add compute datatype for reference code.
Prevent error in bf16
parent
7f09b8a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
2 deletions
+7
-2
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp
.../reference_tensor_operation/cpu/reference_maxpool_bwd.hpp
+6
-1
No files found.
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
View file @
85b2fee8
...
@@ -213,7 +213,7 @@ bool maxpool_bwd_test(bool do_verification,
...
@@ -213,7 +213,7 @@ bool maxpool_bwd_test(bool do_verification,
ref_pooling_fwd_invoker
.
Run
(
ref_pooling_fwd_argument
);
ref_pooling_fwd_invoker
.
Run
(
ref_pooling_fwd_argument
);
using
ReferencePoolingBwdInstance
=
ck
::
tensor_operation
::
host
::
using
ReferencePoolingBwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMaxPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
ReferenceMaxPoolBwd
<
DOutDataType
,
IndexDataType
,
float
,
DInDataType
,
PassThrough
>
;
auto
ref_pooling_bwd
=
ReferencePoolingBwdInstance
{};
auto
ref_pooling_bwd
=
ReferencePoolingBwdInstance
{};
auto
ref_pooling_bwd_invoker
=
ref_pooling_bwd
.
MakeInvoker
();
auto
ref_pooling_bwd_invoker
=
ref_pooling_bwd
.
MakeInvoker
();
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp
View file @
85b2fee8
...
@@ -21,6 +21,7 @@ using namespace std;
...
@@ -21,6 +21,7 @@ using namespace std;
template
<
typename
DOutDataType
,
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
IndexDataType
,
typename
ConputeDataType
,
typename
DInDataType
,
typename
DInDataType
,
typename
ElementwiseOperation
>
typename
ElementwiseOperation
>
struct
ReferenceMaxPoolBwd
:
public
device
::
BaseOperator
struct
ReferenceMaxPoolBwd
:
public
device
::
BaseOperator
...
@@ -49,13 +50,17 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator
...
@@ -49,13 +50,17 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator
{
{
int
din_length
=
arg
.
din_
.
GetElementSpaceSize
();
int
din_length
=
arg
.
din_
.
GetElementSpaceSize
();
int
dout_length
=
arg
.
dout_
.
GetElementSpaceSize
();
int
dout_length
=
arg
.
dout_
.
GetElementSpaceSize
();
std
::
vector
<
ConputeDataType
>
buf
(
din_length
);
for
(
int
i
=
0
;
i
<
dout_length
;
++
i
)
for
(
int
i
=
0
;
i
<
dout_length
;
++
i
)
{
{
int
index
=
arg
.
indices_
.
mData
[
i
];
int
index
=
arg
.
indices_
.
mData
[
i
];
if
(
index
>=
0
&&
index
<
din_length
)
if
(
index
>=
0
&&
index
<
din_length
)
arg
.
din_
.
mData
[
index
]
+=
arg
.
dout_
.
mData
[
i
];
buf
[
index
]
+=
ck
::
type_convert
<
ConputeDataType
>
(
arg
.
dout_
.
mData
[
i
]
)
;
}
}
for
(
int
i
=
0
;
i
<
din_length
;
++
i
)
arg
.
din_
.
mData
[
i
]
=
ck
::
type_convert
<
DInDataType
>
(
buf
[
i
]);
return
0
;
return
0
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment