Unverified Commit 58d5c2fa authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] refactor of pytorch operators (#3365)

parent 59521d33
import logging
from typing import List
from typing import List, Tuple, Any
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
......@@ -32,9 +32,26 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> List[str]:
def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]:
"""
Format the inputs of a given node
Parameters
----------
node : Node
a graph node, get and format its inputs
Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges = _sorted_incoming_edges(node)
inputs = []
inputs_value = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
......@@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]:
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
inputs_value.append(None)
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
else:
inputs_value.append(None)
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs
inputs_value.append(None)
return inputs, inputs_value
def _remove_prefix(names, graph_name):
......@@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_codes = []
for node in nodes:
if node.operation:
if node.operation.type == 'shared':
continue
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
......@@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs = _format_inputs(node)
inputs, inputs_value = _format_inputs(node)
inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name)
edge_codes.append(node.operation.to_forward_code(node_name, node_name, inputs))
submodule_name = node_name
if node.operation.type == 'shared':
submodule_name = _remove_prefix(node.operation.parameters['reference'], graph_name)
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))
output_names = _format_inputs(graph.output_node)
output_names, _ = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
......
......@@ -5,9 +5,9 @@ import torch
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell
from ..operation import Cell, Operation
from ..utils import get_records
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__)
......@@ -19,29 +19,7 @@ class GraphConverter:
self.global_graph_id = 0
self.modules_arg = get_records()
def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
ignore_first = False
continue
# handle source node
def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
if _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
......@@ -66,31 +44,63 @@ class GraphConverter:
src_node_idx = idx
assert isinstance(ir_predecessor_node, Node)
src_node = ir_predecessor_node
return src_node, src_node_idx
def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
ignore_first = False
continue
# handle source node
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
# handle destination node
dst_node = new_node
if is_single_input:
dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
# create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1
def create_prim_constant_node(self, ir_graph, node, module_name):
attrs = {}
if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()}
# NOTE: compare with string not type, because the type is defined in pytorch C code.
# `.kind()` can also be used here
if node.outputsAt(0).type().str() == 'None':
attrs = {'type': 'None'}
else:
attrs = {'type': node.outputsAt(0).type().str(), 'value': node.outputsAt(0).toIValue()}
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
node.kind(), attrs)
return new_node
def handle_prim_attr_node(self, node):
def handle_prim_attr_node(self, node, module):
assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
value = None
if node.inputsAt(0).debugName() == 'self':
_val = getattr(module, node.s('name'))
# TODO: serialize complex data type, and output proper error message
if isinstance(_val, (int, float, str, bool)):
value = _val
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName(), 'value': value}
return node.kind(), attrs
def _remove_mangle(self, module_type_str):
......@@ -124,7 +134,10 @@ class GraphConverter:
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph, module, module_name, ir_model, ir_graph):
def handle_graph_nodes(self, script_module, sm_graph,
module, module_name,
ir_model, ir_graph,
shared_module_index=None):
"""
Convert torch script node to our node ir, and build our graph ir
......@@ -142,6 +155,10 @@ class GraphConverter:
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
shared_module_index : dict
it is used for knowing which module has been created an ir node,
if created and invoked again, then the new ir node can simply reference that ir node.
this way we can identify shared modules (i.e., one module invoked multiple times in `forward` function)
Returns
-------
......@@ -159,6 +176,8 @@ class GraphConverter:
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
if shared_module_index is None:
shared_module_index = {}
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
......@@ -167,6 +186,7 @@ class GraphConverter:
# key: tensor (%out.1), value: node (this node)
output_remap = {}
# ===================handle control flow: if===================
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
......@@ -189,8 +209,45 @@ class GraphConverter:
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
elif tensor.node().kind() == 'aten::le':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} <= {right})'
elif tensor.node().kind() == 'aten::ge':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} >= {right})'
elif tensor.node().kind() == 'aten::__not__':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(not {value})'
elif tensor.node().kind() == 'aten::Bool':
value = _generate_expr(tensor.node().inputsAt(0))
return f'bool({value})'
elif tensor.node().kind() == 'aten::__is__':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} is {right})'
elif tensor.node().kind() == 'aten::__isnot__':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} is not {right})'
elif tensor.node().kind() == 'aten::ne':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} != {right})'
elif tensor.node().kind() == 'aten::gt':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} > {right})'
elif tensor.node().kind() == 'aten::lt':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} < {right})'
elif tensor.node().kind() == 'prim::If':
raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition')
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
'you are suggested to decorate the corresponding class with "@blackbox_module".')
expr = _generate_expr(cond_tensor)
return eval(expr)
......@@ -217,24 +274,18 @@ class GraphConverter:
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
self._add_edge(ir_graph, blocks[chosen_block].returnNode(), graph_inputs, node_index, new_node, output_remap)
last_block_node = new_node
return last_block_node
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if node.kind() == 'prim::CallMethod':
# ===================handle function call===================
def handle_function_callmethod(node):
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
if node.s('name') == 'forward':
# NOTE: "forward__0" is hacky, LSTM instance is parsed to call forward__0 in torchscript
if node.s('name') in ['forward', 'forward__0']:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
......@@ -268,7 +319,7 @@ class GraphConverter:
assert predecessor.hasAttribute('name')
assert predecessor.inputsAt(0).debugName() == 'self'
predecessor_name = predecessor.s('name')
# FIXME: exchange
# TODO: exchange submodule_name and predecessor_name
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
......@@ -277,8 +328,16 @@ class GraphConverter:
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
# TODO: match subgraph with maintained graphs
# build cell
if submodule_full_name in shared_module_index:
# this module is invoked more than once, the ir node has already been created
# create a reference node for it.
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self.global_seq += 1
shared_node_name = build_full_name(submodule_full_name, '', self.global_seq)
shared_type_operation = Operation.new('shared', {'reference': submodule_full_name})
subcell = ir_graph.add_node(shared_node_name, shared_type_operation)
else:
# this module is processed for the first time, build cell for it
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
......@@ -290,11 +349,67 @@ class GraphConverter:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
shared_module_index[submodule_full_name] = subcell
node_index[node] = subcell
# connect the cell into graph
self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name')))
# handle normal member function
assert hasattr(script_module, node.s('name'))
# TODO: support non member functions
assert node.inputsAt(0).debugName() == 'self'
script_method = getattr(script_module, node.s('name')) # <class 'torch._C.ScriptMethod'>
# step #1: generate graph ir for this method
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
method_node_index = self.handle_graph_nodes(script_module, script_method.graph, module,
module_name, ir_model, method_ir_graph, shared_module_index)
for _output in script_method.graph.outputs():
method_ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
method_ir_graph.add_edge(head=(method_node_index[_output.node()], src_node_idx),
tail=(method_ir_graph.output_node, None))
self.refine_graph(method_ir_graph)
# step #2: merge this graph to its module graph
for h_node in method_ir_graph.hidden_nodes:
h_node.graph = ir_graph
ir_graph.hidden_nodes.append(h_node)
for edge in method_ir_graph.edges:
edge.graph = ir_graph
if edge.head == method_ir_graph.input_node:
# this is a member method, 'self' is the first argument, thus +1
_input = node.inputsAt(edge.head_slot + 1)
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
edge.head = src_node
edge.head_slot = src_node_idx
if edge.tail == method_ir_graph.output_node:
# since the following nodes have not been created, skip this edge
# edge.head is the output node of this method
# TODO: check whether there could be multiple output nodes???
node_index[node] = edge.head
continue
ir_graph.edges.append(edge)
# ===================handle each single node===================
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if node.kind() == 'prim::CallMethod':
handle_function_callmethod(node)
elif node.kind() == 'prim::CallFunction':
func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
......@@ -310,30 +425,14 @@ class GraphConverter:
elif node.kind() == 'prim::Constant':
new_node = self.create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
elif node.kind() in ['prim::ListConstruct', 'prim::ListUnpack', 'prim::TupleConstruct', 'prim::TupleUnpack']:
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, self.global_seq), node.kind())
prim_op_name = node.kind().split('::')[-1]
new_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'prim::TupleConstruct':
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.TupleConstruct, self.global_seq), node.kind())
node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = self.handle_prim_attr_node(node)
node_type, attrs = self.handle_prim_attr_node(node, module)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
node_type, attrs)
......@@ -345,6 +444,26 @@ class GraphConverter:
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
elif node.kind().startswith('prim::'):
self.global_seq += 1
prim_op_name = node.kind().replace('::', '__')
prim_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
node_index[node] = prim_node
self._add_edge(ir_graph, node, graph_inputs, node_index, prim_node, output_remap)
elif node.kind() == 'aten::append':
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
......@@ -378,6 +497,11 @@ class GraphConverter:
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
for edge in head_node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(head_node)
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
......@@ -478,7 +602,7 @@ class GraphConverter:
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = self.modules_arg[id(module)]
elif original_type_name in torch.nn.__dict__:
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in self.modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = self.modules_arg[id(module)]
......
......@@ -9,34 +9,8 @@ class OpTypeName(str, Enum):
"""
Attr = 'Attr'
Constant = 'Constant'
ListConstruct = 'ListConstruct'
TupleConstruct = 'TupleConstruct'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View',
'aten::reshape': 'Reshape',
'aten::eq': 'Eq',
'aten::Bool': 'Bool',
'aten::empty': 'Empty',
'aten::zeros': 'Zeros',
'aten::chunk': 'Chunk',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF = {}
......@@ -45,7 +45,7 @@ class ValueChoiceMutator(Mutator):
chosen = self.choice(self.candidates)
for node in self.nodes:
target = model.get_node_by_name(node.name)
target.update_operation('prim::Constant', {'value': chosen})
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
......
......@@ -83,6 +83,31 @@ class Operation:
class PyTorchOperation(Operation):
@classmethod
def _find_subclass(cls, subclass_name):
if cls.to_class_name(subclass_name) is not None:
subclass_name = 'ModuleOperator'
if cls.is_functional(subclass_name):
subclass_name = 'FunctionalOperator'
for subclass in cls.__subclasses__():
if hasattr(subclass, '_ori_type_name') and \
subclass_name in subclass._ori_type_name:
return subclass
return cls
@classmethod
def to_class_name(cls, type_name) -> str:
if type_name.startswith('__torch__.'):
return type_name[len('__torch__.'):]
elif type_name.startswith('__mutated__.'):
return type_name[len('__mutated__.'):]
else:
return None
@classmethod
def is_functional(cls, type_name) -> bool:
return type_name.startswith('Function.')
def _to_class_name(self) -> str:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):]
......@@ -106,59 +131,27 @@ class PyTorchOperation(Operation):
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
from .converter.op_types import OpTypeName
if self._to_class_name() is not None:
return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type.startswith('Function.'):
func_name = self.type[len('Function.'):]
return f'{output} = F.{func_name}({", ".join(inputs)})'
elif self.type == 'prim::Constant':
if self.parameters:
value = self.parameters['value']
else:
value = None
return f'{output} = {value}'
elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'prim::TupleConstruct':
return f'{output} = ({", ".join(inputs)})'
elif self.type == 'prim::GetAttr':
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
elif self.type == 'aten::mean':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
elif self.type == 'aten::__getitem__':
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
elif self.type == 'aten::append':
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
elif self.type == 'aten::cat':
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
elif self.type == 'aten::add':
return f'{output} = ' + ' + '.join(inputs)
elif self.type == OpTypeName.MergedSlice:
assert (len(inputs) - 1) % 4 == 0
slices = []
dim = int((len(inputs) - 1) / 4)
for i in range(dim):
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif self.type == 'aten::size':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.size({inputs[1]})'
elif self.type == 'aten::view':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.view({inputs[1]})'
elif self.type == 'aten::reshape':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.reshape({inputs[1]})'
elif self.type == 'aten::slice':
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
"""
Parameters
----------
field : str
the name of member submodule
output : str
the output name (lvalue) of this line of code
inputs : List[str]
variables used in this line of code
inputs_value : List[Any]
some variables are actually constant, their real values are recorded in ```inputs_value```.
if not constant, we simply put None at the corresponding index
Returns
-------
str
generated code line
"""
if self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
elif self.type == 'aten::Bool':
return f'{output} = bool({inputs[0]})'
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
......@@ -212,6 +205,8 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name)
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class _IOPseudoOperation(Operation):
"""
......
from typing import (Any, List)
import torch
from ..operation import PyTorchOperation
class relu(PyTorchOperation):
def to_init_code(self, field):
return ''
mem_format = [
'torch.contiguous_format', # 0
'torch.preserve_format', # 1
'torch.channels_last', # 2
]
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = nn.functional.relu({inputs[0]})'
# this snippet is copied from torch/onnx/symbolic_helper.py,
# the original definition is in c10/core/ScalarType.h
# This indicates each scalar type's corresponding
scalar_type_to_pytorch_type = [
'torch.uint8', # 0
'torch.int8', # 1
'torch.short', # 2
'torch.int', # 3
'torch.int64', # 4
'torch.half', # 5
'torch.float', # 6
'torch.double', # 7
'torch.complex32', # 8
'torch.complex64', # 9
'torch.complex128', # 10
'torch.bool', # 11
]
class NoOpIdentity(PyTorchOperation):
"""
this operator type is added by us
"""
_ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {", ".join(inputs)}'
class Flatten(PyTorchOperation):
def to_init_code(self, field):
return ''
class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = {inputs[0]}.view({inputs[0]}.size(0), -1)'
class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
func_name = self.type[len('Function.'):]
if not hasattr(torch.nn.functional, func_name):
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})'
class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] == 'None':
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'):
return f'{output} = {self.parameters["value"]}'
elif self.parameters['type'] == 'str':
str_val = self.parameters["value"]
return f'{output} = "{str_val}"'
elif self.parameters['type'] == 'Device':
value = self.parameters['value']
return f'{output} = torch.device("{value}")'
else:
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]}'
class ToDevice(PyTorchOperation):
def to_init_code(self, field):
return ''
class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = ({", ".join(inputs)})'
def to_forward_code(self, field, output, inputs) -> str:
class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1
return f"{output} = {inputs[0]}.to('{self.parameters['device']}')"
return f'{output} = {inputs[0]}'
class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}"
else:
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
class Dense(PyTorchOperation):
def to_init_code(self, field):
return f"self.{field} = nn.Linear({self.parameters['in_features']}, {self.parameters['out_features']})"
class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}'
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = self.{field}({inputs[0]})'
class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
class Softmax(PyTorchOperation):
def to_init_code(self, field):
return ''
class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = F.softmax({inputs[0]}, -1)'
class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if (len(inputs) - 1) % 4 == 0:
slices = []
dim = int((len(inputs) - 1) / 4)
for i in range(dim):
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif len(inputs) == 4:
# this case is for simple list
return f'{output} = {inputs[0]}[{inputs[1]}:{inputs[2]}:{inputs[3]}]'
else:
raise RuntimeError('Unsupported slice pattern')
# the following Aten classes means these aten ops are not in torch.Tensor
class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
#====================================
class AtenTensors(PyTorchOperation):
_ori_type_name = ['aten::full', 'aten::full_like', 'aten::empty_like',
'aten::ones_like', 'aten::zeros_like', 'aten::rand',
'aten::randn', 'aten::scalar_tensor', 'aten::new_full',
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type)
# match number of inputs
overloaded_defs = [len(s.arguments) for s in schemas]
matched = overloaded_defs.index(len(inputs))
args_list = []
for idx, arg in enumerate(schemas[matched].arguments):
if arg.name == 'dtype':
arg_str = f'dtype={scalar_type_to_pytorch_type[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
elif arg.name == 'layout':
if inputs_value[idx] is not None:
arg_str = f'layout=torch.strided'
print('Warning: only support `torch.strided` for now!!!')
else:
arg_str = ''
elif arg.name == 'device':
arg_str = f'device=torch.device({inputs[idx]})' if inputs_value[idx] is not None else ''
elif arg.name == 'memory_format':
arg_str = f'memory_format={mem_format[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
elif arg.name == 'pin_memory':
# TODO: deal with this argument
continue
elif arg.name == 'requires_grad':
arg_str = f'requires_grad={inputs[idx]}' if inputs_value[idx] else ''
elif str(arg.type).startswith('Optional['):
arg_str = f'{arg.name}={inputs[idx]}'
else:
arg_str = f'{inputs[idx]}'
if arg_str != '':
args_list.append(arg_str)
op_name = self.type.split('::')[-1]
if hasattr(torch, op_name):
return f'{output} = torch.{op_name}({", ".join(args_list)})'
else:
return f'{output} = {inputs[0]}.{op_name}({", ".join(args_list[1:])})'
#====================================
class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}'
elif self.type == 'aten::Int':
return f'{output} = int({inputs[0]})'
elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})'
class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]'
ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')]
}
TensorOpExceptions = {
'aten::sub': lambda output, inputs: f'{output} = {inputs[0]} - {inputs[1]}', # example: x.size(1) - 3
'aten::add': lambda output, inputs: f'{output} = {inputs[0]} + {inputs[1]}' # example: input.shape[0] + 5
}
TorchOpExclude = ['aten::Size', 'aten::as_tensor', 'aten::device',
'aten::manual_seed', 'aten::quantized_gru', 'aten::quantized_lstm',
'aten::save', 'aten::tensor', 'aten::wait'
]
def _hidden(name):
return name.startswith('_') and not name.startswith('__')
def _emit_args(args):
# filter out the `out` argument here
return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out'
def _get_tensor_ops():
def is_tensor_method(schema):
if len(schema.arguments) == 0:
return False
self = schema.arguments[0]
if self.name != 'self':
return False
if not self.type.isSubtypeOf(torch._C.TensorType.get()):
return False
return True
op_args = {}
# discover methods
for elem in dir(torch.Tensor):
if not _hidden(elem):
schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
for schema in schemas:
if is_tensor_method(schema):
op_name = 'aten::' + elem
args = _emit_args(schema.arguments[1:])
if op_name in op_args:
op_args[op_name].append(args)
else:
op_args[op_name] = [args]
return op_args.keys(), op_args
def _get_torch_ops():
torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins:
name = mod.__name__
if name == 'torch._C._nn':
continue
# only process 'torch.XXX'
for elem in dir(mod):
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem))
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
# remove _tan but not __and__
if not _hidden(elem):
op_name = 'aten::' + elem
if len(schema.arguments) > 0 and schema.arguments[0].name == 'self':
continue
args = _emit_args(schema.arguments)
if op_name in torch_op_args:
torch_op_args[op_name].append(args)
else:
torch_op_args[op_name] = [args]
return torch_op_args.keys(), torch_op_args
def _get_torch_ops_exclude_tensor_ops():
tensor_op_names, _ = _get_tensor_ops()
torch_op_names, torch_ops = _get_torch_ops()
torch_exclude_ops = {}
for name in torch_op_names:
if name not in tensor_op_names:
if name not in TorchOpExclude:
# exclude the ops that are not in
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
torch_exclude_ops[name] = torch_ops[name]
return torch_exclude_ops.keys(), torch_exclude_ops
class TensorOps(PyTorchOperation):
"""
corresponding to _get_tensor_ops in torch.jit.supported_ops
"""
_ori_type_name, _op_args = _get_tensor_ops()
comparison_ops = {'aten::eq': '==', 'aten::ne': '!=', 'aten::le': '<=', 'aten::ge': '>=', 'aten::lt': '<', 'aten::gt': '>'}
@staticmethod
def _get_matched_args(_type, inputs):
def has_same_arg_name(matched):
concated_names = []
for i, each in enumerate(matched):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i+1]:
return False
return True
overloaded_defs = TensorOps._op_args[_type]
matched = []
for each in overloaded_defs:
# plus 1 because we skip the first argument when generating tensor op def
if len(each) + 1 == len(inputs):
matched.append(each)
if len(matched) == 1:
return matched[0]
elif len(matched) > 1:
# TODO: match with arg's type. manually choose for now
if has_same_arg_name(matched):
# return any one is okay
return matched[0]
elif _type in ManuallyChooseDef:
return ManuallyChooseDef[_type]
else:
raise RuntimeError(f'tensor op type {_type} has more than one matched: {matched}')
else:
if _type in TensorOpExceptions:
return None
raise RuntimeError(f'tensor op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# TODO: deal with conditional ops
if self.type in TensorOps.comparison_ops:
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
matched_args = TensorOps._get_matched_args(self.type, inputs)
if matched_args is None:
return TensorOpExceptions[self.type](output, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
print(args_str)
return f'{output} = {inputs[0]}.{op_name}({args_str})'
class TorchOps(PyTorchOperation):
"""
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
"""
_ori_type_name, _op_args = _get_torch_ops_exclude_tensor_ops()
# add 'aten::pixel_shuffle'
_op_args['aten::pixel_shuffle'] = [[('input', 'Tensor', 'None'), ('upscale_factor', 'Optional[int]', 'None')]]
_ori_type_name = _op_args.keys()
@staticmethod
def _get_matched_args(_type, inputs):
def has_same_arg_name(matched):
concated_names = []
for i, each in enumerate(matched):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i+1]:
return False
return True
overloaded_defs = TorchOps._op_args[_type]
matched = []
for each in overloaded_defs:
if len(each) == len(inputs):
matched.append(each)
if len(matched) == 1:
return matched[0]
elif len(matched) > 1:
# TODO: match with arg's type. manually choose for now
if has_same_arg_name(matched):
# return any one is okay
return matched[0]
else:
raise RuntimeError(f'torch op type {_type} has more than one matched: {matched}')
else:
raise RuntimeError(f'torch op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}' \
for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = torch.{op_name}({args_str})'
class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
\ No newline at end of file
......@@ -162,6 +162,14 @@ def _get_module_name(cls):
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
break
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if f'{cls.__module__}.{cls.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls.__module__
return module_name
......
......@@ -250,7 +250,9 @@ stages:
- script: |
cd test
python -m pytest ut
python -m pytest ut --ignore=ut/retiarii/test_convert_basic.py \
--ignore=ut/retiarii/test_convert_operators.py \
--ignore=ut/retiarii/test_convert_pytorch.py
displayName: Python unit test
- script: |
......
import inspect
import logging
import torch
import torch.nn as nn
from nni.retiarii.utils import add_record, del_record, version_larger_equal
_logger = logging.getLogger(__name__)
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
original_class.bak_init_for_inject = orig_init
if hasattr(original_class, '__del__'):
orig_del = original_class.__del__
original_class.bak_del_for_inject = orig_del
else:
orig_del = None
original_class.bak_del_for_inject = None
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
def __del__(self):
del_record(id(self))
if orig_del is not None:
orig_del(self)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__del__ = __del__
return original_class
def unwrap_module(wrapped_class):
if hasattr(wrapped_class, 'bak_init_for_inject'):
wrapped_class.__init__ = wrapped_class.bak_init_for_inject
delattr(wrapped_class, 'bak_init_for_inject')
if hasattr(wrapped_class, 'bak_del_for_inject'):
if wrapped_class.bak_del_for_inject is not None:
wrapped_class.__del__ = wrapped_class.bak_del_for_inject
delattr(wrapped_class, 'bak_del_for_inject')
return None
def remove_inject_pytorch_nn():
Identity = unwrap_module(nn.Identity)
Linear = unwrap_module(nn.Linear)
Conv1d = unwrap_module(nn.Conv1d)
Conv2d = unwrap_module(nn.Conv2d)
Conv3d = unwrap_module(nn.Conv3d)
ConvTranspose1d = unwrap_module(nn.ConvTranspose1d)
ConvTranspose2d = unwrap_module(nn.ConvTranspose2d)
ConvTranspose3d = unwrap_module(nn.ConvTranspose3d)
Threshold = unwrap_module(nn.Threshold)
ReLU = unwrap_module(nn.ReLU)
Hardtanh = unwrap_module(nn.Hardtanh)
ReLU6 = unwrap_module(nn.ReLU6)
Sigmoid = unwrap_module(nn.Sigmoid)
Tanh = unwrap_module(nn.Tanh)
Softmax = unwrap_module(nn.Softmax)
Softmax2d = unwrap_module(nn.Softmax2d)
LogSoftmax = unwrap_module(nn.LogSoftmax)
ELU = unwrap_module(nn.ELU)
SELU = unwrap_module(nn.SELU)
CELU = unwrap_module(nn.CELU)
GLU = unwrap_module(nn.GLU)
GELU = unwrap_module(nn.GELU)
Hardshrink = unwrap_module(nn.Hardshrink)
LeakyReLU = unwrap_module(nn.LeakyReLU)
LogSigmoid = unwrap_module(nn.LogSigmoid)
Softplus = unwrap_module(nn.Softplus)
Softshrink = unwrap_module(nn.Softshrink)
MultiheadAttention = unwrap_module(nn.MultiheadAttention)
PReLU = unwrap_module(nn.PReLU)
Softsign = unwrap_module(nn.Softsign)
Softmin = unwrap_module(nn.Softmin)
Tanhshrink = unwrap_module(nn.Tanhshrink)
RReLU = unwrap_module(nn.RReLU)
AvgPool1d = unwrap_module(nn.AvgPool1d)
AvgPool2d = unwrap_module(nn.AvgPool2d)
AvgPool3d = unwrap_module(nn.AvgPool3d)
MaxPool1d = unwrap_module(nn.MaxPool1d)
MaxPool2d = unwrap_module(nn.MaxPool2d)
MaxPool3d = unwrap_module(nn.MaxPool3d)
MaxUnpool1d = unwrap_module(nn.MaxUnpool1d)
MaxUnpool2d = unwrap_module(nn.MaxUnpool2d)
MaxUnpool3d = unwrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = unwrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = unwrap_module(nn.FractionalMaxPool3d)
LPPool1d = unwrap_module(nn.LPPool1d)
LPPool2d = unwrap_module(nn.LPPool2d)
LocalResponseNorm = unwrap_module(nn.LocalResponseNorm)
BatchNorm1d = unwrap_module(nn.BatchNorm1d)
BatchNorm2d = unwrap_module(nn.BatchNorm2d)
BatchNorm3d = unwrap_module(nn.BatchNorm3d)
InstanceNorm1d = unwrap_module(nn.InstanceNorm1d)
InstanceNorm2d = unwrap_module(nn.InstanceNorm2d)
InstanceNorm3d = unwrap_module(nn.InstanceNorm3d)
LayerNorm = unwrap_module(nn.LayerNorm)
GroupNorm = unwrap_module(nn.GroupNorm)
SyncBatchNorm = unwrap_module(nn.SyncBatchNorm)
Dropout = unwrap_module(nn.Dropout)
Dropout2d = unwrap_module(nn.Dropout2d)
Dropout3d = unwrap_module(nn.Dropout3d)
AlphaDropout = unwrap_module(nn.AlphaDropout)
FeatureAlphaDropout = unwrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = unwrap_module(nn.ReflectionPad1d)
ReflectionPad2d = unwrap_module(nn.ReflectionPad2d)
ReplicationPad2d = unwrap_module(nn.ReplicationPad2d)
ReplicationPad1d = unwrap_module(nn.ReplicationPad1d)
ReplicationPad3d = unwrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = unwrap_module(nn.CrossMapLRN2d)
Embedding = unwrap_module(nn.Embedding)
EmbeddingBag = unwrap_module(nn.EmbeddingBag)
RNNBase = unwrap_module(nn.RNNBase)
RNN = unwrap_module(nn.RNN)
LSTM = unwrap_module(nn.LSTM)
GRU = unwrap_module(nn.GRU)
RNNCellBase = unwrap_module(nn.RNNCellBase)
RNNCell = unwrap_module(nn.RNNCell)
LSTMCell = unwrap_module(nn.LSTMCell)
GRUCell = unwrap_module(nn.GRUCell)
PixelShuffle = unwrap_module(nn.PixelShuffle)
Upsample = unwrap_module(nn.Upsample)
UpsamplingNearest2d = unwrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = unwrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = unwrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = unwrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = unwrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = unwrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = unwrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = unwrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = unwrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = unwrap_module(nn.TripletMarginLoss)
ZeroPad2d = unwrap_module(nn.ZeroPad2d)
ConstantPad1d = unwrap_module(nn.ConstantPad1d)
ConstantPad2d = unwrap_module(nn.ConstantPad2d)
ConstantPad3d = unwrap_module(nn.ConstantPad3d)
Bilinear = unwrap_module(nn.Bilinear)
CosineSimilarity = unwrap_module(nn.CosineSimilarity)
Unfold = unwrap_module(nn.Unfold)
Fold = unwrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = unwrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = unwrap_module(nn.TransformerEncoder)
TransformerDecoder = unwrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = unwrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = unwrap_module(nn.TransformerDecoderLayer)
Transformer = unwrap_module(nn.Transformer)
Flatten = unwrap_module(nn.Flatten)
Hardsigmoid = unwrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = unwrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = unwrap_module(nn.SiLU)
Unflatten = unwrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = unwrap_module(nn.TripletMarginWithDistanceLoss)
def inject_pytorch_nn():
Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear)
Conv1d = wrap_module(nn.Conv1d)
Conv2d = wrap_module(nn.Conv2d)
Conv3d = wrap_module(nn.Conv3d)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d)
Threshold = wrap_module(nn.Threshold)
ReLU = wrap_module(nn.ReLU)
Hardtanh = wrap_module(nn.Hardtanh)
ReLU6 = wrap_module(nn.ReLU6)
Sigmoid = wrap_module(nn.Sigmoid)
Tanh = wrap_module(nn.Tanh)
Softmax = wrap_module(nn.Softmax)
Softmax2d = wrap_module(nn.Softmax2d)
LogSoftmax = wrap_module(nn.LogSoftmax)
ELU = wrap_module(nn.ELU)
SELU = wrap_module(nn.SELU)
CELU = wrap_module(nn.CELU)
GLU = wrap_module(nn.GLU)
GELU = wrap_module(nn.GELU)
Hardshrink = wrap_module(nn.Hardshrink)
LeakyReLU = wrap_module(nn.LeakyReLU)
LogSigmoid = wrap_module(nn.LogSigmoid)
Softplus = wrap_module(nn.Softplus)
Softshrink = wrap_module(nn.Softshrink)
MultiheadAttention = wrap_module(nn.MultiheadAttention)
PReLU = wrap_module(nn.PReLU)
Softsign = wrap_module(nn.Softsign)
Softmin = wrap_module(nn.Softmin)
Tanhshrink = wrap_module(nn.Tanhshrink)
RReLU = wrap_module(nn.RReLU)
AvgPool1d = wrap_module(nn.AvgPool1d)
AvgPool2d = wrap_module(nn.AvgPool2d)
AvgPool3d = wrap_module(nn.AvgPool3d)
MaxPool1d = wrap_module(nn.MaxPool1d)
MaxPool2d = wrap_module(nn.MaxPool2d)
MaxPool3d = wrap_module(nn.MaxPool3d)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d)
LPPool1d = wrap_module(nn.LPPool1d)
LPPool2d = wrap_module(nn.LPPool2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm)
BatchNorm1d = wrap_module(nn.BatchNorm1d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d)
LayerNorm = wrap_module(nn.LayerNorm)
GroupNorm = wrap_module(nn.GroupNorm)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm)
Dropout = wrap_module(nn.Dropout)
Dropout2d = wrap_module(nn.Dropout2d)
Dropout3d = wrap_module(nn.Dropout3d)
AlphaDropout = wrap_module(nn.AlphaDropout)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d)
Embedding = wrap_module(nn.Embedding)
EmbeddingBag = wrap_module(nn.EmbeddingBag)
RNNBase = wrap_module(nn.RNNBase)
RNN = wrap_module(nn.RNN)
LSTM = wrap_module(nn.LSTM)
GRU = wrap_module(nn.GRU)
RNNCellBase = wrap_module(nn.RNNCellBase)
RNNCell = wrap_module(nn.RNNCell)
LSTMCell = wrap_module(nn.LSTMCell)
GRUCell = wrap_module(nn.GRUCell)
PixelShuffle = wrap_module(nn.PixelShuffle)
Upsample = wrap_module(nn.Upsample)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = wrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
......@@ -35,16 +35,29 @@ class MnistNet(nn.Module):
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# NOTE: blackbox module cannot be placed within class or function
@blackbox_module
class Linear(nn.Module):
def __init__(self, d_embed, d_proj):
super().__init__()
self.linear = nn.Linear(d_embed, d_proj)
def forward(self, input):
if len(input.size()) <= 2:
return self.linear(input)
size = input.size()[:2]
out = self.linear(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for cv in current_values:
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.remove(cv)
current_values.pop(idx)
break
return result
......@@ -53,6 +66,9 @@ class TestConvert(unittest.TestCase):
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir)
from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn()
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
......@@ -134,18 +150,17 @@ class TestConvert(unittest.TestCase):
model = DCGANGenerator(nz, ngf, nc)
self.checkExportImport(model, input)
@unittest.skip('this test has a if condition that needs to be handle') # FIXME
def test_neural_style(self):
class TransformerNet(torch.nn.Module):
class TransformerNet(nn.Module):
def __init__(self):
super(TransformerNet, self).__init__()
# Initial convolution layers
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
self.in1 = nn.InstanceNorm2d(32, affine=True)
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
self.in2 = nn.InstanceNorm2d(64, affine=True)
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
self.in3 = nn.InstanceNorm2d(128, affine=True)
# Residual layers
self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128)
......@@ -154,12 +169,12 @@ class TestConvert(unittest.TestCase):
self.res5 = ResidualBlock(128)
# Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
self.in4 = nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
self.in5 = nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
# Non-linearities
self.relu = torch.nn.ReLU()
self.relu = nn.ReLU()
def forward(self, X):
y = self.relu(self.in1(self.conv1(X)))
......@@ -175,19 +190,19 @@ class TestConvert(unittest.TestCase):
y = self.deconv3(y)
return y
class ConvLayer(torch.nn.Module):
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer, self).__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class ResidualBlock(torch.nn.Module):
class ResidualBlock(nn.Module):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
......@@ -196,10 +211,10 @@ class TestConvert(unittest.TestCase):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
self.in1 = nn.InstanceNorm2d(channels, affine=True)
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
self.relu = torch.nn.ReLU()
self.in2 = nn.InstanceNorm2d(channels, affine=True)
self.relu = nn.ReLU()
def forward(self, x):
residual = x
......@@ -208,7 +223,7 @@ class TestConvert(unittest.TestCase):
out = out + residual
return out
class UpsampleConvLayer(torch.nn.Module):
class UpsampleConvLayer(nn.Module):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
......@@ -219,10 +234,10 @@ class TestConvert(unittest.TestCase):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
self.upsample_layer = nn.Upsample(mode='nearest', scale_factor=upsample)
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
x_in = x
......@@ -254,50 +269,40 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(Policy(), (torch.rand(1, 4),))
@unittest.skip('Replaced init error.') # FIXME
def test_snli(self):
class Bottle(nn.Module):
def forward(self, input):
if len(input.size()) <= 2:
return super(Bottle, self).forward(input)
size = input.size()[:2]
out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)
class Linear(Bottle, nn.Linear):
pass
class Encoder(nn.Module):
def __init__(self, config):
super(Encoder, self).__init__()
self.config = config
input_size = config.d_proj if config.projection else config.d_embed
dropout = 0 if config.n_layers == 1 else config.dp_ratio
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
num_layers=config.n_layers, dropout=dropout,
bidirectional=config.birnn)
#self.config = config
input_size = config["d_proj"] if config["projection"] else config["d_embed"]
dropout = 0 if config["n_layers"] == 1 else config["dp_ratio"]
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config["d_hidden"],
num_layers=config["n_layers"], dropout=dropout,
bidirectional=config["birnn"])
self.n_cells = config["n_cells"]
self.d_hidden = config["d_hidden"]
self.birnn = config["birnn"]
def forward(self, inputs):
batch_size = inputs.size()[1]
state_shape = self.config.n_cells, batch_size, self.config.d_hidden
state_shape = self.n_cells, batch_size, self.d_hidden
h0 = c0 = inputs.new_zeros(state_shape)
outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
return ht[-1] if not self.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
class SNLIClassifier(nn.Module):
def __init__(self, config):
super(SNLIClassifier, self).__init__()
self.config = config
self.embed = nn.Embedding(config.n_embed, config.d_embed)
self.projection = Linear(config.d_embed, config.d_proj)
self.embed = nn.Embedding(config["n_embed"], config["d_embed"])
self.projection = Linear(config["d_embed"], config["d_proj"])
self.encoder = Encoder(config)
self.dropout = nn.Dropout(p=config.dp_ratio)
self.dropout = nn.Dropout(p=config["dp_ratio"])
self.relu = nn.ReLU()
seq_in_size = 2 * config.d_hidden
if self.config.birnn:
seq_in_size = 2 * config["d_hidden"]
if config["birnn"]:
seq_in_size *= 2
lin_config = [seq_in_size] * 2
self.out = nn.Sequential(
......@@ -310,15 +315,17 @@ class TestConvert(unittest.TestCase):
Linear(*lin_config),
self.relu,
self.dropout,
Linear(seq_in_size, config.d_out))
Linear(seq_in_size, config["d_out"]))
self.fix_emb = config["fix_emb"]
self.project = config["projection"]
def forward(self, premise, hypothesis):
prem_embed = self.embed(premise)
hypo_embed = self.embed(hypothesis)
if self.config.fix_emb:
if self.fix_emb:
prem_embed = prem_embed.detach()
hypo_embed = hypo_embed.detach()
if self.config.projection:
if self.project:
prem_embed = self.relu(self.projection(prem_embed))
hypo_embed = self.relu(self.projection(hypo_embed))
premise = self.encoder(prem_embed)
......@@ -326,23 +333,24 @@ class TestConvert(unittest.TestCase):
scores = self.out(torch.cat([premise, hypothesis], 1))
return scores
class Config:
n_embed = 100
d_embed = 100
d_proj = 300
dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace?
d_hidden = 30
birnn = True
d_out = 300
fix_emb = True
projection = True
n_layers = 2
n_cells = 4 # 2 * n_layers because birnn = True
Config = {
"n_embed": 100,
"d_embed": 100,
"d_proj": 300,
"dp_ratio": 0.0, # For deterministic testing TOD": change by fixing seed in checkTrace?,
"d_hidden": 30,
"birnn": True,
"d_out": 300,
"fix_emb": True,
"projection": True,
"n_layers": 2,
"n_cells": 4 # 2 * n_layers because birnn = True,
}
premise = torch.LongTensor(48, 64).random_(0, 100)
hypothesis = torch.LongTensor(24, 64).random_(0, 100)
self.checkExportImport(SNLIClassifier(Config()), (premise, hypothesis))
self.checkExportImport(SNLIClassifier(Config), (premise, hypothesis))
def test_super_resolution(self):
class Net(nn.Module):
......@@ -367,16 +375,16 @@ class TestConvert(unittest.TestCase):
net = Net(upscale_factor=4)
self.checkExportImport(net, (torch.rand(5, 1, 32, 32),))
@unittest.skip('Need to support operator prim::ListUnpack') # FIXME
@unittest.skip('Need to support Loop') # FIXME
def test_time_sequence_prediction(self):
class Sequence(torch.jit.ScriptModule):
class Sequence(nn.Module): #torch.jit.ScriptModule
def __init__(self):
super(Sequence, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)
@torch.jit.script_method
#@torch.jit.script_method
def forward(self, input):
# TODO: add future as input with default val
# see https://github.com/pytorch/pytorch/issues/8724
......@@ -414,7 +422,7 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(Traced(), (torch.rand(3, 4),))
@unittest.skip('Unsupported callmethod encode') # FIXME
@unittest.skip('incorrectly assigned weights') # FIXME
def test_vae(self):
class VAE(nn.Module):
def __init__(self):
......@@ -449,11 +457,11 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(VAE().eval(), (torch.rand(128, 1, 28, 28),))
@unittest.skip('torchvision models are not supported yet') # FIXME
def test_torchvision_resnet18(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
self.checkExportImport(torchvision.models.resnet18().eval(), (torch.ones(1, 3, 224, 224),))
@unittest.skip('Unsupported CallMethod _forward_impl') # FIXME
def test_resnet(self):
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
......@@ -464,7 +472,7 @@ class TestConvert(unittest.TestCase):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(torch.jit.ScriptModule):
class BasicBlock(nn.Module): #torch.jit.ScriptModule
expansion = 1
__constants__ = ['downsample']
......@@ -478,7 +486,8 @@ class TestConvert(unittest.TestCase):
self.downsample = downsample
self.stride = stride
@torch.jit.script_method
# NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def forward(self, x):
residual = x
......@@ -497,7 +506,8 @@ class TestConvert(unittest.TestCase):
return out
class ResNet(torch.jit.ScriptModule):
# NOTE: cannot inherit torch.jit.ScriptModule, otherwise, there would be error: 'RecursiveScriptModule' object has no attribute 'graph'
class ResNet(nn.Module): #torch.jit.ScriptModule
__constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
def __init__(self, block, layers, num_classes=1000):
......@@ -538,7 +548,8 @@ class TestConvert(unittest.TestCase):
return nn.Sequential(*layers)
@torch.jit.script_method
# NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
......@@ -558,10 +569,11 @@ class TestConvert(unittest.TestCase):
resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
self.checkExportImport(torchvision.models.resnet18().eval(), (torch.randn(1, 3, 224, 224),))
self.checkExportImport(resnet18, (torch.randn(1, 3, 224, 224),))
@unittest.skip('torchvision models are not supported yet') # FIXME
def test_alexnet(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
x = torch.ones(1, 3, 224, 224)
model = torchvision.models.AlexNet()
self.checkExportImport(model, (x,))
import os
import sys
import unittest
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
# following pytorch v1.7.1
class TestConvert(unittest.TestCase):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
if check_value:
self.assertEqual(len(converted_output), len(expected_output))
for a, b in zip(converted_output, expected_output):
if hasattr(a, 'dtype') and a.dtype == torch.bool:
self.assertEqual((a ^ b), False)
elif isinstance((a - b), int):
self.assertEqual((a - b), 0)
else:
self.assertLess((a - b).abs().max().item(), 1E-4)
return converted_model
# skip torch.Tensor.new_tensor as it is not supported by jit
def test_basic_new_full(self):
class SimpleOp(nn.Module):
def forward(self, x):
# requires_grad is not supported by jit
# aten::new_full(Tensor self, int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor):
# Keyword argument requires_grad unknown.
out = x.new_full((3, 4), 3.141592, dtype=torch.float32, device=torch.device('cpu'))
return out
self.checkExportImport(SimpleOp(), (torch.ones((2,), dtype=torch.float64), ))
def test_basic_new_empty(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.new_empty((2, 3), dtype=torch.int8, device=torch.device('cpu'))
return out
self.checkExportImport(SimpleOp(), (torch.ones(()), ), check_value=False)
# skip torch.Tensor.new_ones as it is not supported by jit
# requires_grad=False is not supported by jit
def test_basic_new_zeros(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.new_zeros((2, 3))
return out
self.checkExportImport(SimpleOp(), (torch.tensor((), dtype=torch.int32), ))
def test_basic_is_cuda(self):
class SimpleOp(nn.Module):
def forward(self, x):
return torch.tensor([x.is_cuda], dtype=torch.bool, device=torch.device('cpu'))
self.checkExportImport(SimpleOp(), (torch.tensor((), dtype=torch.int32), ))
# is_quantized
# is_meta
# device
# grad
# ndim
# T
# real
# imag
def test_basic_abs(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.abs()
out11 = x.absolute()
out2 = torch.abs(x)
#out3 = x.abs_()
#out33 = x.absolute_()
return out1, out11, out2#, out3, out33
self.checkExportImport(SimpleOp(), (torch.tensor([-1, -2, 3]), ))
# TODO: topological sort should be improved
#def forward(self, x__1):
# __Acos2 = x__1.acos()
# __Acos_3 = x__1.acos_()
# __Acos1 = x__1.acos()
# __TupleConstruct4 = (__Acos1,__Acos2,__Acos_3)
# return __TupleConstruct4
def test_basic_acos_asin_atan(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out1 = x.acos()
out2 = torch.acos(x)
# TODO: add back this line
#out = x.acos_()
out3 = x.asin()
out4 = torch.asin(x)
out5 = x.atan()
out6 = torch.atan(x)
out7 = x.atan2(y)
out8 = torch.atan2(x, y)
return out1, out2, out3, out4, out5, out6, out7, out8#, out
self.checkExportImport(SimpleOp(), (torch.tensor([-1.0, -0.5, 0.2]), torch.tensor([1.0, 0.6, -0.3]), ))
# arccos is not supported by jit
def test_basic_add(self):
class SimpleOp(nn.Module):
def forward(self, x):
t = torch.tensor([-1.0, -0.5, 0.2])
out1 = x.add(t)
out2 = x.add(t, alpha=2)
#out3 = x.add_(t)
return out1, out2#, out3
self.checkExportImport(SimpleOp(), (torch.tensor([-1.0, -0.5, 0.2]), ))
def test_basic_addbmm(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z, m):
out1 = x.addbmm(y, z, beta=2, alpha=3)
out2 = torch.addbmm(x, y, z, beta=2, alpha=3)
#out3 = x.addbmm_(y, z, beta=2, alpha=3)
out3 = m.baddbmm(y, z, beta=2, alpha=3)
out4 = torch.baddbmm(m, y, z, beta=2, alpha=3)
out5 = torch.bmm(y, z) # deterministic is not supported by jit
return out1, out2, out3, out4, out5
self.checkExportImport(SimpleOp(), (torch.randn(3, 5), torch.randn(10, 3, 4), torch.randn(10, 4, 5), torch.randn(10, 3, 5), ))
def test_basic_addcdiv(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addcdiv(y, z, value=2)
out2 = torch.addcdiv(x, y, z, value=2)
# addcdiv_
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), ))
def test_basic_addcmul(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addcmul(y, z, value=0.1)
out2 = torch.addcmul(x, y, z, value=0.1)
# addcmul_
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), ))
def test_basic_addmm(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addmm(y, z, beta=0.1, alpha=0.2)
out2 = torch.addmm(x, y, z, beta=0.1, alpha=0.2)
# addmm_
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3), ))
def test_basic_addmv(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addmv(y, z, beta=0.1, alpha=0.2)
out2 = torch.addmv(x, y, z, beta=0.1, alpha=0.2)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(2), torch.randn(2, 3), torch.randn(3), ))
def test_basic_addr(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addr(y, z, beta=2, alpha=3)
out2 = torch.addr(x, y, z, beta=2, alpha=3)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.zeros(3, 2), torch.arange(1., 4.), torch.arange(1., 3.), ))
def test_basic_allclose(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out1 = x.allclose(y, rtol=1e-05, atol=1e-08, equal_nan=False)
out2 = torch.allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08]), ))
def test_basic_angle(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.angle()
out2 = torch.angle(x)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]), ))
# skip apply_(callable) for now
def test_basic_argmax_argmin(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.argmax()
out2 = torch.argmax(x)
out3 = x.argmax(dim=1)
out4 = torch.argmax(x, dim=1)
out5 = x.argmax(dim=1, keepdim=True)
o1 = x.argmin()
o2 = torch.argmin(x)
o3 = x.argmin(dim=1)
o4 = x.argmin(dim=1, keepdim=True)
return out1, out2, out3, out4, out5, o1, o2, o3, o4
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))
def test_basic_argsort(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.argsort()
out2 = x.argsort(dim=1)
out3 = x.argsort(dim=1, descending=True)
out4 = torch.argsort(x, dim=1, descending=True)
return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))
# skip backward(gradient=None, retain_graph=None, create_graph=False)
def test_basic_bernoulli(self):
class SimpleOp(nn.Module):
def forward(self, x):
# generator=torch.Generator() is not supported by jit
out = x.bernoulli()
return out
self.checkExportImport(SimpleOp(), (torch.ones(3, 3), ))
# bfloat16/bool/byte/char is not supported by jit
def test_basic_bincount(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out1 = x.bincount()
out2 = torch.bincount(x)
out3 = x.bincount(weights=y)
out4 = x.bincount(weights=y, minlength=2)
return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.randint(0, 8, (5,), dtype=torch.int64), torch.linspace(0, 1, steps=5), ))
def test_basic_bitwise(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out1 = x.bitwise_not()
out2 = x.bitwise_and(y)
out3 = x.bitwise_or(y)
out4 = x.bitwise_xor(y)
return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8), ))
# cauchy_ is not supported yet
def test_ceil(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.ceil()
return out1
self.checkExportImport(SimpleOp(), (torch.randn(4), ))
\ No newline at end of file
'''
The tests in this file is copied and transformed from
`https://github.com/pytorch/pytorch/blob/master/test/onnx/test_operators.py`
'''
import os
import sys
import unittest
from typing import (Dict)
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
# following pytorch v1.7.1
class TestOperators(unittest.TestCase):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir)
#print(model_code)
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
if check_value:
try:
self.assertEqual(len(converted_output), len(expected_output))
for a, b in zip(converted_output, expected_output):
torch.eq(a, b)
except:
self.assertEqual(converted_output, expected_output)
return converted_model
def test_basic_basic(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = -torch.sigmoid(torch.tanh(x * (x + y)))
return out
x = torch.tensor([0.4], requires_grad=True)
y = torch.tensor([0.7], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_view(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.view(1, 1)
return out
x = torch.tensor([0.0], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_index(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x[0]
return out
x = torch.tensor([[0.0]], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_type_as(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.type_as(x)
return out
x = torch.tensor([0.0], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_addconstant(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x + 1
return out
x = torch.randn(2, 3, requires_grad=True).double()
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_add_broadcast(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x + y
return out
x = torch.randn(2, 3, requires_grad=True).double()
y = torch.randn(3, requires_grad=True).double()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_add_left_broadcast(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x + y
return out
x = torch.randn(3, requires_grad=True).double()
y = torch.randn(2, 3, requires_grad=True).double()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_add_size1_broadcast(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x + y
return out
x = torch.randn(2, 3, requires_grad=True).double()
y = torch.randn(2, 1, requires_grad=True).double()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_add_size1_right_broadcast(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x + y
return out
x = torch.randn(2, 3, requires_grad=True).double()
y = torch.randn(3, requires_grad=True).double()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_add_size1_singleton_broadcast(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x + y
return out
x = torch.randn(2, 3, requires_grad=True).double()
y = torch.randn(1, 3, requires_grad=True).double()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_rsub(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = 1 - x
return out
x = torch.randn(2, 3, requires_grad=True).double()
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_transpose(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.transpose(0, 1).transpose(1, 0)
return out
x = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_chunk(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.chunk(2)
return out
x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_split(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.split(x, 2, 1)
return out
x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_split_with_sizes(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.split(x, [2, 1, 3], 1)
return out
x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
self.checkExportImport(SimpleOp(), (x, ))
@unittest.skip('cannot be parsed by jit')
def test_basic_concat2(self):
class SimpleOp(nn.Module):
def forward(self, inputs):
out = torch.cat(inputs, 1)
return out
x = torch.randn(2, 3)
y = torch.randn(2, 3)
self.checkExportImport(SimpleOp(), ((x, y), ))
def test_basic_addmm(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out = torch.addmm(torch.addmm(z, x, y), x, y)
return out
m1 = torch.randn(2, 3, requires_grad=True)
m2 = torch.randn(3, 4, requires_grad=True)
m3 = torch.randn(4, requires_grad=True)
self.checkExportImport(SimpleOp(), (m1, m2, m3, ))
def test_basic_permute2(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.permute(0, 1, 4, 2, 5, 3)
return out
x = torch.tensor([[[[[[0.0]]]]]], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_params(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = -torch.sigmoid(torch.tanh(x * (x + y)))
return out
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
y = torch.nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_params_onnx_irv4(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = -torch.sigmoid(torch.tanh(x * (x + y)))
return out
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
y = torch.nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_clip(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.clamp(x, min=-0.5, max=0.5)
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_clip_min(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.clamp(min=-0.1)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_clip_max(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.clamp(max=0.1)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
@unittest.skip('cannot be parsed by jit')
def test_basic_hardtanh(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.nn.Hardtanh(-0.5, 0.5)(x)
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_full(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.full(x.shape, 2., dtype=torch.float32, layout=torch.strided, device=torch.device('cpu'))
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_full_like(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.full_like(x, 2, memory_format=torch.preserve_format)
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_max(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = torch.max(x, y)
return out
x = torch.randn(3, 4, requires_grad=True)
y = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_min(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = torch.min(x, y)
return out
x = torch.randn(3, 4, requires_grad=True)
y = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_mean(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.mean(x)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_reduced_mean(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.mean(x, dim=2)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_reduced_mean_keepdim(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.mean(x, dim=(2, 3), keepdim=True)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_sum(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.sum(x)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_reduced_sum(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.sum(x, dim=(1, 2))
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_reduced_sum_keepdim(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.sum(x, dim=2, keepdim=True)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_prod(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.prod(x)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_reduced_prod(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.prod(x, dim=2)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_reduced_prod_keepdim(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.prod(x, dim=2, keepdim=True)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_sqrt(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.sqrt(x)
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_rsqrt(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.rsqrt(x)
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_equal(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x == y
return out
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
y = torch.randn(1, 4, requires_grad=False).int()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_lt(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x < y
return out
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
y = torch.randn(1, 4, requires_grad=False).int()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_gt(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x > y
return out
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
y = torch.randn(1, 4, requires_grad=False).int()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_le(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x <= y
return out
x = torch.randn(3, 4, requires_grad=False).int()
y = torch.randn(3, 4, requires_grad=False).int()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_ge(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x >= y
return out
x = torch.randn(3, 4, requires_grad=False).int()
y = torch.randn(3, 4, requires_grad=False).int()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_exp(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.exp()
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_sin(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.sin()
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_cos(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.cos()
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_tan(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.tan()
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_asin(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.asin()
return out
x = torch.rand(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_acos(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.acos()
return out
x = torch.rand(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_slice(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x[:, 1:2]
return out
x = torch.rand(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_slice_dynamic(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x[x.size(0):, x.size(1) - 3]
return out
x = torch.rand(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_sign(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.sign()
return out
x = torch.rand(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_narrow(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.narrow(x, 0, 0, 2)
return out
x = torch.randn(3, 3, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_atan(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.atan()
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_view_flatten(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.view(x.size()[0], x.numel() // x.size()[0])
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_flatten(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.flatten(x)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_flatten2D(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.flatten(x, 1)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_isnan(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.isnan(x)
return out
x = torch.tensor([1, float('nan'), 2])
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_argmax(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.argmax(x, dim=1)
return out
x = torch.randn(4, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_pow(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x.pow(y)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
y = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_repeat(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.repeat(1, 2, 3, 4)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_repeat_dim_overflow(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.repeat(1, 2, 3, 4)
return out
x = torch.randn(1, 2, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_norm_p1(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.norm(p=1, dim=2)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_norm_p2(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.norm(p=2, dim=2)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_upsample_nearest_size(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.nn.functional.interpolate(x, size=16, mode='nearest')
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_unsqueeze(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.unsqueeze(len(x.shape))
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_implicit_expand(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x + 1
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_reduce_sum_negative_indices(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.sum(-1)
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_randn(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.randn(1, 2, 3, 4) + x
return out
x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_rand(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.rand(1, 2, 3, 4) + x
return out
x = torch.rand(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_empty_like(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.empty_like(x)
return out
x = torch.randn(5, 8, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_empty_like_opset7(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.empty_like(x)
return out
x = torch.randn(5, 8, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_zeros_like(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.zeros_like(x)
return out
x = torch.randn(5, 8, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_ones_like(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.ones_like(x)
return out
x = torch.randn(6, 10, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_expand(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.expand(4, 6, 2)
return out
x = torch.randn(6, 1, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_ne(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = torch.ne(x, y)
return out
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
y = torch.randn(1, 4, requires_grad=False).int()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_reducemax(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.max(x)
return out
x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_reducemin(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.min(x)
return out
x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_erf(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.erf()
return out
x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_dropout(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.max(torch.nn.functional.dropout(x, training=False))
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_dropout_default(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.max(torch.nn.functional.dropout(x,))
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ), check_value=False)
def test_basic_dropout_training(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.max(torch.nn.functional.dropout(x))
return out
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ), check_value=False)
def test_basic_nonzero(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.nonzero(x)
return out
x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_gather(self):
class SimpleOp(nn.Module):
def forward(self, data, index):
out = data.gather(1, index)
return out
data = torch.randn(3, 4, 3, requires_grad=True)
index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
self.checkExportImport(SimpleOp(), (data, index, ))
def test_basic_gather_opset11(self):
class SimpleOp(nn.Module):
def forward(self, data, index):
out = data.gather(1, index)
return out
data = torch.randn(3, 4, 3, requires_grad=True)
index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
self.checkExportImport(SimpleOp(), (data, index, ))
def test_basic_scatter_add(self):
class SimpleOp(nn.Module):
def forward(self, data, indices, values):
out = data.scatter_add(1, indices, values)
return out
data = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
self.checkExportImport(SimpleOp(), (data, indices, values, ))
def test_basic_scatter_add_opset11(self):
class SimpleOp(nn.Module):
def forward(self, data, indices, values):
out = data.scatter_add(1, indices, values)
return out
data = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
self.checkExportImport(SimpleOp(), (data, indices, values, ))
def test_basic_master_opset(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = x + y
return out
x = torch.randn(2, 3).float()
y = torch.randn(2, 3).float()
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_std(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.std(x, dim=(0, 1), unbiased=True, keepdim=True)
return out
x = torch.randn(2, 3, 4).float()
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_cumsum(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.cumsum(x, dim=1)
return out
x = torch.randn(2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_pixel_shuffle(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.pixel_shuffle(x, upscale_factor=2)
return out
x = torch.randn(2, 8, 3, 4).float()
self.checkExportImport(SimpleOp(), (x, ))
@unittest.skip('skip as torch.norm is called with prim::CallFunction, also torch.norm is deprecated')
def test_basic_frobenius_norm(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.norm(x, p="fro", dim=(0, 1), keepdim=True)
return out
x = torch.randn(2, 3, 4).float()
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_unfold(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.unfold(dimension=2, size=2, step=2)
return out
x = torch.randn(2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_remainder(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = torch.remainder(x, y)
return out
x = torch.randn(2, 3, 4)
y = torch.randn(2, 1, 4)
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_fmod(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = torch.fmod(x, y)
return out
x = torch.randn(2, 3, 4)
y = torch.randn(2, 1, 4)
self.checkExportImport(SimpleOp(), (x, y, ))
def test_basic_gelu(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.nn.functional.gelu(x)
return out
x = torch.randn(2, 3, 4, 5, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
@unittest.skip('skip as it is called with prim::CallFunction, and unknown func definition')
def test_basic_unique(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.unique(x, dim=0, sorted=True, return_inverse=False, return_counts=True)
return out
x = torch.randint(3, (2, 3, 4, 5)).float()
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_meshgrid(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out = torch.meshgrid(x, y, z)
return out
x = torch.ones(3, requires_grad=True)
y = torch.zeros(4, requires_grad=True)
z = torch.ones(5, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, y, z, ))
def test_basic_topk(self):
class SimpleOp(nn.Module):
def forward(self, x, k):
out = torch.topk(x, k)
return out
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
self.checkExportImport(SimpleOp(), (x, k, ))
def test_basic_topk_smallest_unsorted(self):
class SimpleOp(nn.Module):
def forward(self, x, k):
out = torch.topk(x, k, largest=False, sorted=False)
return out
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
self.checkExportImport(SimpleOp(), (x, k, ))
def test_basic_baddbmm(self):
class SimpleOp(nn.Module):
def forward(self, x, b1, b2):
out = torch.baddbmm(x, b1, b2)
return out
x = torch.randn(10, 3, 5)
b1 = torch.randn(10, 3, 4)
b2 = torch.randn(10, 4, 5)
self.checkExportImport(SimpleOp(), (x, b1, b2, ))
def test_basic_round(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.round(x)
return out
x = torch.tensor([0.9920, -1.0362, -1.5000, 2.5000], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_dim(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.scalar_tensor(x.dim())
return out
x = torch.ones((2, 2), requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_det(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.det(x)
return out
x = torch.randn(2, 3, 5, 5, device=torch.device('cpu'))
self.checkExportImport(SimpleOp(), (x, ))
# the followings are more complex tests
def test_mm(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out = torch.mm(x, y)
return out
m1 = torch.randn(2, 3, requires_grad=True)
m2 = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (m1, m2))
def test_basic_pad(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.ReflectionPad2d((2, 3, 0, 1))
def forward(self, x):
out = self.m(x)
return out
x = torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_batchnorm(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.BatchNorm2d(2)
def forward(self, x):
out = self.m(x)
return out
x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_batchnorm_1d(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.BatchNorm1d(2)
def forward(self, x):
out = self.m(x)
return out
x = torch.ones(2, 2, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_conv(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.Conv2d(16, 13, 3, bias=False)
def forward(self, x):
out = self.m(x)
return out
x = torch.ones(20, 16, 50, 40, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_conv_onnx_irv4_opset8(self):
# This test point checks that for opset 8 (or lower), even if
# keep_initializers_as_inputs is set to False, it is ignored,
# and initializers are listed as ONNX graph input, in accordance
# with ONNX IR v3 semantics (which apply to opset version <= 8).
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.Conv2d(2, 4, 3, bias=False)
self.m.weight.data.fill_(1.0)
def forward(self, x):
out = self.m(x)
return out
x = torch.ones(1, 2, 5, 7, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_convtranspose(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.ConvTranspose2d(3, 3, 3, stride=3, bias=False,
padding=1, output_padding=2)
def forward(self, x):
out = self.m(x)
return out
x = torch.ones(2, 3, 4, 5, requires_grad=True)
self.checkExportImport(SimpleOp(), (x,))
def test_basic_maxpool(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool1d(3, stride=2)
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_maxpool_dilations(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool1d(2, stride=1, dilation=2)
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_avg_pool2d(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.AvgPool2d(3, stride=2)
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(20, 16, 50, 32)
self.checkExportImport(SimpleOp(), (x, ))
@unittest.skip('jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"')
def test_basic_maxpool_indices(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool1d(3, stride=2, return_indices=True)
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, ))
@unittest.skip("jit error: Tried to access nonexistent attribute or method 'at' of type '__torch__.test_convert_operators.MyFun'")
def test_at_op(self):
from torch.autograd import Function
x = torch.randn(3, 4)
class MyFun(Function):
@staticmethod
def symbolic(g, x):
return g.at("add", x, x)
@staticmethod
def forward(ctx, x):
return x + x
class MyModule(nn.Module):
def forward(self, x):
return MyFun.apply(x)
self.checkExportImport(MyModule(), x)
def test_basic_logsoftmax(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.LogSoftmax(dim=3)
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_elu(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.ELU()
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_selu(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.SELU()
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_upsample_nearest_scale(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.nn.functional.interpolate(x, scale_factor=2.,
mode='nearest', recompute_scale_factor=False)
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_upsample_nearest_scale_default_scale_factor(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.nn.functional.interpolate(x, scale_factor=2.,
mode='nearest')
return out
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_batchnorm_noaffine(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.BatchNorm2d(128, affine=False, momentum=0.3)
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(128, 128, 1, 1, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
def test_embedding_bags(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.EmbeddingBag(10, 8)
def forward(self, x, y):
out = self.m(x, y)
return out
input = torch.tensor([1, 2, 3, 4]).long()
offset = torch.tensor([0]).long()
self.checkExportImport(SimpleOp(), (input, offset, ))
def test_basic_rrelu(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.RReLU()
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_prelu(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.PReLU(2)
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_log_sigmoid(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.LogSigmoid()
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_basic_linear(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.Linear(4, 5, bias=True)
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(3, 4)
self.checkExportImport(SimpleOp(), (x, ))
def test_retain_param_name_disabled(self):
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.fc1 = nn.Linear(4, 5, bias=False)
self.fc1.weight.data.fill_(2.)
self.fc2 = nn.Linear(5, 6, bias=False)
self.fc2.weight.data.fill_(3.)
def forward(self, x):
return self.fc2(self.fc1(x))
x = torch.randn(3, 4).float()
self.checkExportImport(MyModule(), (x, ))
@unittest.skip('Segmentation fault')
def test_dict(self):
class MyModel(nn.Module):
def forward(self, x_in: Dict):
x_out = {}
x_out["test_key_out"] = torch.add(x_in[list(x_in.keys())[0]], list(x_in.keys())[0])
return x_out
x = {torch.tensor(1.): torch.randn(1, 2, 3)}
self.checkExportImport(MyModel(), (x, ))
def test_arange_dynamic(self):
class TestModel(nn.Module):
def forward(self, input):
out = torch.arange(input.shape[0], input.shape[0] + 5, 0.5)
return out
input = torch.randn(5, 3, 2)
self.checkExportImport(TestModel(), (input, ))
def test_bitshift(self):
class BitshiftModel(nn.Module):
def forward(self, input, input2):
return input >> 1, input2 >> 2
input = torch.arange(24, dtype=torch.float32).reshape(3, 4, 2)
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
self.checkExportImport(BitshiftModel(), (input, input2, ))
def test_layer_norm_aten(self):
class SimpleOp(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.LayerNorm([10, 10])
def forward(self, x):
out = self.m(x)
return out
x = torch.randn(20, 5, 10, 10)
self.checkExportImport(SimpleOp(), (x, ))
\ No newline at end of file
'''
The tests in this file is copied and transformed from
https://github.com/pytorch/pytorch/blob/master/test/onnx/test_pytorch_onnx_onnxruntime.py
'''
import os
import sys
import unittest
from typing import (Dict)
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
class TestPytorch(unittest.TestCase):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn()
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
if check_value:
try:
self.assertEqual(len(converted_output), len(expected_output))
for a, b in zip(converted_output, expected_output):
torch.eq(a, b)
except:
self.assertEqual(converted_output, expected_output)
return converted_model
def test_embedding_model_with_external_data(self):
class LargeModel(nn.Module):
def __init__(self):
super(LargeModel, self).__init__()
dim = 15
n = 4 * 100
self.emb = nn.Embedding(n, dim)
self.lin1 = nn.Linear(dim, 1)
self.seq = nn.Sequential(
self.emb,
self.lin1,
)
def forward(self, input):
return self.seq(input)
model = LargeModel()
x = torch.tensor([2], dtype=torch.long)
self.run_test(model, (x, ))
@unittest.skip('skip for now, as it needs inject_nn')
def test_mobilenet_v2_with_external_data(self):
model = torchvision.models.mobilenet_v2(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
# We are turning off Onnx Runtime optimization off in this test,
# because external data format is not supported to in ORT optimizer.
# Once that support is added, we can set ort_optim_on=True (default).
self.run_test(model, (x, ))
def test_attribute_with_external_data(self):
class LargeModel(nn.Module):
def forward(self, x):
return x + torch.ones(2, 1024)
x = torch.randn(2, 1)
self.run_test(LargeModel(), (x, ))
@unittest.skip('skip as it has loop')
def test_subgraph_with_external_data(self):
class LargeModel(nn.Module):
def forward(self, x):
for i in range(x.size(0)):
x = x + torch.ones(2, 1024)
return x
x = torch.randn(2, 1)
self.run_test((LargeModel()), (x, ))
def test_fuse_conv_bn1d(self):
class Fuse(nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = nn.Conv1d(16, 33, 3, stride=2)
self.bn = nn.BatchNorm1d(33)
def forward(self, x):
out = self.conv(x)
return self.bn(out)
model = Fuse()
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(model, (x,))
def test_fuse_conv_bn2d(self):
class Fuse(nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(2)
def forward(self, x):
out = self.conv(x)
return self.bn(out)
model = Fuse()
x = torch.randn(2, 3, 2, 2, requires_grad=True)
self.run_test(model, (x,))
def test_fuse_conv_bn3d(self):
class Fuse(nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = nn.Conv3d(3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False)
self.bn = nn.BatchNorm3d(2)
def forward(self, x):
out = self.conv(x)
return self.bn(out)
model = Fuse()
x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
self.run_test(model, (x,))
@unittest.skip('have not supported register_buffer yet')
def test_reshape_constant_fold(self):
class Reshape(nn.Module):
def __init__(self, ):
super(Reshape, self).__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
scale_1 = self.weight.reshape(1, -1, 1, 1)
return x * scale_1
x = torch.randn(4, 5)
self.run_test(Reshape(), (x,))
def run_word_language_model(self, model_name):
ntokens = 50
emsize = 5
nhid = 5
nlayers = 5
dropout = 0.2
tied = False
batchsize = 5
model = word_language_model.RNNModel(model_name, ntokens, emsize,
nhid, nlayers, dropout, tied,
batchsize)
x = torch.arange(0, ntokens).long().view(-1, batchsize)
# Only support CPU version, since tracer is not working in GPU RNN.
self.run_test(model, (x, model.hidden))
def get_image_from_url(self, url, size=(300, 200)):
import os
from urllib.parse import urlsplit
from urllib import request
from PIL import Image
from torchvision import transforms
from torch._utils_internal import get_writable_path
filename = os.path.basename(urlsplit(url)[2])
data_dir = get_writable_path(os.path.join(os.path.dirname(__file__)))
path = os.path.join(data_dir, filename)
data = request.urlopen(url, timeout=15).read()
with open(path, 'wb') as f:
f.write(data)
image = Image.open(path).convert("RGB")
image = image.resize(size, Image.BILINEAR)
to_tensor = transforms.ToTensor()
return to_tensor(image)
def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url, size=(100, 320))
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2, size=(250, 380))
return [image], [image2]
@unittest.skip('does not support `if A and/or B`')
def test_faster_rcnn(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
model.eval()
x = torch.randn(2, 3, 200, 300, requires_grad=True)
self.run_test(model, (x,))
dummy_image = [torch.ones(3, 100, 100) * 0.3]
images, test_images = self.get_test_images()
self.run_test(model, (images,))
self.run_test(model, (dummy_image,))
@unittest.skip('does not support `if A and/or B`')
def test_mask_rcnn(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
images, test_images = self.get_test_images()
self.run_test(model, (images,))
dummy_image = [torch.ones(3, 100, 100) * 0.3]
self.run_test(model, (dummy_image,))
@unittest.skip('does not support `if A and/or B`')
def test_keypoint_rcnn(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
model = torchvision.models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
images, test_images = self.get_test_images()
self.run_test(model, (images,))
dummy_images = [torch.ones(3, 100, 100) * 0.3]
self.run_test(model, (dummy_images,))
def test_shufflenet_v2_dynamic_axes(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
model = torchvision.models.shufflenet_v2_x0_5(pretrained=True)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True)
self.run_test(model, (dummy_input,))
@unittest.skip('')
def test_word_language_model_RNN_TANH(self):
self.run_word_language_model("RNN_TANH")
@unittest.skip('')
def test_word_language_model_RNN_RELU(self):
self.run_word_language_model("RNN_RELU")
@unittest.skip('')
def test_word_language_model_LSTM(self):
self.run_word_language_model("LSTM")
@unittest.skip('')
def test_word_language_model_GRU(self):
self.run_word_language_model("GRU")
def test_index_1d(self):
class MyModel(nn.Module):
def forward(self, input):
return input[0]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1, ))
def test_index_2d_1dimslice(self):
class MyModel(nn.Module):
def forward(self, input):
return input[0:1, :]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1, ))
def test_index_2d_sliceint(self):
class MyModel(nn.Module):
def forward(self, input):
return input[1, :]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1, ))
def test_index_2d_neg_slice(self):
class MyModel(nn.Module):
def forward(self, input):
return input[0:-1, :]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1, ))
def test_index_mask(self):
class MyModel(nn.Module):
def forward(self, input):
return input[torch.tensor([0, 1, 0], dtype=torch.uint8)]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1, ))
class MyModel(nn.Module):
def forward(self, input):
return input[torch.tensor([0, 1, 0], dtype=torch.bool)]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1, ))
def test_data(self):
class Data(nn.Module):
def forward(self, x):
return x.new_zeros(x.data.size())
x = torch.randn(3, 4)
self.run_test(Data(), (x, ))
def test_index_mask_nd(self):
class MyModel(nn.Module):
def forward(self, input):
return input[input > 0]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1, ))
@unittest.skip("Tried to access nonexistent attribute or method 'keys' of type 'Tensor (inferred)'.")
def test_dict(self):
class MyModel(nn.Module):
def forward(self, x_in):
x_out = {}
x_out["test_key_out"] = torch.add(x_in[list(x_in.keys())[0]], list(x_in.keys())[0])
return x_out
x = {torch.tensor(1.): torch.randn(1, 2, 3)}
self.run_test(MyModel(), (x, {}))
@unittest.skip("Unsupported operation: indexing tensor with unsupported index type 'str'.")
def test_dict_str(self):
class MyModel(nn.Module):
def forward(self, x_in):
x_out = {}
x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.)
return x_out
x = {"test_key_in": torch.randn(1, 2, 3)}
self.run_test(MyModel(), (x, {}))
@unittest.skip('Convert graph error')
def test_optional_inputs_with_no_optionals(self):
class NoOptionalModel(nn.Module):
def forward(self, input):
return input
# Without empty optional arguments dictionary
x = torch.randn(2, 3)
self.run_test(NoOptionalModel(), (x,))
# With empty optional arguments dictionary
y = torch.randn(2, 3)
self.run_test(NoOptionalModel(), (y, {}))
# NOTE: torch script gets an incorrect graph...
def test_optional_inputs_with_mixed_optionals(self):
class MixedModel(nn.Module):
def forward(self, x: 'Tensor', y: 'Tensor', z: 'Tensor'):
if y is not None:
return x + y
if z is not None:
return x + z
return x
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)
# Without optional arguments dictionary
self.run_test(MixedModel(), (x, y, None))
#self.run_test(MixedModel(), (x, None, z, ))
# With optional arguments dictionary
#self.run_test(MixedModel(), (x, {'y': y, 'z': None}))
#self.run_test(MixedModel(), (x, {'y': None, 'z': z}))
#self.run_test(MixedModel(), (x, {'z': z}))
#self.run_test(MixedModel(), (x, {'y': y}))
@unittest.skip('torch script gets an incorrect graph...')
def test_optional_inputs_with_all_optionals(self):
class AllOptionalModel(nn.Module):
def forward(self, y, z):
if y is not None:
return y
if z is not None:
return z
y = torch.randn(2, 3)
# Without optional arguments dictionary
self.run_test(AllOptionalModel(), (y, None))
# With optional arguments dictionary
#self.run_test(AllOptionalModel(), {'y': y, 'z': None})
@unittest.skip('torch script gets an incorrect graph...')
def test_none_as_input(self):
class Model(nn.Module):
def forward(self, x, y):
if y is not None:
return x + y
return x
x = torch.randn(2, 3)
self.run_test(Model(), (x, None))
@unittest.skip('jit cannot correctly deal with tuple as input argument')
def test_none_as_tuple_input(self):
class Model(nn.Module):
def forward(self, x, y):
if y[0] is not None:
return x + y[0]
if y[1] is not None:
return x + y[1]
return x
x = torch.randn(2, 3)
y = torch.randn(2, 3)
self.run_test(Model(), (x, (None, y)))
def test_cste_script(self):
class MyModel(nn.Module):
def forward(self, x):
return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64)
x = torch.randn(3, 4)
self.run_test(MyModel(), (x, ))
def test_scalar_tensor(self):
class test(nn.Module):
def forward(self, input):
return torch.scalar_tensor(input.size(0)), \
torch.scalar_tensor(input.size(1), dtype=torch.int64)
x = torch.randn(2, 3, 4)
y = torch.randn(7, 8, 9)
model = test()
self.run_test(model, (x, ))
def test_tensor(self):
class ScalarInputModel(nn.Module):
def forward(self, input):
return torch.tensor(input.shape[1])
x = torch.randn(3, 4)
self.run_test(ScalarInputModel(), (x, ))
class TensorInputModel(nn.Module):
def forward(self, input):
return torch.tensor([input.shape[0], input.shape[1]])
x = torch.randn(3, 4)
self.run_test(TensorInputModel(), (x, ))
class FloatInputModel(nn.Module):
def forward(self, input):
return torch.tensor([float(input)])
x = torch.randn(1)
self.run_test(FloatInputModel(), (x, ))
class InputWithDtypeModel(nn.Module):
def forward(self, input):
return torch.tensor(input.shape[1], dtype=torch.long)
x = torch.randn(3, 4)
self.run_test(InputWithDtypeModel(), (x, ))
class MixedInputModel(nn.Module):
def forward(self, input):
return torch.tensor([input.shape[0], int(input)])
x = torch.randn(1)
self.run_test(MixedInputModel(), (x, ))
def test_hardtanh(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.Hardtanh(-1.5, 2.5)
def forward(self, x):
return self.m(x)
x = torch.arange(-5, 5).to(dtype=torch.float32)
self.run_test(MyModel(), (x, ))
def test_hardtanh_script_with_default_values(self):
class MyModel(nn.Module):
def forward(self, x):
return F.hardtanh(x)
x = torch.arange(-5, 5).to(dtype=torch.float32)
self.run_test(MyModel(), (x, ))
def test_hardswish(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.Hardswish()
def forward(self, x):
return self.m(x)
x = torch.rand(3, 3).to(dtype=torch.float32)
self.run_test(MyModel(), (x, ))
# Testing edge cases
x = torch.tensor(3).to(dtype=torch.float32)
self.run_test(MyModel(), (x, ))
x = torch.tensor(-3).to(dtype=torch.float32)
self.run_test(MyModel(), (x, ))
def test_hardswish_script(self):
class MyModel(nn.Module):
def forward(self, x):
return F.hardswish(x)
x = torch.rand(3, 3).to(dtype=torch.float32)
self.run_test(MyModel(), (x, ))
def test_clamp(self):
class ClampModel(nn.Module):
def forward(self, x):
return x.clamp(-0.5, 0.5)
x = torch.randn(3, 4)
self.run_test(ClampModel(), (x, ))
class ClampMinModel(nn.Module):
def forward(self, x):
return x.clamp(min=-0.5)
x = torch.randn(3, 4)
self.run_test(ClampMinModel(), (x, ))
class ClampMaxModel(nn.Module):
def forward(self, x):
return x.clamp(max=0.5)
x = torch.randn(3, 4)
self.run_test(ClampMaxModel(), (x, ))
def test_clamp_dyn(self):
class ClampMaxModel(nn.Module):
def forward(self, x):
return x.clamp(None, x.size(0))
x = torch.arange(16).view(4, 4).float()
self.run_test(ClampMaxModel(), (x, ))
class ClampMinModel(nn.Module):
def forward(self, x):
return x.clamp(x.size(0), None)
x = torch.arange(16).view(4, 4).float()
self.run_test(ClampMinModel(), (x, ))
class ClampMinMaxModel(nn.Module):
def forward(self, x):
return x.clamp(x.size(0), x.size(1))
x = torch.arange(16).view(2, 8).float()
self.run_test(ClampMinMaxModel(), (x, ))
def test_full_trace(self):
class FullModel(nn.Module):
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
x = torch.tensor(12)
self.run_test(FullModel(), (x, ))
def test_full_script(self):
class FullModelScripting(nn.Module):
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
x = torch.tensor(12)
self.run_test(FullModelScripting(), (x, ))
def test_fuse_addmm(self):
class AddmmModel(nn.Module):
def forward(self, x):
return torch.mm(x, x) + x
x = torch.ones(3, 3)
self.run_test(AddmmModel(), (x, ))
def test_maxpool(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool1d(2, stride=1)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50)
self.run_test(MyModel(), (x, ))
def test_conv(self):
class TraceModel(nn.Module):
def __init__(self):
super(TraceModel, self).__init__()
self.conv1 = nn.Conv1d(16, 33, 3, stride=2)
self.conv2 = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
x1 = torch.randn(20, 16, 50)
x2 = torch.randn(20, 16, 50, 100)
x3 = torch.randn(20, 16, 10, 50, 100)
self.run_test(TraceModel(), (x1, x2, x3, ))
def test_conv_shape_inference(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv2 = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
def forward(self, input):
return self.conv2(input) + 2
x = torch.randn(20, 16, 50, 100)
self.run_test(Model(), (x, ))
def test_conv_transpose(self):
class TraceModel(nn.Module):
def __init__(self):
super(TraceModel, self).__init__()
self.conv1 = nn.ConvTranspose1d(16, 33, 3, stride=2)
self.conv2 = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
x1 = torch.randn(20, 16, 50)
x2 = torch.randn(20, 16, 50, 100)
x3 = torch.randn(20, 16, 10, 50, 100)
self.run_test(TraceModel(), (x1, x2, x3, ))
# Conversion of Transpose depends on input shape to be known.
# The following test only works when onnx shape inference is enabled.
def test_transpose_infer_shape(self):
class TransposeModule(nn.Module):
def __init__(self):
super(TransposeModule, self).__init__()
self.conv = nn.Conv2d(3, 1, 3, stride=2)
def forward(self, x):
x = self.conv(x)
return x.transpose(0, 1)
x = torch.randn(32, 3, 64, 64)
y = torch.randn(16, 3, 8, 64)
self.run_test(TransposeModule(), (x, ))
def squeeze_model_tests(self, d, x1):
class Squeeze(nn.Module):
def __init__(self, d):
super(Squeeze, self).__init__()
self.d = d
def forward(self, x):
if self.d is not None:
return torch.squeeze(x, dim=self.d)
else:
return torch.squeeze(x)
self.run_test(Squeeze(d), (x1, ))
def test_squeeze_without_no_op(self):
x = torch.randn(2, 1, 4)
self.squeeze_model_tests(1, x)
def test_squeeze_neg_without_no_op(self):
x = torch.randn(2, 1, 4)
self.squeeze_model_tests(-2, x)
def test_squeeze_all_dims(self):
x_squeeze = torch.randn(2, 1, 4)
self.squeeze_model_tests(None, x_squeeze)
def test_squeeze_no_op(self):
x_noop = torch.randn(2, 1, 4)
self.squeeze_model_tests(2, x_noop)
def test_squeeze_runtime_dim(self):
class Squeeze(nn.Module):
def forward(self, d1, d2):
t = torch.zeros(d1[0], d2[0])
return t.squeeze(0)
d1 = torch.tensor([1])
d3 = torch.tensor([3])
d4 = torch.tensor([4])
self.run_test(Squeeze(), (d1, d4))
self.run_test(Squeeze(), (d3, d4))
def test_squeeze(self):
class Squeeze(nn.Module):
def forward(self, x):
return torch.squeeze(x, dim=-2)
x = torch.randn(2, 1, 4)
self.run_test(Squeeze(), (x, ))
def test_unsqueeze(self):
class Unsqueeze(nn.Module):
def forward(self, x):
return torch.unsqueeze(x, dim=-2)
x = torch.randn(2, 3, 4)
self.run_test(Unsqueeze(), (x, ))
def test_maxpool_default_stride(self):
class MaxPoolModel(nn.Module):
def forward(self, x):
return F.max_pool2d(x, 2)
model = MaxPoolModel()
x = torch.randn(10, 20, 16, 50)
self.run_test(model, (x, ))
def test_maxpool_adaptive(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.AdaptiveMaxPool1d((5), return_indices=False)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(MyModel(), (x, ))
def test_maxpool_2d(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool2d(5, padding=(1, 2))
def forward(self, x):
return self.m(x)
x = torch.randn(1, 20, 16, 50, requires_grad=True)
self.run_test(MyModel(), (x, ))
def test_maxpool_1d_ceil(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool1d(3, 2, ceil_mode=True)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50)
self.run_test(MyModel(), (x, ))
def test_maxpool_2d_ceil(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool2d(3, 2, ceil_mode=True)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50, 32)
self.run_test(MyModel(), (x, ))
def test_maxpool_3d_ceil(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool3d(3, 2, ceil_mode=True)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(MyModel(), (x, ))
@unittest.skip('jit error: Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]')
def test_maxpool_with_indices(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool1d(2, stride=1, return_indices=True)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50)
self.run_test(MyModel(), (x, ))
def test_maxpool_dilation(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool1d(2, stride=1, dilation=2)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50)
self.run_test(MyModel(), (x, ))
def test_avgpool_default_stride(self):
class AvgPoolModel(nn.Module):
def forward(self, x):
return F.avg_pool2d(x, 2)
model = AvgPoolModel()
x = torch.randn(10, 20, 16, 50)
self.run_test(model, (x, ))
def test_avgpool(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.AvgPool1d(2, stride=1)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50)
self.run_test(MyModel(), (x, ))
def test_avgpool_1d_ceil(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.AvgPool1d(3, 2, ceil_mode=True)
def forward(self, x):
return self.m(x)
x = torch.randn(1, 1, 7)
self.run_test(MyModel(), (x, ))
def test_avgpool_2d_ceil(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.AvgPool2d(3, 2, ceil_mode=True)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50, 32)
self.run_test(MyModel(), (x, ))
def test_avgpool_3d_ceil(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.AvgPool3d(3, 2, ceil_mode=True)
def forward(self, x):
return self.m(x)
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(MyModel(), (x, ))
@unittest.skip('Unsupported op type aten::is_floating_point in if condition')
def test_floating_point(self):
class FloatingPoint(nn.Module):
def forward(self, x):
if x.is_floating_point():
return x.new_zeros(x.shape)
return x.new_zeros(x.shape)
x = torch.randn(2, 3, 4)
self.run_test(FloatingPoint(), (x, ))
class FloatingPoint(nn.Module):
def forward(self, x):
if x.size(0) > 1:
a = x + 2
if a.is_floating_point():
return x + 1
return x + 1
return x
x = torch.randn(2, 3, 4)
self.run_test(FloatingPoint(), (x, ))
# Operator rank mismatch between outputs of two branches for opsets below 11.
@unittest.skip('Unsupported op type aten::size in if condition')
def test_floating_point_infer_dtype(self):
class FloatingPoint(nn.Module):
def forward(self, x):
if x.size(0) > 1:
a = x + 2
if a.is_floating_point():
return x.new_zeros(x.shape[1:])
return x.new_zeros(x.shape)
return x
x = torch.randn(2, 3, 4)
self.run_test(FloatingPoint(), (x, ))
class FloatingPoint(nn.Module):
def forward(self, x):
if x.size(0) > 1:
a = x + 2
if a.is_floating_point():
return x + 1
return x
return x
x = torch.randn(2, 3, 4).to(torch.int32)
self.run_test(FloatingPoint(), (x, ))
def test_arithmetic(self):
class ArithmeticModule(nn.Module):
def forward(self, x):
x = x + 2
x = x - 4
x = x * 6
x = x / 8
return x
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), (x, ))
# In scripting the first transpose node do not carry shape and dtype info.
# The following test only works when onnx shape inference is enabled.
def test_arithmetic_infer_dtype(self):
class ArithmeticModule(nn.Module):
def forward(self, x):
x = x.t()
x = x + 2
x = x - 4
x = x * 6
x = x / 8
return x
x = torch.randn(2, 3)
self.run_test(ArithmeticModule(), (x, ))
@unittest.skip('tensor op type aten::to has more than one matched')
def test_floor_div(self):
class FloorDivModule(nn.Module):
def forward(self, x, y):
return x // 3, x // 2., \
x.to(dtype=torch.float64) // 3, x.to(dtype=torch.float64) // 2., \
x.to(dtype=torch.int64) // 3, x.to(dtype=torch.int64) // 2., \
x // (y + 1.).to(dtype=torch.int64), x // y, \
x.to(dtype=torch.float64) // y.to(dtype=torch.int64), x.to(dtype=torch.float64) // y.to(dtype=torch.float64), \
x.to(dtype=torch.int64) // y.to(dtype=torch.int64), x.to(dtype=torch.int64) // y
x = torch.randn(2, 3, 4)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))
def test_floor_div_script(self):
class FloorDivModule(nn.Module):
def forward(self, x, y):
return x // 3, x // 2., x // y
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))
def test_floordiv(self):
class FloordivModule(nn.Module):
def forward(self, x):
return x.new_zeros(x.size(2) // x.size(1))
x = torch.randn(2, 3, 4)
self.run_test(FloordivModule(), (x,))
def test_div(self):
class DivModule(nn.Module):
def forward(self, x, y):
return torch.true_divide(x, y)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
self.run_test(DivModule(), (x, y))
self.run_test(DivModule(), (x.float(), y.float()))
# Note: div cannot (generally) be exported via scripting
# since its type promotion logic is dependent on knowing the scalar types
# of the input tensors. That is, the ONNX graph is dependent on the
# data type of the inputs. This makes it appropriate for tracing only.
def test_div_promotion_trace(self):
class DivModule(nn.Module):
def forward(self, x, y):
return torch.true_divide(x, y)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
prev_default = torch.get_default_dtype()
torch.set_default_dtype(torch.float)
self.run_test(DivModule(), (x, y))
torch.set_default_dtype(torch.double)
self.run_test(DivModule(), (x, y))
torch.set_default_dtype(prev_default)
# In scripting x, y do not carry shape and dtype info.
# The following test only works when onnx shape inference is enabled.
def test_div_promotion_script(self):
class DivModule(nn.Module):
def forward(self, x, y):
# Add transpose to hide shape/type information
# Otherwise shape and type are still avaiable from input.
x = x.transpose(1, 2)
y = y.transpose(1, 2)
return torch.true_divide(x, y)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
prev_default = torch.get_default_dtype()
# 1. x,y are int, and output is float.
# This can be handled by the default case, where both are cast to float.
# It works even if type of x, y are unknown.
torch.set_default_dtype(torch.float)
self.run_test((DivModule()), (x, y))
# 2. x,y are int, and output is double.
# This can be handled by the default case, where both are cast to double.
# It works even if type of x, y are unknown.
torch.set_default_dtype(torch.double)
self.run_test((DivModule()), (x, y))
# 3. x is int, y is double, and output is double.
# This can only be handled when both type of x and y are known.
torch.set_default_dtype(prev_default)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
self.run_test((DivModule()), (x, y))
def test_slice_trace(self):
class MyModule(nn.Module):
def forward(self, x):
return x[0:1]
x = torch.randn(3)
self.run_test(MyModule(), (x, ))
def test_slice_neg(self):
class NegSlice(nn.Module):
def forward(self, x):
return x[-1:]
x = torch.randn(3, 4, 5)
self.run_test(NegSlice(), (x, ))
def test_slice_neg_large(self):
class NegSlice(nn.Module):
def forward(self, x):
return x[:, :, -3:-1, :, -1]
x = torch.randn(3, 4, 5, 6, 7)
self.run_test(NegSlice(), (x, ))
def test_slice_neg_large_negone(self):
class NegSlice(nn.Module):
def forward(self, x):
return x[:, :, :, :, -1]
x = torch.randn(3, 4, 5, 6, 7)
self.run_test(NegSlice(), (x, ))
@unittest.skip('strange torch script graph')
def test_slice_with_input_index(self):
class InputIndexSlice(nn.Module):
def forward(self, x, y):
x[:y.size(0), 0, :] = y
return x
x = torch.zeros((56, 6, 256))
y = torch.rand((22, 256))
self.run_test(InputIndexSlice(), (x, y))
@unittest.skip('Loop has not been supported yet!')
def test_slice_dynamic(self):
class DynamicSliceExportMod(nn.Module):
def forward(self, x):
results = []
for i in range(4):
results.append(x[:x.size(0) - i, i:x.size(2), i:3])
return results
x = torch.rand(5, 5, 5)
y = torch.randn(6, 7, 8)
self.run_test(DynamicSliceExportMod(), (x, ))
def test_slice_dynamic_script(self):
class DynamicSliceModel(nn.Module):
def forward(self, x):
return x[1:x.size(1)]
x = torch.rand(1, 2)
self.run_test(DynamicSliceModel(), (x, ))
def test_slice_dynamic_shape_script(self):
class DynamicSliceModel(nn.Module):
def forward(self, x):
return x.new_zeros(x.shape[1:x.size(2)])
x = torch.rand(1, 2, 3, 4)
self.run_test(DynamicSliceModel(), (x, ))
@unittest.skip('Loop has not been supported yet!')
def test_slice_dynamic_to_end(self):
class DynamicSliceExportMod(nn.Module):
def forward(self, x):
results = []
for i in range(4):
results.append(x[:, i:, x.size(2) - 5])
return results
x = torch.rand(5, 5, 5)
self.run_test(DynamicSliceExportMod(), (x, ))
def test_square(self):
class Square(nn.Module):
def forward(self, x):
return torch.square(x)
x = torch.randn(2, 3, 4)
self.run_test(Square(), (x, ))
def test_arange_dynamic(self):
class ArangeModel(nn.Module):
def forward(self, input):
return torch.arange(input.shape[0]), \
torch.arange(12), \
torch.arange(start=input.shape[0], end=input.shape[0] + 5)
x = torch.randn(5, 3, 2)
y = torch.randn(8, 3, 2)
self.run_test(ArangeModel(), (x, ))
@unittest.skip('mismatched aten::arange definition, does not support `out`')
def test_dynamic_arange_out(self):
class ArangeOutModel(nn.Module):
def forward(self, end):
out_t = torch.tensor([1], dtype=torch.int64)
return torch.arange(end, out=out_t)
x = torch.tensor(8)
self.run_test(ArangeOutModel(), (x, ))
@unittest.skip('mismatched aten::arange definition, does not support `out`')
def test_dynamic_arange_start_out(self):
class ArangeStartOutModel(nn.Module):
def forward(self, start, end):
out_t = torch.tensor([1], dtype=torch.int64)
return torch.arange(start.size(0), end, out=out_t)
x = torch.randn(2, 3, 4)
y = torch.tensor(8)
self.run_test(ArangeStartOutModel(), (x, y))
def test_arange(self):
class ArangeModel(nn.Module):
def forward(self, start, end):
return torch.arange(start.size(0), end, 1.5, dtype=torch.int64)
x = torch.randn(2, 3, 4)
y = torch.tensor(8.5, dtype=torch.float)
self.run_test(ArangeModel(), (x, y))
@unittest.skip('mismatched aten::arange definition, does not support `out`')
def test_arange_out(self):
class ArangeOutModel(nn.Module):
def forward(self, end):
out_t = torch.tensor([1], dtype=torch.float)
return torch.arange(end, out=out_t)
x = torch.tensor(8.5, dtype=torch.float)
self.run_test(ArangeOutModel(), (x, ))
@unittest.skip('mismatched aten::arange definition, does not support `out`')
def test_arange_start_out(self):
class ArangeStartOutModel(nn.Module):
def forward(self, start, end):
out_t = torch.tensor([1], dtype=torch.float)
return torch.arange(start.size(0), end, out=out_t)
x = torch.randn(2, 3, 4)
y = torch.tensor(8.5, dtype=torch.float)
self.run_test(ArangeStartOutModel(), (x, y))
def test_arange_no_type(self):
class ArangeModel(nn.Module):
def forward(self, end):
return torch.arange(end), \
torch.arange(0, end)
x = torch.tensor(6.2, dtype=torch.float)
self.run_test(ArangeModel(), (x, ))
def test_size(self):
class SizeModel(nn.Module):
def forward(self, input):
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)
x = torch.randn(5, 3, 2)
self.run_test(SizeModel(), (x, ))
def test_size2(self):
class SizeModel(nn.Module):
def __init__(self, a, b):
super().__init__()
self.a = a
self.b = b
def forward(self, input):
if self.a < self.b:
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)
x = torch.randn(5, 3, 2)
self.run_test(SizeModel(10, 5), (x, ))
\ No newline at end of file
......@@ -167,6 +167,7 @@ class TestHighLevelAPI(unittest.TestCase):
mutator = mutators[0].bind_sampler(EnuemrateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment