Unverified Commit 58d5c2fa authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] refactor of pytorch operators (#3365)

parent 59521d33
import logging import logging
from typing import List from typing import List, Tuple, Any
from ..graph import IllegalGraphError, Edge, Graph, Node, Model from ..graph import IllegalGraphError, Edge, Graph, Node, Model
...@@ -32,9 +32,26 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]: ...@@ -32,9 +32,26 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> List[str]: def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]:
"""
Format the inputs of a given node
Parameters
----------
node : Node
a graph node, get and format its inputs
Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges = _sorted_incoming_edges(node) edges = _sorted_incoming_edges(node)
inputs = [] inputs = []
inputs_value = []
for edge in edges: for edge in edges:
if edge.head.name == '_inputs': if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int) assert isinstance(edge.head_slot, int)
...@@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]: ...@@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]:
else: else:
# when input has no name, e.g., forward(*_inputs) # when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot)) inputs.append('_inputs[{}]'.format(edge.head_slot))
inputs_value.append(None)
else: else:
if edge.head_slot is None: if edge.head_slot is None:
# when the input comes from a single-output operator # when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name)) inputs.append('{}'.format(edge.head.name))
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
else:
inputs_value.append(None)
else: else:
# when the input comes from a multi-output operator: needs to know which one it comes from # when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot)) inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs inputs_value.append(None)
return inputs, inputs_value
def _remove_prefix(names, graph_name): def _remove_prefix(names, graph_name):
...@@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_codes = [] node_codes = []
for node in nodes: for node in nodes:
if node.operation: if node.operation:
if node.operation.type == 'shared':
continue
pkg_name = node.operation.get_import_pkg() pkg_name = node.operation.get_import_pkg()
if pkg_name is not None: if pkg_name is not None:
import_pkgs.add(pkg_name) import_pkgs.add(pkg_name)
...@@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
sorted_nodes = graph.topo_sort() sorted_nodes = graph.topo_sort()
for node in sorted_nodes: for node in sorted_nodes:
if node.operation: if node.operation:
inputs = _format_inputs(node) inputs, inputs_value = _format_inputs(node)
inputs = _remove_prefix(inputs, graph_name) inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name) node_name = _remove_prefix(node.name, graph_name)
edge_codes.append(node.operation.to_forward_code(node_name, node_name, inputs)) submodule_name = node_name
if node.operation.type == 'shared':
submodule_name = _remove_prefix(node.operation.parameters['reference'], graph_name)
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))
output_names = _format_inputs(graph.output_node) output_names, _ = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name) output_names = _remove_prefix(output_names, graph_name)
if not output_names: if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node)) raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
......
This diff is collapsed.
...@@ -9,34 +9,8 @@ class OpTypeName(str, Enum): ...@@ -9,34 +9,8 @@ class OpTypeName(str, Enum):
""" """
Attr = 'Attr' Attr = 'Attr'
Constant = 'Constant' Constant = 'Constant'
ListConstruct = 'ListConstruct'
TupleConstruct = 'TupleConstruct'
LayerChoice = 'LayerChoice' LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice' InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice' ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder' Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice' MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View',
'aten::reshape': 'Reshape',
'aten::eq': 'Eq',
'aten::Bool': 'Bool',
'aten::empty': 'Empty',
'aten::zeros': 'Zeros',
'aten::chunk': 'Chunk',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF = {}
...@@ -45,7 +45,7 @@ class ValueChoiceMutator(Mutator): ...@@ -45,7 +45,7 @@ class ValueChoiceMutator(Mutator):
chosen = self.choice(self.candidates) chosen = self.choice(self.candidates)
for node in self.nodes: for node in self.nodes:
target = model.get_node_by_name(node.name) target = model.get_node_by_name(node.name)
target.update_operation('prim::Constant', {'value': chosen}) target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
......
...@@ -83,6 +83,31 @@ class Operation: ...@@ -83,6 +83,31 @@ class Operation:
class PyTorchOperation(Operation): class PyTorchOperation(Operation):
@classmethod
def _find_subclass(cls, subclass_name):
if cls.to_class_name(subclass_name) is not None:
subclass_name = 'ModuleOperator'
if cls.is_functional(subclass_name):
subclass_name = 'FunctionalOperator'
for subclass in cls.__subclasses__():
if hasattr(subclass, '_ori_type_name') and \
subclass_name in subclass._ori_type_name:
return subclass
return cls
@classmethod
def to_class_name(cls, type_name) -> str:
if type_name.startswith('__torch__.'):
return type_name[len('__torch__.'):]
elif type_name.startswith('__mutated__.'):
return type_name[len('__mutated__.'):]
else:
return None
@classmethod
def is_functional(cls, type_name) -> bool:
return type_name.startswith('Function.')
def _to_class_name(self) -> str: def _to_class_name(self) -> str:
if self.type.startswith('__torch__.'): if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):] return self.type[len('__torch__.'):]
...@@ -106,59 +131,27 @@ class PyTorchOperation(Operation): ...@@ -106,59 +131,27 @@ class PyTorchOperation(Operation):
return f'self.{field} = {self._to_class_name()}({kw_params})' return f'self.{field} = {self._to_class_name()}({kw_params})'
return None return None
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
from .converter.op_types import OpTypeName """
if self._to_class_name() is not None: Parameters
return f'{output} = self.{field}({", ".join(inputs)})' ----------
elif self.type.startswith('Function.'): field : str
func_name = self.type[len('Function.'):] the name of member submodule
return f'{output} = F.{func_name}({", ".join(inputs)})' output : str
elif self.type == 'prim::Constant': the output name (lvalue) of this line of code
if self.parameters: inputs : List[str]
value = self.parameters['value'] variables used in this line of code
else: inputs_value : List[Any]
value = None some variables are actually constant, their real values are recorded in ```inputs_value```.
return f'{output} = {value}' if not constant, we simply put None at the corresponding index
elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]' Returns
elif self.type == 'prim::TupleConstruct': -------
return f'{output} = ({", ".join(inputs)})' str
elif self.type == 'prim::GetAttr': generated code line
return f"{output} = {self.parameters['input']}.{self.parameters['name']}" """
elif self.type == 'aten::mean': if self.type == 'aten::slice':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
elif self.type == 'aten::__getitem__':
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
elif self.type == 'aten::append':
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
elif self.type == 'aten::cat':
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
elif self.type == 'aten::add':
return f'{output} = ' + ' + '.join(inputs)
elif self.type == OpTypeName.MergedSlice:
assert (len(inputs) - 1) % 4 == 0
slices = []
dim = int((len(inputs) - 1) / 4)
for i in range(dim):
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif self.type == 'aten::size':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.size({inputs[1]})'
elif self.type == 'aten::view':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.view({inputs[1]})'
elif self.type == 'aten::reshape':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.reshape({inputs[1]})'
elif self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation') raise RuntimeError('not supposed to have aten::slice operation')
elif self.type == 'aten::Bool':
return f'{output} = bool({inputs[0]})'
else: else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}') raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
...@@ -212,6 +205,8 @@ class Cell(PyTorchOperation): ...@@ -212,6 +205,8 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part # TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name) return _convert_name(self.cell_name)
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class _IOPseudoOperation(Operation): class _IOPseudoOperation(Operation):
""" """
......
This diff is collapsed.
...@@ -162,6 +162,14 @@ def _get_module_name(cls): ...@@ -162,6 +162,14 @@ def _get_module_name(cls):
f'please launch the experiment under the directory where "{main_file_path.name}" is located.') f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem module_name = main_file_path.stem
break break
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if f'{cls.__module__}.{cls.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls.__module__
return module_name return module_name
......
...@@ -250,7 +250,9 @@ stages: ...@@ -250,7 +250,9 @@ stages:
- script: | - script: |
cd test cd test
python -m pytest ut python -m pytest ut --ignore=ut/retiarii/test_convert_basic.py \
--ignore=ut/retiarii/test_convert_operators.py \
--ignore=ut/retiarii/test_convert_pytorch.py
displayName: Python unit test displayName: Python unit test
- script: | - script: |
......
import inspect
import logging
import torch
import torch.nn as nn
from nni.retiarii.utils import add_record, del_record, version_larger_equal
_logger = logging.getLogger(__name__)
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
original_class.bak_init_for_inject = orig_init
if hasattr(original_class, '__del__'):
orig_del = original_class.__del__
original_class.bak_del_for_inject = orig_del
else:
orig_del = None
original_class.bak_del_for_inject = None
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
def __del__(self):
del_record(id(self))
if orig_del is not None:
orig_del(self)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__del__ = __del__
return original_class
def unwrap_module(wrapped_class):
if hasattr(wrapped_class, 'bak_init_for_inject'):
wrapped_class.__init__ = wrapped_class.bak_init_for_inject
delattr(wrapped_class, 'bak_init_for_inject')
if hasattr(wrapped_class, 'bak_del_for_inject'):
if wrapped_class.bak_del_for_inject is not None:
wrapped_class.__del__ = wrapped_class.bak_del_for_inject
delattr(wrapped_class, 'bak_del_for_inject')
return None
def remove_inject_pytorch_nn():
Identity = unwrap_module(nn.Identity)
Linear = unwrap_module(nn.Linear)
Conv1d = unwrap_module(nn.Conv1d)
Conv2d = unwrap_module(nn.Conv2d)
Conv3d = unwrap_module(nn.Conv3d)
ConvTranspose1d = unwrap_module(nn.ConvTranspose1d)
ConvTranspose2d = unwrap_module(nn.ConvTranspose2d)
ConvTranspose3d = unwrap_module(nn.ConvTranspose3d)
Threshold = unwrap_module(nn.Threshold)
ReLU = unwrap_module(nn.ReLU)
Hardtanh = unwrap_module(nn.Hardtanh)
ReLU6 = unwrap_module(nn.ReLU6)
Sigmoid = unwrap_module(nn.Sigmoid)
Tanh = unwrap_module(nn.Tanh)
Softmax = unwrap_module(nn.Softmax)
Softmax2d = unwrap_module(nn.Softmax2d)
LogSoftmax = unwrap_module(nn.LogSoftmax)
ELU = unwrap_module(nn.ELU)
SELU = unwrap_module(nn.SELU)
CELU = unwrap_module(nn.CELU)
GLU = unwrap_module(nn.GLU)
GELU = unwrap_module(nn.GELU)
Hardshrink = unwrap_module(nn.Hardshrink)
LeakyReLU = unwrap_module(nn.LeakyReLU)
LogSigmoid = unwrap_module(nn.LogSigmoid)
Softplus = unwrap_module(nn.Softplus)
Softshrink = unwrap_module(nn.Softshrink)
MultiheadAttention = unwrap_module(nn.MultiheadAttention)
PReLU = unwrap_module(nn.PReLU)
Softsign = unwrap_module(nn.Softsign)
Softmin = unwrap_module(nn.Softmin)
Tanhshrink = unwrap_module(nn.Tanhshrink)
RReLU = unwrap_module(nn.RReLU)
AvgPool1d = unwrap_module(nn.AvgPool1d)
AvgPool2d = unwrap_module(nn.AvgPool2d)
AvgPool3d = unwrap_module(nn.AvgPool3d)
MaxPool1d = unwrap_module(nn.MaxPool1d)
MaxPool2d = unwrap_module(nn.MaxPool2d)
MaxPool3d = unwrap_module(nn.MaxPool3d)
MaxUnpool1d = unwrap_module(nn.MaxUnpool1d)
MaxUnpool2d = unwrap_module(nn.MaxUnpool2d)
MaxUnpool3d = unwrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = unwrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = unwrap_module(nn.FractionalMaxPool3d)
LPPool1d = unwrap_module(nn.LPPool1d)
LPPool2d = unwrap_module(nn.LPPool2d)
LocalResponseNorm = unwrap_module(nn.LocalResponseNorm)
BatchNorm1d = unwrap_module(nn.BatchNorm1d)
BatchNorm2d = unwrap_module(nn.BatchNorm2d)
BatchNorm3d = unwrap_module(nn.BatchNorm3d)
InstanceNorm1d = unwrap_module(nn.InstanceNorm1d)
InstanceNorm2d = unwrap_module(nn.InstanceNorm2d)
InstanceNorm3d = unwrap_module(nn.InstanceNorm3d)
LayerNorm = unwrap_module(nn.LayerNorm)
GroupNorm = unwrap_module(nn.GroupNorm)
SyncBatchNorm = unwrap_module(nn.SyncBatchNorm)
Dropout = unwrap_module(nn.Dropout)
Dropout2d = unwrap_module(nn.Dropout2d)
Dropout3d = unwrap_module(nn.Dropout3d)
AlphaDropout = unwrap_module(nn.AlphaDropout)
FeatureAlphaDropout = unwrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = unwrap_module(nn.ReflectionPad1d)
ReflectionPad2d = unwrap_module(nn.ReflectionPad2d)
ReplicationPad2d = unwrap_module(nn.ReplicationPad2d)
ReplicationPad1d = unwrap_module(nn.ReplicationPad1d)
ReplicationPad3d = unwrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = unwrap_module(nn.CrossMapLRN2d)
Embedding = unwrap_module(nn.Embedding)
EmbeddingBag = unwrap_module(nn.EmbeddingBag)
RNNBase = unwrap_module(nn.RNNBase)
RNN = unwrap_module(nn.RNN)
LSTM = unwrap_module(nn.LSTM)
GRU = unwrap_module(nn.GRU)
RNNCellBase = unwrap_module(nn.RNNCellBase)
RNNCell = unwrap_module(nn.RNNCell)
LSTMCell = unwrap_module(nn.LSTMCell)
GRUCell = unwrap_module(nn.GRUCell)
PixelShuffle = unwrap_module(nn.PixelShuffle)
Upsample = unwrap_module(nn.Upsample)
UpsamplingNearest2d = unwrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = unwrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = unwrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = unwrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = unwrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = unwrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = unwrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = unwrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = unwrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = unwrap_module(nn.TripletMarginLoss)
ZeroPad2d = unwrap_module(nn.ZeroPad2d)
ConstantPad1d = unwrap_module(nn.ConstantPad1d)
ConstantPad2d = unwrap_module(nn.ConstantPad2d)
ConstantPad3d = unwrap_module(nn.ConstantPad3d)
Bilinear = unwrap_module(nn.Bilinear)
CosineSimilarity = unwrap_module(nn.CosineSimilarity)
Unfold = unwrap_module(nn.Unfold)
Fold = unwrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = unwrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = unwrap_module(nn.TransformerEncoder)
TransformerDecoder = unwrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = unwrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = unwrap_module(nn.TransformerDecoderLayer)
Transformer = unwrap_module(nn.Transformer)
Flatten = unwrap_module(nn.Flatten)
Hardsigmoid = unwrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = unwrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = unwrap_module(nn.SiLU)
Unflatten = unwrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = unwrap_module(nn.TripletMarginWithDistanceLoss)
def inject_pytorch_nn():
Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear)
Conv1d = wrap_module(nn.Conv1d)
Conv2d = wrap_module(nn.Conv2d)
Conv3d = wrap_module(nn.Conv3d)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d)
Threshold = wrap_module(nn.Threshold)
ReLU = wrap_module(nn.ReLU)
Hardtanh = wrap_module(nn.Hardtanh)
ReLU6 = wrap_module(nn.ReLU6)
Sigmoid = wrap_module(nn.Sigmoid)
Tanh = wrap_module(nn.Tanh)
Softmax = wrap_module(nn.Softmax)
Softmax2d = wrap_module(nn.Softmax2d)
LogSoftmax = wrap_module(nn.LogSoftmax)
ELU = wrap_module(nn.ELU)
SELU = wrap_module(nn.SELU)
CELU = wrap_module(nn.CELU)
GLU = wrap_module(nn.GLU)
GELU = wrap_module(nn.GELU)
Hardshrink = wrap_module(nn.Hardshrink)
LeakyReLU = wrap_module(nn.LeakyReLU)
LogSigmoid = wrap_module(nn.LogSigmoid)
Softplus = wrap_module(nn.Softplus)
Softshrink = wrap_module(nn.Softshrink)
MultiheadAttention = wrap_module(nn.MultiheadAttention)
PReLU = wrap_module(nn.PReLU)
Softsign = wrap_module(nn.Softsign)
Softmin = wrap_module(nn.Softmin)
Tanhshrink = wrap_module(nn.Tanhshrink)
RReLU = wrap_module(nn.RReLU)
AvgPool1d = wrap_module(nn.AvgPool1d)
AvgPool2d = wrap_module(nn.AvgPool2d)
AvgPool3d = wrap_module(nn.AvgPool3d)
MaxPool1d = wrap_module(nn.MaxPool1d)
MaxPool2d = wrap_module(nn.MaxPool2d)
MaxPool3d = wrap_module(nn.MaxPool3d)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d)
LPPool1d = wrap_module(nn.LPPool1d)
LPPool2d = wrap_module(nn.LPPool2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm)
BatchNorm1d = wrap_module(nn.BatchNorm1d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d)
LayerNorm = wrap_module(nn.LayerNorm)
GroupNorm = wrap_module(nn.GroupNorm)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm)
Dropout = wrap_module(nn.Dropout)
Dropout2d = wrap_module(nn.Dropout2d)
Dropout3d = wrap_module(nn.Dropout3d)
AlphaDropout = wrap_module(nn.AlphaDropout)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d)
Embedding = wrap_module(nn.Embedding)
EmbeddingBag = wrap_module(nn.EmbeddingBag)
RNNBase = wrap_module(nn.RNNBase)
RNN = wrap_module(nn.RNN)
LSTM = wrap_module(nn.LSTM)
GRU = wrap_module(nn.GRU)
RNNCellBase = wrap_module(nn.RNNCellBase)
RNNCell = wrap_module(nn.RNNCell)
LSTMCell = wrap_module(nn.LSTMCell)
GRUCell = wrap_module(nn.GRUCell)
PixelShuffle = wrap_module(nn.PixelShuffle)
Upsample = wrap_module(nn.Upsample)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = wrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
...@@ -35,16 +35,29 @@ class MnistNet(nn.Module): ...@@ -35,16 +35,29 @@ class MnistNet(nn.Module):
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
# NOTE: blackbox module cannot be placed within class or function
@blackbox_module
class Linear(nn.Module):
def __init__(self, d_embed, d_proj):
super().__init__()
self.linear = nn.Linear(d_embed, d_proj)
def forward(self, input):
if len(input.size()) <= 2:
return self.linear(input)
size = input.size()[:2]
out = self.linear(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase): class TestConvert(unittest.TestCase):
@staticmethod @staticmethod
def _match_state_dict(current_values, expected_format): def _match_state_dict(current_values, expected_format):
result = {} result = {}
for k, v in expected_format.items(): for k, v in expected_format.items():
for cv in current_values: for idx, cv in enumerate(current_values):
if cv.shape == v.shape: if cv.shape == v.shape:
result[k] = cv result[k] = cv
current_values.remove(cv) current_values.pop(idx)
break break
return result return result
...@@ -53,6 +66,9 @@ class TestConvert(unittest.TestCase): ...@@ -53,6 +66,9 @@ class TestConvert(unittest.TestCase):
model_ir = convert_to_graph(script_module, model) model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn()
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
...@@ -134,18 +150,17 @@ class TestConvert(unittest.TestCase): ...@@ -134,18 +150,17 @@ class TestConvert(unittest.TestCase):
model = DCGANGenerator(nz, ngf, nc) model = DCGANGenerator(nz, ngf, nc)
self.checkExportImport(model, input) self.checkExportImport(model, input)
@unittest.skip('this test has a if condition that needs to be handle') # FIXME
def test_neural_style(self): def test_neural_style(self):
class TransformerNet(torch.nn.Module): class TransformerNet(nn.Module):
def __init__(self): def __init__(self):
super(TransformerNet, self).__init__() super(TransformerNet, self).__init__()
# Initial convolution layers # Initial convolution layers
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1) self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1 = torch.nn.InstanceNorm2d(32, affine=True) self.in1 = nn.InstanceNorm2d(32, affine=True)
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2) self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.in2 = torch.nn.InstanceNorm2d(64, affine=True) self.in2 = nn.InstanceNorm2d(64, affine=True)
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2) self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
self.in3 = torch.nn.InstanceNorm2d(128, affine=True) self.in3 = nn.InstanceNorm2d(128, affine=True)
# Residual layers # Residual layers
self.res1 = ResidualBlock(128) self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128) self.res2 = ResidualBlock(128)
...@@ -154,12 +169,12 @@ class TestConvert(unittest.TestCase): ...@@ -154,12 +169,12 @@ class TestConvert(unittest.TestCase):
self.res5 = ResidualBlock(128) self.res5 = ResidualBlock(128)
# Upsampling Layers # Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2) self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
self.in4 = torch.nn.InstanceNorm2d(64, affine=True) self.in4 = nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2) self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
self.in5 = torch.nn.InstanceNorm2d(32, affine=True) self.in5 = nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1) self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
# Non-linearities # Non-linearities
self.relu = torch.nn.ReLU() self.relu = nn.ReLU()
def forward(self, X): def forward(self, X):
y = self.relu(self.in1(self.conv1(X))) y = self.relu(self.in1(self.conv1(X)))
...@@ -175,19 +190,19 @@ class TestConvert(unittest.TestCase): ...@@ -175,19 +190,19 @@ class TestConvert(unittest.TestCase):
y = self.deconv3(y) y = self.deconv3(y)
return y return y
class ConvLayer(torch.nn.Module): class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride): def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer, self).__init__() super(ConvLayer, self).__init__()
reflection_padding = kernel_size // 2 reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x): def forward(self, x):
out = self.reflection_pad(x) out = self.reflection_pad(x)
out = self.conv2d(out) out = self.conv2d(out)
return out return out
class ResidualBlock(torch.nn.Module): class ResidualBlock(nn.Module):
"""ResidualBlock """ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385 introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
...@@ -196,10 +211,10 @@ class TestConvert(unittest.TestCase): ...@@ -196,10 +211,10 @@ class TestConvert(unittest.TestCase):
def __init__(self, channels): def __init__(self, channels):
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True) self.in1 = nn.InstanceNorm2d(channels, affine=True)
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True) self.in2 = nn.InstanceNorm2d(channels, affine=True)
self.relu = torch.nn.ReLU() self.relu = nn.ReLU()
def forward(self, x): def forward(self, x):
residual = x residual = x
...@@ -208,7 +223,7 @@ class TestConvert(unittest.TestCase): ...@@ -208,7 +223,7 @@ class TestConvert(unittest.TestCase):
out = out + residual out = out + residual
return out return out
class UpsampleConvLayer(torch.nn.Module): class UpsampleConvLayer(nn.Module):
"""UpsampleConvLayer """UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d. compared to ConvTranspose2d.
...@@ -219,10 +234,10 @@ class TestConvert(unittest.TestCase): ...@@ -219,10 +234,10 @@ class TestConvert(unittest.TestCase):
super(UpsampleConvLayer, self).__init__() super(UpsampleConvLayer, self).__init__()
self.upsample = upsample self.upsample = upsample
if upsample: if upsample:
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample) self.upsample_layer = nn.Upsample(mode='nearest', scale_factor=upsample)
reflection_padding = kernel_size // 2 reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x): def forward(self, x):
x_in = x x_in = x
...@@ -254,50 +269,40 @@ class TestConvert(unittest.TestCase): ...@@ -254,50 +269,40 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(Policy(), (torch.rand(1, 4),)) self.checkExportImport(Policy(), (torch.rand(1, 4),))
@unittest.skip('Replaced init error.') # FIXME
def test_snli(self): def test_snli(self):
class Bottle(nn.Module):
def forward(self, input):
if len(input.size()) <= 2:
return super(Bottle, self).forward(input)
size = input.size()[:2]
out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)
class Linear(Bottle, nn.Linear):
pass
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.config = config #self.config = config
input_size = config.d_proj if config.projection else config.d_embed input_size = config["d_proj"] if config["projection"] else config["d_embed"]
dropout = 0 if config.n_layers == 1 else config.dp_ratio dropout = 0 if config["n_layers"] == 1 else config["dp_ratio"]
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden, self.rnn = nn.LSTM(input_size=input_size, hidden_size=config["d_hidden"],
num_layers=config.n_layers, dropout=dropout, num_layers=config["n_layers"], dropout=dropout,
bidirectional=config.birnn) bidirectional=config["birnn"])
self.n_cells = config["n_cells"]
self.d_hidden = config["d_hidden"]
self.birnn = config["birnn"]
def forward(self, inputs): def forward(self, inputs):
batch_size = inputs.size()[1] batch_size = inputs.size()[1]
state_shape = self.config.n_cells, batch_size, self.config.d_hidden state_shape = self.n_cells, batch_size, self.d_hidden
h0 = c0 = inputs.new_zeros(state_shape) h0 = c0 = inputs.new_zeros(state_shape)
outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1) return ht[-1] if not self.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
class SNLIClassifier(nn.Module): class SNLIClassifier(nn.Module):
def __init__(self, config): def __init__(self, config):
super(SNLIClassifier, self).__init__() super(SNLIClassifier, self).__init__()
self.config = config self.embed = nn.Embedding(config["n_embed"], config["d_embed"])
self.embed = nn.Embedding(config.n_embed, config.d_embed) self.projection = Linear(config["d_embed"], config["d_proj"])
self.projection = Linear(config.d_embed, config.d_proj)
self.encoder = Encoder(config) self.encoder = Encoder(config)
self.dropout = nn.Dropout(p=config.dp_ratio) self.dropout = nn.Dropout(p=config["dp_ratio"])
self.relu = nn.ReLU() self.relu = nn.ReLU()
seq_in_size = 2 * config.d_hidden seq_in_size = 2 * config["d_hidden"]
if self.config.birnn: if config["birnn"]:
seq_in_size *= 2 seq_in_size *= 2
lin_config = [seq_in_size] * 2 lin_config = [seq_in_size] * 2
self.out = nn.Sequential( self.out = nn.Sequential(
...@@ -310,15 +315,17 @@ class TestConvert(unittest.TestCase): ...@@ -310,15 +315,17 @@ class TestConvert(unittest.TestCase):
Linear(*lin_config), Linear(*lin_config),
self.relu, self.relu,
self.dropout, self.dropout,
Linear(seq_in_size, config.d_out)) Linear(seq_in_size, config["d_out"]))
self.fix_emb = config["fix_emb"]
self.project = config["projection"]
def forward(self, premise, hypothesis): def forward(self, premise, hypothesis):
prem_embed = self.embed(premise) prem_embed = self.embed(premise)
hypo_embed = self.embed(hypothesis) hypo_embed = self.embed(hypothesis)
if self.config.fix_emb: if self.fix_emb:
prem_embed = prem_embed.detach() prem_embed = prem_embed.detach()
hypo_embed = hypo_embed.detach() hypo_embed = hypo_embed.detach()
if self.config.projection: if self.project:
prem_embed = self.relu(self.projection(prem_embed)) prem_embed = self.relu(self.projection(prem_embed))
hypo_embed = self.relu(self.projection(hypo_embed)) hypo_embed = self.relu(self.projection(hypo_embed))
premise = self.encoder(prem_embed) premise = self.encoder(prem_embed)
...@@ -326,23 +333,24 @@ class TestConvert(unittest.TestCase): ...@@ -326,23 +333,24 @@ class TestConvert(unittest.TestCase):
scores = self.out(torch.cat([premise, hypothesis], 1)) scores = self.out(torch.cat([premise, hypothesis], 1))
return scores return scores
class Config: Config = {
n_embed = 100 "n_embed": 100,
d_embed = 100 "d_embed": 100,
d_proj = 300 "d_proj": 300,
dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace? "dp_ratio": 0.0, # For deterministic testing TOD": change by fixing seed in checkTrace?,
d_hidden = 30 "d_hidden": 30,
birnn = True "birnn": True,
d_out = 300 "d_out": 300,
fix_emb = True "fix_emb": True,
projection = True "projection": True,
n_layers = 2 "n_layers": 2,
n_cells = 4 # 2 * n_layers because birnn = True "n_cells": 4 # 2 * n_layers because birnn = True,
}
premise = torch.LongTensor(48, 64).random_(0, 100) premise = torch.LongTensor(48, 64).random_(0, 100)
hypothesis = torch.LongTensor(24, 64).random_(0, 100) hypothesis = torch.LongTensor(24, 64).random_(0, 100)
self.checkExportImport(SNLIClassifier(Config()), (premise, hypothesis)) self.checkExportImport(SNLIClassifier(Config), (premise, hypothesis))
def test_super_resolution(self): def test_super_resolution(self):
class Net(nn.Module): class Net(nn.Module):
...@@ -367,16 +375,16 @@ class TestConvert(unittest.TestCase): ...@@ -367,16 +375,16 @@ class TestConvert(unittest.TestCase):
net = Net(upscale_factor=4) net = Net(upscale_factor=4)
self.checkExportImport(net, (torch.rand(5, 1, 32, 32),)) self.checkExportImport(net, (torch.rand(5, 1, 32, 32),))
@unittest.skip('Need to support operator prim::ListUnpack') # FIXME @unittest.skip('Need to support Loop') # FIXME
def test_time_sequence_prediction(self): def test_time_sequence_prediction(self):
class Sequence(torch.jit.ScriptModule): class Sequence(nn.Module): #torch.jit.ScriptModule
def __init__(self): def __init__(self):
super(Sequence, self).__init__() super(Sequence, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51) self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51) self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1) self.linear = nn.Linear(51, 1)
@torch.jit.script_method #@torch.jit.script_method
def forward(self, input): def forward(self, input):
# TODO: add future as input with default val # TODO: add future as input with default val
# see https://github.com/pytorch/pytorch/issues/8724 # see https://github.com/pytorch/pytorch/issues/8724
...@@ -414,7 +422,7 @@ class TestConvert(unittest.TestCase): ...@@ -414,7 +422,7 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(Traced(), (torch.rand(3, 4),)) self.checkExportImport(Traced(), (torch.rand(3, 4),))
@unittest.skip('Unsupported callmethod encode') # FIXME @unittest.skip('incorrectly assigned weights') # FIXME
def test_vae(self): def test_vae(self):
class VAE(nn.Module): class VAE(nn.Module):
def __init__(self): def __init__(self):
...@@ -449,11 +457,11 @@ class TestConvert(unittest.TestCase): ...@@ -449,11 +457,11 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(VAE().eval(), (torch.rand(128, 1, 28, 28),)) self.checkExportImport(VAE().eval(), (torch.rand(128, 1, 28, 28),))
@unittest.skip('torchvision models are not supported yet') # FIXME
def test_torchvision_resnet18(self): def test_torchvision_resnet18(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
self.checkExportImport(torchvision.models.resnet18().eval(), (torch.ones(1, 3, 224, 224),)) self.checkExportImport(torchvision.models.resnet18().eval(), (torch.ones(1, 3, 224, 224),))
@unittest.skip('Unsupported CallMethod _forward_impl') # FIXME
def test_resnet(self): def test_resnet(self):
def conv1x1(in_planes, out_planes, stride=1): def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution""" """1x1 convolution"""
...@@ -464,7 +472,7 @@ class TestConvert(unittest.TestCase): ...@@ -464,7 +472,7 @@ class TestConvert(unittest.TestCase):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False) padding=1, bias=False)
class BasicBlock(torch.jit.ScriptModule): class BasicBlock(nn.Module): #torch.jit.ScriptModule
expansion = 1 expansion = 1
__constants__ = ['downsample'] __constants__ = ['downsample']
...@@ -478,7 +486,8 @@ class TestConvert(unittest.TestCase): ...@@ -478,7 +486,8 @@ class TestConvert(unittest.TestCase):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
@torch.jit.script_method # NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def forward(self, x): def forward(self, x):
residual = x residual = x
...@@ -497,7 +506,8 @@ class TestConvert(unittest.TestCase): ...@@ -497,7 +506,8 @@ class TestConvert(unittest.TestCase):
return out return out
class ResNet(torch.jit.ScriptModule): # NOTE: cannot inherit torch.jit.ScriptModule, otherwise, there would be error: 'RecursiveScriptModule' object has no attribute 'graph'
class ResNet(nn.Module): #torch.jit.ScriptModule
__constants__ = ['layer1', 'layer2', 'layer3', 'layer4'] __constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
def __init__(self, block, layers, num_classes=1000): def __init__(self, block, layers, num_classes=1000):
...@@ -538,7 +548,8 @@ class TestConvert(unittest.TestCase): ...@@ -538,7 +548,8 @@ class TestConvert(unittest.TestCase):
return nn.Sequential(*layers) return nn.Sequential(*layers)
@torch.jit.script_method # NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
...@@ -558,10 +569,11 @@ class TestConvert(unittest.TestCase): ...@@ -558,10 +569,11 @@ class TestConvert(unittest.TestCase):
resnet18 = ResNet(BasicBlock, [2, 2, 2, 2]) resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
self.checkExportImport(torchvision.models.resnet18().eval(), (torch.randn(1, 3, 224, 224),)) self.checkExportImport(resnet18, (torch.randn(1, 3, 224, 224),))
@unittest.skip('torchvision models are not supported yet') # FIXME
def test_alexnet(self): def test_alexnet(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
x = torch.ones(1, 3, 224, 224) x = torch.ones(1, 3, 224, 224)
model = torchvision.models.AlexNet() model = torchvision.models.AlexNet()
self.checkExportImport(model, (x,)) self.checkExportImport(model, (x,))
import os
import sys
import unittest
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
# following pytorch v1.7.1
class TestConvert(unittest.TestCase):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
if check_value:
self.assertEqual(len(converted_output), len(expected_output))
for a, b in zip(converted_output, expected_output):
if hasattr(a, 'dtype') and a.dtype == torch.bool:
self.assertEqual((a ^ b), False)
elif isinstance((a - b), int):
self.assertEqual((a - b), 0)
else:
self.assertLess((a - b).abs().max().item(), 1E-4)
return converted_model
# skip torch.Tensor.new_tensor as it is not supported by jit
def test_basic_new_full(self):
class SimpleOp(nn.Module):
def forward(self, x):
# requires_grad is not supported by jit
# aten::new_full(Tensor self, int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor):
# Keyword argument requires_grad unknown.
out = x.new_full((3, 4), 3.141592, dtype=torch.float32, device=torch.device('cpu'))
return out
self.checkExportImport(SimpleOp(), (torch.ones((2,), dtype=torch.float64), ))
def test_basic_new_empty(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.new_empty((2, 3), dtype=torch.int8, device=torch.device('cpu'))
return out
self.checkExportImport(SimpleOp(), (torch.ones(()), ), check_value=False)
# skip torch.Tensor.new_ones as it is not supported by jit
# requires_grad=False is not supported by jit
def test_basic_new_zeros(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = x.new_zeros((2, 3))
return out
self.checkExportImport(SimpleOp(), (torch.tensor((), dtype=torch.int32), ))
def test_basic_is_cuda(self):
class SimpleOp(nn.Module):
def forward(self, x):
return torch.tensor([x.is_cuda], dtype=torch.bool, device=torch.device('cpu'))
self.checkExportImport(SimpleOp(), (torch.tensor((), dtype=torch.int32), ))
# is_quantized
# is_meta
# device
# grad
# ndim
# T
# real
# imag
def test_basic_abs(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.abs()
out11 = x.absolute()
out2 = torch.abs(x)
#out3 = x.abs_()
#out33 = x.absolute_()
return out1, out11, out2#, out3, out33
self.checkExportImport(SimpleOp(), (torch.tensor([-1, -2, 3]), ))
# TODO: topological sort should be improved
#def forward(self, x__1):
# __Acos2 = x__1.acos()
# __Acos_3 = x__1.acos_()
# __Acos1 = x__1.acos()
# __TupleConstruct4 = (__Acos1,__Acos2,__Acos_3)
# return __TupleConstruct4
def test_basic_acos_asin_atan(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out1 = x.acos()
out2 = torch.acos(x)
# TODO: add back this line
#out = x.acos_()
out3 = x.asin()
out4 = torch.asin(x)
out5 = x.atan()
out6 = torch.atan(x)
out7 = x.atan2(y)
out8 = torch.atan2(x, y)
return out1, out2, out3, out4, out5, out6, out7, out8#, out
self.checkExportImport(SimpleOp(), (torch.tensor([-1.0, -0.5, 0.2]), torch.tensor([1.0, 0.6, -0.3]), ))
# arccos is not supported by jit
def test_basic_add(self):
class SimpleOp(nn.Module):
def forward(self, x):
t = torch.tensor([-1.0, -0.5, 0.2])
out1 = x.add(t)
out2 = x.add(t, alpha=2)
#out3 = x.add_(t)
return out1, out2#, out3
self.checkExportImport(SimpleOp(), (torch.tensor([-1.0, -0.5, 0.2]), ))
def test_basic_addbmm(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z, m):
out1 = x.addbmm(y, z, beta=2, alpha=3)
out2 = torch.addbmm(x, y, z, beta=2, alpha=3)
#out3 = x.addbmm_(y, z, beta=2, alpha=3)
out3 = m.baddbmm(y, z, beta=2, alpha=3)
out4 = torch.baddbmm(m, y, z, beta=2, alpha=3)
out5 = torch.bmm(y, z) # deterministic is not supported by jit
return out1, out2, out3, out4, out5
self.checkExportImport(SimpleOp(), (torch.randn(3, 5), torch.randn(10, 3, 4), torch.randn(10, 4, 5), torch.randn(10, 3, 5), ))
def test_basic_addcdiv(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addcdiv(y, z, value=2)
out2 = torch.addcdiv(x, y, z, value=2)
# addcdiv_
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), ))
def test_basic_addcmul(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addcmul(y, z, value=0.1)
out2 = torch.addcmul(x, y, z, value=0.1)
# addcmul_
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), ))
def test_basic_addmm(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addmm(y, z, beta=0.1, alpha=0.2)
out2 = torch.addmm(x, y, z, beta=0.1, alpha=0.2)
# addmm_
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3), ))
def test_basic_addmv(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addmv(y, z, beta=0.1, alpha=0.2)
out2 = torch.addmv(x, y, z, beta=0.1, alpha=0.2)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(2), torch.randn(2, 3), torch.randn(3), ))
def test_basic_addr(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
out1 = x.addr(y, z, beta=2, alpha=3)
out2 = torch.addr(x, y, z, beta=2, alpha=3)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.zeros(3, 2), torch.arange(1., 4.), torch.arange(1., 3.), ))
def test_basic_allclose(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out1 = x.allclose(y, rtol=1e-05, atol=1e-08, equal_nan=False)
out2 = torch.allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08]), ))
def test_basic_angle(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.angle()
out2 = torch.angle(x)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]), ))
# skip apply_(callable) for now
def test_basic_argmax_argmin(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.argmax()
out2 = torch.argmax(x)
out3 = x.argmax(dim=1)
out4 = torch.argmax(x, dim=1)
out5 = x.argmax(dim=1, keepdim=True)
o1 = x.argmin()
o2 = torch.argmin(x)
o3 = x.argmin(dim=1)
o4 = x.argmin(dim=1, keepdim=True)
return out1, out2, out3, out4, out5, o1, o2, o3, o4
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))
def test_basic_argsort(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.argsort()
out2 = x.argsort(dim=1)
out3 = x.argsort(dim=1, descending=True)
out4 = torch.argsort(x, dim=1, descending=True)
return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))
# skip backward(gradient=None, retain_graph=None, create_graph=False)
def test_basic_bernoulli(self):
class SimpleOp(nn.Module):
def forward(self, x):
# generator=torch.Generator() is not supported by jit
out = x.bernoulli()
return out
self.checkExportImport(SimpleOp(), (torch.ones(3, 3), ))
# bfloat16/bool/byte/char is not supported by jit
def test_basic_bincount(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out1 = x.bincount()
out2 = torch.bincount(x)
out3 = x.bincount(weights=y)
out4 = x.bincount(weights=y, minlength=2)
return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.randint(0, 8, (5,), dtype=torch.int64), torch.linspace(0, 1, steps=5), ))
def test_basic_bitwise(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
out1 = x.bitwise_not()
out2 = x.bitwise_and(y)
out3 = x.bitwise_or(y)
out4 = x.bitwise_xor(y)
return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8), ))
# cauchy_ is not supported yet
def test_ceil(self):
class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.ceil()
return out1
self.checkExportImport(SimpleOp(), (torch.randn(4), ))
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -167,6 +167,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -167,6 +167,7 @@ class TestHighLevelAPI(unittest.TestCase):
mutator = mutators[0].bind_sampler(EnuemrateSampler()) mutator = mutators[0].bind_sampler(EnuemrateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
model2 = mutator.apply(model) model2 = mutator.apply(model)
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3])) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment