Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
af9b4f25
"...composable_kernel.git" did not exist on "0bdbd358ecee2eb434e27bc900a3e248bcd2b2bd"
Commit
af9b4f25
authored
Jun 27, 2023
by
rocking
Browse files
Refine naming
parent
6ab0ace0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
77 additions
and
114 deletions
+77
-114
library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
.../reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
+77
-114
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
View file @
af9b4f25
...
@@ -14,52 +14,43 @@ namespace ck {
...
@@ -14,52 +14,43 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
// input descriptor in [N, C, Do, Ho, Wo] order
//
d
input descriptor in [N, C, Do, Ho, Wo] order
// output descriptor in [N, C, Di, Hi, Wi] order
//
d
output descriptor in [N, C, Di, Hi, Wi] order
// phyiscal layout is irrelavent
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
DInDataType
,
typename
OutDataType
,
typename
DOutDataType
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceAvgPoolBwd
:
public
device
::
BaseOperator
struct
ReferenceAvgPoolBwd
:
public
device
::
BaseOperator
{
{
// Argument
// Argument
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
{
{
Argument
(
Tensor
<
InDataType
>&
input
,
Argument
(
Tensor
<
D
InDataType
>&
d
input
,
const
Tensor
<
OutDataType
>&
output
,
const
Tensor
<
D
OutDataType
>&
d
output
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
dinput_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
)
InElementwiseOperation
in_element_op
,
:
dinput_
{
dinput
},
OutElementwiseOperation
out_element_op
)
doutput_
{
doutput
},
:
input_
{
input
},
output_
{
output
},
window_spatial_lengths_
{
window_spatial_lengths
},
window_spatial_lengths_
{
window_spatial_lengths
},
conv_strides_
{
window_strides
},
window_strides_
{
window_strides
},
conv_dilations_
{
window_dilations
},
window_dilations_
{
window_dilations
},
in_left_pads_
{
input_left_pads
},
in_left_pads_
{
dinput_left_pads
},
in_right_pads_
{
input_right_pads
},
in_right_pads_
{
dinput_right_pads
}
in_element_op_
{
in_element_op
},
out_element_op_
{
out_element_op
}
{
{
}
}
Tensor
<
InDataType
>&
input_
;
Tensor
<
D
InDataType
>&
d
input_
;
const
Tensor
<
OutDataType
>&
output_
;
const
Tensor
<
D
OutDataType
>&
d
output_
;
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths_
;
std
::
vector
<
index_t
>
conv
_strides_
;
std
::
vector
<
index_t
>
window
_strides_
;
std
::
vector
<
index_t
>
conv
_dilations_
;
std
::
vector
<
index_t
>
window
_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
OutElementwiseOperation
out_element_op_
;
};
};
// Invoker
// Invoker
...
@@ -69,8 +60,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -69,8 +60,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
if
(
!
(
arg
.
input_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
&&
if
(
!
(
arg
.
d
input_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
&&
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
))
arg
.
d
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
))
{
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
}
...
@@ -79,7 +70,7 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -79,7 +70,7 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
d
output_
.
GetLengths
()[
2
];
float
v_acc
=
0
;
float
v_acc
=
0
;
...
@@ -87,35 +78,28 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -87,35 +78,28 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
{
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv
_dilations_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window
_dilations_
[
0
]);
if
(
w_tmp
%
arg
.
conv
_strides_
[
0
]
==
0
)
if
(
w_tmp
%
arg
.
window
_strides_
[
0
]
==
0
)
{
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv
_strides_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
window
_strides_
[
0
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
{
float
v_out
=
0
;
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
wo
));
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
c
,
wo
)));
v_acc
+=
v_out
;
}
}
}
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
X
);
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
v_acc
/=
ck
::
type_convert
<
float
>
(
X
);
arg
.
dinput_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
};
make_ParallelTensorFunctor
(
f_ncw
,
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
input_
.
GetLengths
()[
0
],
arg
.
d
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
d
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
])(
arg
.
d
input_
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
...
@@ -126,8 +110,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -126,8 +110,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
d
output_
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
d
output_
.
GetLengths
()[
3
];
float
v_acc
=
0
;
float
v_acc
=
0
;
...
@@ -135,11 +119,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -135,11 +119,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
{
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv
_dilations_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
window
_dilations_
[
0
]);
if
(
h_tmp
%
arg
.
conv
_strides_
[
0
]
==
0
)
if
(
h_tmp
%
arg
.
window
_strides_
[
0
]
==
0
)
{
{
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv
_strides_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
window
_strides_
[
0
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
...
@@ -147,40 +131,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -147,40 +131,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
auto
w_tmp
=
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv
_dilations_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window
_dilations_
[
1
]);
if
(
w_tmp
%
arg
.
conv
_strides_
[
1
]
==
0
)
if
(
w_tmp
%
arg
.
window
_strides_
[
1
]
==
0
)
{
{
auto
wo
=
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv
_strides_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
window
_strides_
[
1
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
{
float
v_out
=
0
;
v_acc
+=
arg
.
out_element_op_
(
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
ho
,
wo
));
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
c
,
ho
,
wo
)));
v_acc
+=
v_out
;
}
}
}
}
}
}
}
}
}
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
Y
*
X
);
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
v_acc
/=
ck
::
type_convert
<
float
>
(
Y
*
X
);
arg
.
dinput_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
arg
.
input_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
};
make_ParallelTensorFunctor
(
f_nchw
,
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
input_
.
GetLengths
()[
0
],
arg
.
d
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
d
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
],
arg
.
d
input_
.
GetLengths
()[
2
],
arg
.
input_
.
GetLengths
()[
3
])(
arg
.
d
input_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
...
@@ -192,9 +168,9 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -192,9 +168,9 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
2
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
2
];
std
::
size_t
Do
=
arg
.
output_
.
GetLengths
()[
2
];
std
::
size_t
Do
=
arg
.
d
output_
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Ho
=
arg
.
d
output_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
4
];
std
::
size_t
Wo
=
arg
.
d
output_
.
GetLengths
()[
4
];
float
v_acc
=
0
;
float
v_acc
=
0
;
...
@@ -202,11 +178,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -202,11 +178,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
{
auto
d_tmp
=
static_cast
<
ck
::
long_index_t
>
(
di
)
+
auto
d_tmp
=
static_cast
<
ck
::
long_index_t
>
(
di
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv
_dilations_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
window
_dilations_
[
0
]);
if
(
d_tmp
%
arg
.
conv
_strides_
[
0
]
==
0
)
if
(
d_tmp
%
arg
.
window
_strides_
[
0
]
==
0
)
{
{
auto
do_
=
static_cast
<
ck
::
long_index_t
>
(
d_tmp
)
/
auto
do_
=
static_cast
<
ck
::
long_index_t
>
(
d_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv
_strides_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
window
_strides_
[
0
]);
if
(
do_
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
do_
)
<
Do
)
if
(
do_
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
do_
)
<
Do
)
{
{
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
...
@@ -214,12 +190,12 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -214,12 +190,12 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
auto
h_tmp
=
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv
_dilations_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
window
_dilations_
[
1
]);
if
(
h_tmp
%
arg
.
conv
_strides_
[
1
]
==
0
)
if
(
h_tmp
%
arg
.
window
_strides_
[
1
]
==
0
)
{
{
auto
ho
=
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv
_strides_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
window
_strides_
[
1
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
...
@@ -228,23 +204,18 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -228,23 +204,18 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
])
-
arg
.
in_left_pads_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv
_dilations_
[
2
]);
x
*
arg
.
window
_dilations_
[
2
]);
if
(
w_tmp
%
arg
.
conv
_strides_
[
2
]
==
0
)
if
(
w_tmp
%
arg
.
window
_strides_
[
2
]
==
0
)
{
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv
_strides_
[
2
]);
arg
.
window
_strides_
[
2
]);
if
(
wo
>=
0
&&
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
{
float
v_out
=
0
;
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
out_element_op_
(
arg
.
doutput_
(
n
,
c
,
do_
,
ho
,
wo
));
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
c
,
do_
,
ho
,
wo
)));
v_acc
+=
v_out
;
}
}
}
}
}
}
...
@@ -254,21 +225,17 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -254,21 +225,17 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
}
}
}
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
Z
*
Y
*
X
);
float
v_in
;
v_acc
/=
ck
::
type_convert
<
float
>
(
Z
*
Y
*
X
);
arg
.
dinput_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
};
make_ParallelTensorFunctor
(
f_ncdhw
,
make_ParallelTensorFunctor
(
f_ncdhw
,
arg
.
input_
.
GetLengths
()[
0
],
arg
.
d
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
d
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
],
arg
.
d
input_
.
GetLengths
()[
2
],
arg
.
input_
.
GetLengths
()[
3
],
arg
.
d
input_
.
GetLengths
()[
3
],
arg
.
input_
.
GetLengths
()[
4
])(
arg
.
d
input_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
...
@@ -290,30 +257,26 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
...
@@ -290,30 +257,26 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
Tensor
<
InDataType
>&
input
,
static
auto
MakeArgument
(
Tensor
<
D
InDataType
>&
d
input
,
const
Tensor
<
OutDataType
>&
output
,
const
Tensor
<
D
OutDataType
>&
d
output
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
dinput_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
)
InElementwiseOperation
in_element_op
,
OutElementwiseOperation
out_element_op
)
{
{
if
(
window_spatial_lengths
.
size
()
!=
NDimSpatial
||
window_strides
.
size
()
!=
NDimSpatial
||
if
(
window_spatial_lengths
.
size
()
!=
NDimSpatial
||
window_strides
.
size
()
!=
NDimSpatial
||
window_dilations
.
size
()
!=
NDimSpatial
||
input_left_pads
.
size
()
!=
NDimSpatial
||
window_dilations
.
size
()
!=
NDimSpatial
||
d
input_left_pads
.
size
()
!=
NDimSpatial
||
input_right_pads
.
size
()
!=
NDimSpatial
)
d
input_right_pads
.
size
()
!=
NDimSpatial
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
Argument
{
input
,
return
Argument
{
d
input
,
output
,
d
output
,
window_spatial_lengths
,
window_spatial_lengths
,
window_strides
,
window_strides
,
window_dilations
,
window_dilations
,
input_left_pads
,
dinput_left_pads
,
input_right_pads
,
dinput_right_pads
};
in_element_op
,
out_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment