Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
47ac3767
Commit
47ac3767
authored
Jun 15, 2023
by
rocking
Browse files
Save din_length_raw
parent
283f9b62
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
10 deletions
+9
-10
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
..._operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
+9
-10
No files found.
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
View file @
47ac3767
...
@@ -95,6 +95,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -95,6 +95,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
:
p_dout_
{
p_dout
},
:
p_dout_
{
p_dout
},
p_indices_
{
p_indices
},
p_indices_
{
p_indices
},
p_din_
{
p_din
},
p_din_
{
p_din
},
din_length_raw_
{
din_length
},
blockSize_
{
256
},
blockSize_
{
256
},
gridSize_
{
104
},
// FIXME - Calculate the grid size by number of CU in the future
gridSize_
{
104
},
// FIXME - Calculate the grid size by number of CU in the future
windowOverlap_
{
false
}
windowOverlap_
{
false
}
...
@@ -112,6 +113,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -112,6 +113,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const
DOutDataType
*
p_dout_
;
const
DOutDataType
*
p_dout_
;
const
IndexDataType
*
p_indices_
;
const
IndexDataType
*
p_indices_
;
DInDataType
*
p_din_
;
DInDataType
*
p_din_
;
index_t
din_length_raw_
;
index_t
blockSize_
;
index_t
blockSize_
;
index_t
gridSize_
;
index_t
gridSize_
;
bool
windowOverlap_
;
bool
windowOverlap_
;
...
@@ -123,12 +125,12 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -123,12 +125,12 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
index_t
din_length_raw
=
arg
.
din_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
{
{
hip_check_error
(
hipMemsetAsync
(
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
arg
.
p_din_
,
0
,
din_length_raw
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
0
,
arg
.
din_length_raw_
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
if
(
arg
.
windowOverlap_
)
if
(
arg
.
windowOverlap_
)
{
{
...
@@ -181,7 +183,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -181,7 +183,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
hip_check_error
(
hip_check_error
(
hipMemsetAsync
(
arg
.
p_workspace_
,
hipMemsetAsync
(
arg
.
p_workspace_
,
0
,
0
,
din_length_raw
*
sizeof
(
DInDataType_AutomicAddPreCast
),
arg
.
din_length_raw
_
*
sizeof
(
DInDataType_AutomicAddPreCast
),
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
...
@@ -236,7 +238,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -236,7 +238,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
0
,
0
,
din_length_raw
*
sizeof
(
DInDataType
),
arg
.
din_length_raw
_
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
@@ -270,10 +272,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -270,10 +272,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
if
(
!
needCast
)
if
(
!
needCast
)
return
0
;
return
0
;
else
else
{
return
pArg_
->
din_length_raw_
*
sizeof
(
DInDataType_AutomicAddPreCast
);
index_t
din_length
=
pArg_
->
din_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
return
din_length
*
sizeof
(
DInDataType_AutomicAddPreCast
);
}
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment