Unverified Commit 11d8dd53 authored by robin Han's avatar robin Han Committed by GitHub
Browse files

support ONNX adaptive average pooling (#504)



* support ONNX adaptive average pooling

* fix double quotes
Co-authored-by: default avatarKai Chen <chenkaidev@gmail.com>
parent 5e3f56f8
...@@ -305,6 +305,50 @@ def softmax(g, input, dim, dtype=None): ...@@ -305,6 +305,50 @@ def softmax(g, input, dim, dtype=None):
return softmax return softmax
def _adaptive_pool(name, type, tuple_fn, fn=None):
@parse_args('v', 'is')
def symbolic_fn(g, input, output_size):
if output_size == [1] * len(output_size) and type == 'AveragePool':
return g.op('GlobalAveragePool', input)
if not input.isCompleteTensor():
if output_size == [1] * len(output_size):
return g.op('GlobalMaxPool', input), None
raise NotImplementedError(
'[Adaptive pool]:input size not accessible')
dim = input.type().sizes()[2:]
if output_size == [1] * len(output_size) and type == 'MaxPool':
return g.op('GlobalMaxPool', input), None
# compute stride = floor(input_size / output_size)
s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
# compute kernel_size = input_size - (output_size - 1) * stride
k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))]
# call max_poolxd_with_indices to get indices in the output
if type == 'MaxPool':
return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim),
False)
output = g.op(
type,
input,
kernel_shape_i=tuple_fn(k),
strides_i=tuple_fn(s),
ceil_mode_i=False)
return output
return symbolic_fn
adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool',
_single)
adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool',
_pair)
adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
_triple)
def register_extra_symbolics(opset=11): def register_extra_symbolics(opset=11):
register_op('one_hot', one_hot, '', opset) register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset) register_op('im2col', im2col, '', opset)
...@@ -317,6 +361,9 @@ def register_extra_symbolics(opset=11): ...@@ -317,6 +361,9 @@ def register_extra_symbolics(opset=11):
register_op('avg_pool1d', avg_pool1d, '', opset) register_op('avg_pool1d', avg_pool1d, '', opset)
register_op('avg_pool2d', avg_pool2d, '', opset) register_op('avg_pool2d', avg_pool2d, '', opset)
register_op('avg_pool3d', avg_pool3d, '', opset) register_op('avg_pool3d', avg_pool3d, '', opset)
register_op('adaptive_avg_pool1d', adaptive_avg_pool1d, '', opset)
register_op('adaptive_avg_pool2d', adaptive_avg_pool2d, '', opset)
register_op('adaptive_avg_pool3d', adaptive_avg_pool3d, '', opset)
register_op('masked_select', masked_select, '', opset) register_op('masked_select', masked_select, '', opset)
register_op('upsample_nearest1d', upsample_nearest1d, '', opset) register_op('upsample_nearest1d', upsample_nearest1d, '', opset)
register_op('upsample_nearest2d', upsample_nearest2d, '', opset) register_op('upsample_nearest2d', upsample_nearest2d, '', opset)
......
...@@ -39,4 +39,4 @@ PARROTS_EXTENSION_REGISTER(tin_shift_backward) ...@@ -39,4 +39,4 @@ PARROTS_EXTENSION_REGISTER(tin_shift_backward)
.input(2) .input(2)
.output(1) .output(1)
.apply(tin_shift_backward_cuda) .apply(tin_shift_backward_cuda)
.done(); .done();
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment