Unverified Commit cc332b26 authored by robin Han's avatar robin Han Committed by GitHub
Browse files

add switch for onnx exporter (#564)

parent 21143568
"""Modified from https://github.com/pytorch/pytorch."""
import os
import numpy as np
import torch
from torch.nn.modules.utils import _pair, _single, _triple
......@@ -21,14 +23,27 @@ def _interpolate(name, dim, interpolate_mode):
'Constant', value_t=torch.tensor([], dtype=torch.float32))
if scales is None:
input_size = g.op('Shape', input)
input_size_beg = sym_help._slice_helper(
g, input_size, axes=[0], ends=[2], starts=[0])
output_size = g.op(
'Cast',
output_size,
to_i=sym_help.cast_pytorch_to_onnx['Long'])
output_size = g.op('Concat', input_size_beg, output_size, axis_i=0)
if 'ONNX_BACKEND' in os.environ and os.environ[
'ONNX_BACKEND'] == 'TensorRT':
input_size = input.type().sizes()
# slice the first two dim
input_size = input_size[:2]
# convert output_size to int type
output_size = sym_help._maybe_get_const(output_size, 'is')
input_size.extend(output_size)
output_size = g.op(
'Constant',
value_t=torch.tensor(input_size, dtype=torch.int64))
else:
input_size = g.op('Shape', input)
input_size_beg = sym_help._slice_helper(
g, input_size, axes=[0], ends=[2], starts=[0])
output_size = g.op(
'Cast',
output_size,
to_i=sym_help.cast_pytorch_to_onnx['Long'])
output_size = g.op(
'Concat', input_size_beg, output_size, axis_i=0)
scales = g.op(
'Constant', value_t=torch.tensor([], dtype=torch.float32))
return g.op(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment