Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0ab4fa0f
Commit
0ab4fa0f
authored
Jul 10, 2023
by
rocking
Browse files
Check if argument is valid
parent
400cb28e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
7 deletions
+50
-7
example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp
example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp
+2
-1
include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
...ude/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
...r_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
+47
-5
No files found.
example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp
View file @
0ab4fa0f
...
@@ -75,7 +75,8 @@ bool pool3d_bwd_test(bool do_verification,
...
@@ -75,7 +75,8 @@ bool pool3d_bwd_test(bool do_verification,
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
)
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
)
{
{
using
DevicePoolBwdInstance
=
using
DevicePoolBwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceAvgPool3dBwdImpl
<
DOutDataType
,
ck
::
tensor_operation
::
device
::
DeviceAvgPool3dBwdImpl
<
3
,
DOutDataType
,
DInDataType
,
DInDataType
,
ComputeDataType
,
// ComputeDataType
ComputeDataType
,
// ComputeDataType
64
,
// BlockSize
64
,
// BlockSize
...
...
include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
View file @
0ab4fa0f
...
@@ -12,7 +12,7 @@ namespace ck {
...
@@ -12,7 +12,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
DOutDataType
,
typename
DInDataType
>
template
<
index_t
NDimSpatial
,
typename
DOutDataType
,
typename
DInDataType
>
struct
DeviceAvgPoolBwd
:
public
BaseOperator
struct
DeviceAvgPoolBwd
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
View file @
0ab4fa0f
...
@@ -23,7 +23,8 @@ namespace device {
...
@@ -23,7 +23,8 @@ namespace device {
// Out = AvgPoolFwd(In)
// Out = AvgPoolFwd(In)
// Din = AvgPoolBwd(Dout)
// Din = AvgPoolBwd(Dout)
// Pooling dimension = D, H, W
// Pooling dimension = D, H, W
template
<
typename
DOutDataType
,
template
<
index_t
NDimSpatial
,
typename
DOutDataType
,
typename
DInDataType
,
typename
DInDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
...
@@ -33,10 +34,8 @@ template <typename DOutDataType,
...
@@ -33,10 +34,8 @@ template <typename DOutDataType,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
,
ck
::
index_t
InSrcOutDstVectorSize
,
bool
IsFastestDimReduced
>
bool
IsFastestDimReduced
>
struct
DeviceAvgPool3dBwdImpl
:
public
DeviceAvgPoolBwd
<
DOutDataType
,
DInDataType
>
struct
DeviceAvgPool3dBwdImpl
:
public
DeviceAvgPoolBwd
<
NDimSpatial
,
DOutDataType
,
DInDataType
>
{
{
static
constexpr
index_t
NDimSpatial
=
3
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -356,6 +355,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -356,6 +355,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
:
p_dout_grid_
{
p_dout
},
:
p_dout_grid_
{
p_dout
},
p_din_grid_
{
p_din
},
p_din_grid_
{
p_din
},
dout_n_c_wos_lengths_
{
dout_n_c_wos_lengths
},
din_n_c_wos_length_
{
din_n_c_wos_length
},
dout_n_c_wos_strides_
{
dout_n_c_wos_strides
},
din_n_c_wos_strides_
{
din_n_c_wos_strides
},
num_reduce_
{
1
},
num_reduce_
{
1
},
div_element_op_
{
window_lengths
[
0
]
*
window_lengths
[
1
]
*
window_lengths
[
2
]}
div_element_op_
{
window_lengths
[
0
]
*
window_lengths
[
1
]
*
window_lengths
[
2
]}
{
{
...
@@ -407,6 +410,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -407,6 +410,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const
DOutDataType
*
p_dout_grid_
;
const
DOutDataType
*
p_dout_grid_
;
DInDataType
*
p_din_grid_
;
DInDataType
*
p_din_grid_
;
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_lengths_
;
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_length_
;
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_strides_
;
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_strides_
;
int
num_reduce_
;
int
num_reduce_
;
std
::
vector
<
DoutGridDesc_M_K
>
dout_grid_desc_m_k_container_
;
std
::
vector
<
DoutGridDesc_M_K
>
dout_grid_desc_m_k_container_
;
...
@@ -468,7 +475,31 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -468,7 +475,31 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
ignore
=
arg
;
constexpr
index_t
Rank
=
NDimSpatial
+
2
;
int
doutFastestDim
=
-
1
;
int
dinFastestDim
=
-
1
;
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
if
(
arg
.
dout_n_c_wos_strides_
[
i
]
==
1
)
doutFastestDim
=
i
;
if
(
arg
.
din_n_c_wos_strides_
[
i
]
==
1
)
dinFastestDim
=
i
;
}
if
(
doutFastestDim
==
-
1
||
dinFastestDim
==
-
1
)
{
if
constexpr
(
InSrcOutDstVectorSize
!=
1
)
return
false
;
}
else
{
if
(
arg
.
dout_n_c_wos_lengths_
[
doutFastestDim
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
if
(
arg
.
din_n_c_wos_length_
[
dinFastestDim
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
}
return
true
;
return
true
;
}
}
...
@@ -490,6 +521,17 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -490,6 +521,17 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
override
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
override
{
{
constexpr
index_t
Rank
=
NDimSpatial
+
2
;
if
(
dout_n_c_wos_strides
.
size
()
!=
Rank
||
din_n_c_wos_strides
.
size
()
!=
Rank
||
dout_n_c_wos_lengths
.
size
()
!=
Rank
||
din_n_c_wos_length
.
size
()
!=
Rank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
window_lengths
.
size
()
!=
NDimSpatial
||
window_strides
.
size
()
!=
NDimSpatial
||
window_dilations
.
size
()
!=
NDimSpatial
||
input_left_pads
.
size
()
!=
NDimSpatial
||
input_right_pads
.
size
()
!=
NDimSpatial
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DOutDataType
*>
(
p_dout
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DOutDataType
*>
(
p_dout
),
static_cast
<
DInDataType
*>
(
p_din
),
static_cast
<
DInDataType
*>
(
p_din
),
dout_n_c_wos_lengths
,
dout_n_c_wos_lengths
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment