Unverified Commit 8a1fdd53 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Remove PyTorch version larger equal in Retiarii (#4622)

parent 50ab44ba
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .api import * from .api import *
from .component import * from .component import *
from .nn import * from .nn import *
from .hypermodule import * from .hypermodule import *
\ No newline at end of file
...@@ -12,7 +12,7 @@ import torch.nn as nn ...@@ -12,7 +12,7 @@ import torch.nn as nn
from nni.common.serializer import Translatable from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .utils import Mutable, generate_new_label, get_fixed_value from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs'] __all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
......
...@@ -10,8 +10,8 @@ import torch ...@@ -10,8 +10,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from .api import ChosenInputs, LayerChoice, InputChoice from .api import ChosenInputs, LayerChoice, InputChoice
from .nn import ModuleList from .nn import ModuleList # pylint: disable=no-name-in-module
from .utils import generate_new_label from .mutation_utils import generate_new_label
class _ListIdentity(nn.Identity): class _ListIdentity(nn.Identity):
......
...@@ -10,7 +10,7 @@ from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL ...@@ -10,7 +10,7 @@ from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice from .api import LayerChoice
from .cell import Cell from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .utils import Mutable, generate_new_label, get_fixed_value from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell'] __all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from nni.retiarii.serializer import basic_unit from nni.retiarii.serializer import basic_unit
from .api import LayerChoice from .api import LayerChoice
from .utils import generate_new_label from .mutation_utils import generate_new_label
__all__ = ['AutoActivation'] __all__ = ['AutoActivation']
......
...@@ -9,7 +9,7 @@ import torch.nn as nn ...@@ -9,7 +9,7 @@ import torch.nn as nn
from nni.retiarii.mutator import InvalidMutation, Mutator from nni.retiarii.mutator import InvalidMutation, Mutator
from nni.retiarii.graph import Model from nni.retiarii.graph import Model
from .api import InputChoice, ValueChoice, LayerChoice from .api import InputChoice, ValueChoice, LayerChoice
from .utils import Mutable, generate_new_label, get_fixed_dict from .mutation_utils import Mutable, generate_new_label, get_fixed_dict
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from packaging.version import Version import inspect
import warnings
from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...serializer import basic_unit # To make auto-completion happy, we generate a _nn.py that lists out all the classes.
# NOTE: support pytorch version >= 1.5.0 nn_cache_file_path = Path(__file__).parent / '_nn.py'
__all__ = [ cache_valid = False
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', if nn_cache_file_path.exists():
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6', from . import _nn # pylint: disable=no-name-in-module
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink', # valid only when torch version match
'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin', if _nn._torch_version == torch.__version__:
'Tanhshrink', 'RReLU', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d', cache_valid = True
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d', if not cache_valid:
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm', _NO_WRAP_CLASSES = [
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout', # not an nn.Module
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d', 'Parameter',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'ParameterList',
'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'UninitializedBuffer',
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'UninitializedParameter',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold', # arguments are special
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder', 'Module',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'Sequential',
'Flatten', 'Hardsigmoid'
] # utilities
'Container',
if Version(torch.__version__) >= Version('1.6.0'): 'DataParallel',
__all__.append('Hardswish') ]
if Version(torch.__version__) >= Version('1.7.0'): _WRAP_WITHOUT_TAG_CLASSES = [
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss']) # special support on graph engine
'ModuleList',
'ModuleDict',
Module = nn.Module ]
Sequential = nn.Sequential code = [
ModuleList = basic_unit(nn.ModuleList, basic_unit_tag=False) '# This file is auto-generated to make auto-completion work.',
'# When pytorch version does not match, it will get automatically updated.',
Identity = basic_unit(nn.Identity) '# pylint: skip-file',
Linear = basic_unit(nn.Linear) f'_torch_version = "{torch.__version__}"',
Conv1d = basic_unit(nn.Conv1d) 'import torch.nn as nn',
Conv2d = basic_unit(nn.Conv2d) 'from nni.retiarii.serializer import basic_unit',
Conv3d = basic_unit(nn.Conv3d) ]
ConvTranspose1d = basic_unit(nn.ConvTranspose1d)
ConvTranspose2d = basic_unit(nn.ConvTranspose2d) all_names = []
ConvTranspose3d = basic_unit(nn.ConvTranspose3d)
Threshold = basic_unit(nn.Threshold) # Add modules, classes, functions in torch.nn into this module.
ReLU = basic_unit(nn.ReLU) for name, obj in inspect.getmembers(torch.nn):
Hardtanh = basic_unit(nn.Hardtanh) if inspect.isclass(obj):
ReLU6 = basic_unit(nn.ReLU6) if name in _NO_WRAP_CLASSES:
Sigmoid = basic_unit(nn.Sigmoid) code.append(f'{name} = nn.{name}')
Tanh = basic_unit(nn.Tanh) elif not issubclass(obj, nn.Module):
Softmax = basic_unit(nn.Softmax) # It should never go here
Softmax2d = basic_unit(nn.Softmax2d) # We did it to play safe
LogSoftmax = basic_unit(nn.LogSoftmax) warnings.warn(f'{obj} is found to be not a nn.Module, which is unexpected. '
ELU = basic_unit(nn.ELU) 'It means your PyTorch version might not be supported.', RuntimeWarning)
SELU = basic_unit(nn.SELU) code.append(f'{name} = nn.{name}')
CELU = basic_unit(nn.CELU) elif name in _WRAP_WITHOUT_TAG_CLASSES:
GLU = basic_unit(nn.GLU) code.append(f'{name} = basic_unit(nn.{name}, basic_unit_tag=False)')
GELU = basic_unit(nn.GELU) else:
Hardshrink = basic_unit(nn.Hardshrink) code.append(f'{name} = basic_unit(nn.{name})')
LeakyReLU = basic_unit(nn.LeakyReLU)
LogSigmoid = basic_unit(nn.LogSigmoid) all_names.append(name)
Softplus = basic_unit(nn.Softplus)
Softshrink = basic_unit(nn.Softshrink) elif inspect.isfunction(obj) or inspect.ismodule(obj):
MultiheadAttention = basic_unit(nn.MultiheadAttention) code.append(f'{name} = nn.{name}') # no modification
PReLU = basic_unit(nn.PReLU) all_names.append(name)
Softsign = basic_unit(nn.Softsign)
Softmin = basic_unit(nn.Softmin) code.append(f'__all__ = {all_names}')
Tanhshrink = basic_unit(nn.Tanhshrink)
RReLU = basic_unit(nn.RReLU) with nn_cache_file_path.open('w') as fp:
AvgPool1d = basic_unit(nn.AvgPool1d) fp.write('\n'.join(code))
AvgPool2d = basic_unit(nn.AvgPool2d)
AvgPool3d = basic_unit(nn.AvgPool3d)
MaxPool1d = basic_unit(nn.MaxPool1d) # Import all modules from generated _nn.py
MaxPool2d = basic_unit(nn.MaxPool2d)
MaxPool3d = basic_unit(nn.MaxPool3d) from . import _nn # pylint: disable=no-name-in-module
MaxUnpool1d = basic_unit(nn.MaxUnpool1d) __all__ = _nn.__all__
MaxUnpool2d = basic_unit(nn.MaxUnpool2d) from ._nn import * # pylint: disable=import-error, wildcard-import
MaxUnpool3d = basic_unit(nn.MaxUnpool3d)
FractionalMaxPool2d = basic_unit(nn.FractionalMaxPool2d)
FractionalMaxPool3d = basic_unit(nn.FractionalMaxPool3d)
LPPool1d = basic_unit(nn.LPPool1d)
LPPool2d = basic_unit(nn.LPPool2d)
LocalResponseNorm = basic_unit(nn.LocalResponseNorm)
BatchNorm1d = basic_unit(nn.BatchNorm1d)
BatchNorm2d = basic_unit(nn.BatchNorm2d)
BatchNorm3d = basic_unit(nn.BatchNorm3d)
InstanceNorm1d = basic_unit(nn.InstanceNorm1d)
InstanceNorm2d = basic_unit(nn.InstanceNorm2d)
InstanceNorm3d = basic_unit(nn.InstanceNorm3d)
LayerNorm = basic_unit(nn.LayerNorm)
GroupNorm = basic_unit(nn.GroupNorm)
SyncBatchNorm = basic_unit(nn.SyncBatchNorm)
Dropout = basic_unit(nn.Dropout)
Dropout2d = basic_unit(nn.Dropout2d)
Dropout3d = basic_unit(nn.Dropout3d)
AlphaDropout = basic_unit(nn.AlphaDropout)
FeatureAlphaDropout = basic_unit(nn.FeatureAlphaDropout)
ReflectionPad1d = basic_unit(nn.ReflectionPad1d)
ReflectionPad2d = basic_unit(nn.ReflectionPad2d)
ReplicationPad2d = basic_unit(nn.ReplicationPad2d)
ReplicationPad1d = basic_unit(nn.ReplicationPad1d)
ReplicationPad3d = basic_unit(nn.ReplicationPad3d)
CrossMapLRN2d = basic_unit(nn.CrossMapLRN2d)
Embedding = basic_unit(nn.Embedding)
EmbeddingBag = basic_unit(nn.EmbeddingBag)
RNNBase = basic_unit(nn.RNNBase)
RNN = basic_unit(nn.RNN)
LSTM = basic_unit(nn.LSTM)
GRU = basic_unit(nn.GRU)
RNNCellBase = basic_unit(nn.RNNCellBase)
RNNCell = basic_unit(nn.RNNCell)
LSTMCell = basic_unit(nn.LSTMCell)
GRUCell = basic_unit(nn.GRUCell)
PixelShuffle = basic_unit(nn.PixelShuffle)
Upsample = basic_unit(nn.Upsample)
UpsamplingNearest2d = basic_unit(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = basic_unit(nn.UpsamplingBilinear2d)
PairwiseDistance = basic_unit(nn.PairwiseDistance)
AdaptiveMaxPool1d = basic_unit(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = basic_unit(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = basic_unit(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = basic_unit(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = basic_unit(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = basic_unit(nn.AdaptiveAvgPool3d)
TripletMarginLoss = basic_unit(nn.TripletMarginLoss)
ZeroPad2d = basic_unit(nn.ZeroPad2d)
ConstantPad1d = basic_unit(nn.ConstantPad1d)
ConstantPad2d = basic_unit(nn.ConstantPad2d)
ConstantPad3d = basic_unit(nn.ConstantPad3d)
Bilinear = basic_unit(nn.Bilinear)
CosineSimilarity = basic_unit(nn.CosineSimilarity)
Unfold = basic_unit(nn.Unfold)
Fold = basic_unit(nn.Fold)
AdaptiveLogSoftmaxWithLoss = basic_unit(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = basic_unit(nn.TransformerEncoder)
TransformerDecoder = basic_unit(nn.TransformerDecoder)
TransformerEncoderLayer = basic_unit(nn.TransformerEncoderLayer)
TransformerDecoderLayer = basic_unit(nn.TransformerDecoderLayer)
Transformer = basic_unit(nn.Transformer)
Flatten = basic_unit(nn.Flatten)
Hardsigmoid = basic_unit(nn.Hardsigmoid)
if Version(torch.__version__) >= Version('1.6.0'):
Hardswish = basic_unit(nn.Hardswish)
if Version(torch.__version__) >= Version('1.7.0'):
SiLU = basic_unit(nn.SiLU)
Unflatten = basic_unit(nn.Unflatten)
TripletMarginWithDistanceLoss = basic_unit(nn.TripletMarginWithDistanceLoss)
...@@ -170,7 +170,13 @@ def _torchscript_patch(cls) -> None: ...@@ -170,7 +170,13 @@ def _torchscript_patch(cls) -> None:
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr) cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if hasattr(cls, 'trace_symbol'): if hasattr(cls, 'trace_symbol'):
# these must all exist or all non-exist # these must all exist or all non-exist
cls.trace_symbol = torch.jit.unused(cls.trace_symbol) try:
cls.trace_args = torch.jit.unused(cls.trace_args) cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs) cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_copy = torch.jit.ignore(cls.trace_copy) cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
except AttributeError as e:
if 'property' in str(e):
raise RuntimeError('Trace on PyTorch module failed. Your PyTorch version might be outdated. '
'Please try to upgrade PyTorch.')
raise
...@@ -1001,3 +1001,10 @@ class Shared(unittest.TestCase): ...@@ -1001,3 +1001,10 @@ class Shared(unittest.TestCase):
for _ in range(10): for _ in range(10):
model = _apply_all_mutators(init_model, mutators, sampler) model = _apply_all_mutators(init_model, mutators, sampler)
assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)] assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)]
def test_retiarii_nn_import(self):
dummy = torch.zeros(1, 16, 32, 24)
nn.init.uniform_(dummy)
conv = nn.Conv2d(1, 3, 1)
param = nn.Parameter(torch.zeros(1, 3, 24, 24))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment