Unverified Commit cc332b26 authored by robin Han's avatar robin Han Committed by GitHub
Browse files

add switch for onnx exporter (#564)

parent 21143568
"""Modified from https://github.com/pytorch/pytorch.""" """Modified from https://github.com/pytorch/pytorch."""
import os
import numpy as np import numpy as np
import torch import torch
from torch.nn.modules.utils import _pair, _single, _triple from torch.nn.modules.utils import _pair, _single, _triple
...@@ -21,14 +23,27 @@ def _interpolate(name, dim, interpolate_mode): ...@@ -21,14 +23,27 @@ def _interpolate(name, dim, interpolate_mode):
'Constant', value_t=torch.tensor([], dtype=torch.float32)) 'Constant', value_t=torch.tensor([], dtype=torch.float32))
if scales is None: if scales is None:
input_size = g.op('Shape', input) if 'ONNX_BACKEND' in os.environ and os.environ[
input_size_beg = sym_help._slice_helper( 'ONNX_BACKEND'] == 'TensorRT':
g, input_size, axes=[0], ends=[2], starts=[0]) input_size = input.type().sizes()
output_size = g.op( # slice the first two dim
'Cast', input_size = input_size[:2]
output_size, # convert output_size to int type
to_i=sym_help.cast_pytorch_to_onnx['Long']) output_size = sym_help._maybe_get_const(output_size, 'is')
output_size = g.op('Concat', input_size_beg, output_size, axis_i=0) input_size.extend(output_size)
output_size = g.op(
'Constant',
value_t=torch.tensor(input_size, dtype=torch.int64))
else:
input_size = g.op('Shape', input)
input_size_beg = sym_help._slice_helper(
g, input_size, axes=[0], ends=[2], starts=[0])
output_size = g.op(
'Cast',
output_size,
to_i=sym_help.cast_pytorch_to_onnx['Long'])
output_size = g.op(
'Concat', input_size_beg, output_size, axis_i=0)
scales = g.op( scales = g.op(
'Constant', value_t=torch.tensor([], dtype=torch.float32)) 'Constant', value_t=torch.tensor([], dtype=torch.float32))
return g.op( return g.op(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment