Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
283f9b62
Commit
283f9b62
authored
Jun 15, 2023
by
rocking
Browse files
Move set din zero to the device operator
parent
a2598b8a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
7 deletions
+14
-7
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
+0
-1
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
..._operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
+14
-6
No files found.
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
View file @
283f9b62
...
@@ -116,7 +116,6 @@ bool maxpool_bwd_test(bool do_verification,
...
@@ -116,7 +116,6 @@ bool maxpool_bwd_test(bool do_verification,
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
dout_device_buf
.
ToDevice
(
dout_n_c_ho_wo
.
mData
.
data
());
dout_device_buf
.
ToDevice
(
dout_n_c_ho_wo
.
mData
.
data
());
din_device_buf
.
SetZero
();
auto
pool_fwd
=
DevicePoolFwdInstance
{};
auto
pool_fwd
=
DevicePoolFwdInstance
{};
auto
pool_fwd_invoker_ptr
=
pool_fwd
.
MakeInvokerPointer
();
auto
pool_fwd_invoker_ptr
=
pool_fwd
.
MakeInvokerPointer
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
View file @
283f9b62
...
@@ -123,8 +123,13 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -123,8 +123,13 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
index_t
din_length_raw
=
arg
.
din_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
{
{
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
0
,
din_length_raw
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
if
(
arg
.
windowOverlap_
)
if
(
arg
.
windowOverlap_
)
{
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
...
@@ -173,13 +178,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -173,13 +178,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
if
(
arg
.
p_workspace_
==
nullptr
)
if
(
arg
.
p_workspace_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
index_t
din_length_raw
=
arg
.
din_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
hip_check_error
(
hip_check_error
(
hipMemset
(
arg
.
p_workspace_
,
hipMemsetAsync
(
arg
.
p_workspace_
,
0
,
0
,
din_length_raw
*
sizeof
(
DInDataType_AutomicAddPreCast
)));
din_length_raw
*
sizeof
(
DInDataType_AutomicAddPreCast
),
stream_config
.
stream_id_
));
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
InOutGrid1dDesc
,
InOutGrid1dDesc
,
...
@@ -231,6 +234,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -231,6 +234,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
DInDataType
,
DInDataType
,
PassThrough
>
;
PassThrough
>
;
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
0
,
din_length_raw
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
put_kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
gridSize_
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment