Unverified Commit 8a1fdd53 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Remove PyTorch version larger equal in Retiarii (#4622)

parent 50ab44ba
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .api import *
from .component import *
from .nn import *
from .hypermodule import *
\ No newline at end of file
from .hypermodule import *
......@@ -12,7 +12,7 @@ import torch.nn as nn
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .utils import Mutable, generate_new_label, get_fixed_value
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
......
......@@ -10,8 +10,8 @@ import torch
import torch.nn as nn
from .api import ChosenInputs, LayerChoice, InputChoice
from .nn import ModuleList
from .utils import generate_new_label
from .nn import ModuleList # pylint: disable=no-name-in-module
from .mutation_utils import generate_new_label
class _ListIdentity(nn.Identity):
......
......@@ -10,7 +10,7 @@ from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice
from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .utils import Mutable, generate_new_label, get_fixed_value
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']
......
......@@ -8,7 +8,7 @@ import torch.nn as nn
from nni.retiarii.serializer import basic_unit
from .api import LayerChoice
from .utils import generate_new_label
from .mutation_utils import generate_new_label
__all__ = ['AutoActivation']
......
......@@ -9,7 +9,7 @@ import torch.nn as nn
from nni.retiarii.mutator import InvalidMutation, Mutator
from nni.retiarii.graph import Model
from .api import InputChoice, ValueChoice, LayerChoice
from .utils import Mutable, generate_new_label, get_fixed_dict
from .mutation_utils import Mutable, generate_new_label, get_fixed_dict
_logger = logging.getLogger(__name__)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from packaging.version import Version
import inspect
import warnings
from pathlib import Path
import torch
import torch.nn as nn
from ...serializer import basic_unit
# NOTE: support pytorch version >= 1.5.0
__all__ = [
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
'Tanhshrink', 'RReLU', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten', 'Hardsigmoid'
]
if Version(torch.__version__) >= Version('1.6.0'):
__all__.append('Hardswish')
if Version(torch.__version__) >= Version('1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
Module = nn.Module
Sequential = nn.Sequential
ModuleList = basic_unit(nn.ModuleList, basic_unit_tag=False)
Identity = basic_unit(nn.Identity)
Linear = basic_unit(nn.Linear)
Conv1d = basic_unit(nn.Conv1d)
Conv2d = basic_unit(nn.Conv2d)
Conv3d = basic_unit(nn.Conv3d)
ConvTranspose1d = basic_unit(nn.ConvTranspose1d)
ConvTranspose2d = basic_unit(nn.ConvTranspose2d)
ConvTranspose3d = basic_unit(nn.ConvTranspose3d)
Threshold = basic_unit(nn.Threshold)
ReLU = basic_unit(nn.ReLU)
Hardtanh = basic_unit(nn.Hardtanh)
ReLU6 = basic_unit(nn.ReLU6)
Sigmoid = basic_unit(nn.Sigmoid)
Tanh = basic_unit(nn.Tanh)
Softmax = basic_unit(nn.Softmax)
Softmax2d = basic_unit(nn.Softmax2d)
LogSoftmax = basic_unit(nn.LogSoftmax)
ELU = basic_unit(nn.ELU)
SELU = basic_unit(nn.SELU)
CELU = basic_unit(nn.CELU)
GLU = basic_unit(nn.GLU)
GELU = basic_unit(nn.GELU)
Hardshrink = basic_unit(nn.Hardshrink)
LeakyReLU = basic_unit(nn.LeakyReLU)
LogSigmoid = basic_unit(nn.LogSigmoid)
Softplus = basic_unit(nn.Softplus)
Softshrink = basic_unit(nn.Softshrink)
MultiheadAttention = basic_unit(nn.MultiheadAttention)
PReLU = basic_unit(nn.PReLU)
Softsign = basic_unit(nn.Softsign)
Softmin = basic_unit(nn.Softmin)
Tanhshrink = basic_unit(nn.Tanhshrink)
RReLU = basic_unit(nn.RReLU)
AvgPool1d = basic_unit(nn.AvgPool1d)
AvgPool2d = basic_unit(nn.AvgPool2d)
AvgPool3d = basic_unit(nn.AvgPool3d)
MaxPool1d = basic_unit(nn.MaxPool1d)
MaxPool2d = basic_unit(nn.MaxPool2d)
MaxPool3d = basic_unit(nn.MaxPool3d)
MaxUnpool1d = basic_unit(nn.MaxUnpool1d)
MaxUnpool2d = basic_unit(nn.MaxUnpool2d)
MaxUnpool3d = basic_unit(nn.MaxUnpool3d)
FractionalMaxPool2d = basic_unit(nn.FractionalMaxPool2d)
FractionalMaxPool3d = basic_unit(nn.FractionalMaxPool3d)
LPPool1d = basic_unit(nn.LPPool1d)
LPPool2d = basic_unit(nn.LPPool2d)
LocalResponseNorm = basic_unit(nn.LocalResponseNorm)
BatchNorm1d = basic_unit(nn.BatchNorm1d)
BatchNorm2d = basic_unit(nn.BatchNorm2d)
BatchNorm3d = basic_unit(nn.BatchNorm3d)
InstanceNorm1d = basic_unit(nn.InstanceNorm1d)
InstanceNorm2d = basic_unit(nn.InstanceNorm2d)
InstanceNorm3d = basic_unit(nn.InstanceNorm3d)
LayerNorm = basic_unit(nn.LayerNorm)
GroupNorm = basic_unit(nn.GroupNorm)
SyncBatchNorm = basic_unit(nn.SyncBatchNorm)
Dropout = basic_unit(nn.Dropout)
Dropout2d = basic_unit(nn.Dropout2d)
Dropout3d = basic_unit(nn.Dropout3d)
AlphaDropout = basic_unit(nn.AlphaDropout)
FeatureAlphaDropout = basic_unit(nn.FeatureAlphaDropout)
ReflectionPad1d = basic_unit(nn.ReflectionPad1d)
ReflectionPad2d = basic_unit(nn.ReflectionPad2d)
ReplicationPad2d = basic_unit(nn.ReplicationPad2d)
ReplicationPad1d = basic_unit(nn.ReplicationPad1d)
ReplicationPad3d = basic_unit(nn.ReplicationPad3d)
CrossMapLRN2d = basic_unit(nn.CrossMapLRN2d)
Embedding = basic_unit(nn.Embedding)
EmbeddingBag = basic_unit(nn.EmbeddingBag)
RNNBase = basic_unit(nn.RNNBase)
RNN = basic_unit(nn.RNN)
LSTM = basic_unit(nn.LSTM)
GRU = basic_unit(nn.GRU)
RNNCellBase = basic_unit(nn.RNNCellBase)
RNNCell = basic_unit(nn.RNNCell)
LSTMCell = basic_unit(nn.LSTMCell)
GRUCell = basic_unit(nn.GRUCell)
PixelShuffle = basic_unit(nn.PixelShuffle)
Upsample = basic_unit(nn.Upsample)
UpsamplingNearest2d = basic_unit(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = basic_unit(nn.UpsamplingBilinear2d)
PairwiseDistance = basic_unit(nn.PairwiseDistance)
AdaptiveMaxPool1d = basic_unit(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = basic_unit(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = basic_unit(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = basic_unit(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = basic_unit(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = basic_unit(nn.AdaptiveAvgPool3d)
TripletMarginLoss = basic_unit(nn.TripletMarginLoss)
ZeroPad2d = basic_unit(nn.ZeroPad2d)
ConstantPad1d = basic_unit(nn.ConstantPad1d)
ConstantPad2d = basic_unit(nn.ConstantPad2d)
ConstantPad3d = basic_unit(nn.ConstantPad3d)
Bilinear = basic_unit(nn.Bilinear)
CosineSimilarity = basic_unit(nn.CosineSimilarity)
Unfold = basic_unit(nn.Unfold)
Fold = basic_unit(nn.Fold)
AdaptiveLogSoftmaxWithLoss = basic_unit(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = basic_unit(nn.TransformerEncoder)
TransformerDecoder = basic_unit(nn.TransformerDecoder)
TransformerEncoderLayer = basic_unit(nn.TransformerEncoderLayer)
TransformerDecoderLayer = basic_unit(nn.TransformerDecoderLayer)
Transformer = basic_unit(nn.Transformer)
Flatten = basic_unit(nn.Flatten)
Hardsigmoid = basic_unit(nn.Hardsigmoid)
if Version(torch.__version__) >= Version('1.6.0'):
Hardswish = basic_unit(nn.Hardswish)
if Version(torch.__version__) >= Version('1.7.0'):
SiLU = basic_unit(nn.SiLU)
Unflatten = basic_unit(nn.Unflatten)
TripletMarginWithDistanceLoss = basic_unit(nn.TripletMarginWithDistanceLoss)
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_nn.py'
cache_valid = False
if nn_cache_file_path.exists():
from . import _nn # pylint: disable=no-name-in-module
# valid only when torch version match
if _nn._torch_version == torch.__version__:
cache_valid = True
if not cache_valid:
_NO_WRAP_CLASSES = [
# not an nn.Module
'Parameter',
'ParameterList',
'UninitializedBuffer',
'UninitializedParameter',
# arguments are special
'Module',
'Sequential',
# utilities
'Container',
'DataParallel',
]
_WRAP_WITHOUT_TAG_CLASSES = [
# special support on graph engine
'ModuleList',
'ModuleDict',
]
code = [
'# This file is auto-generated to make auto-completion work.',
'# When pytorch version does not match, it will get automatically updated.',
'# pylint: skip-file',
f'_torch_version = "{torch.__version__}"',
'import torch.nn as nn',
'from nni.retiarii.serializer import basic_unit',
]
all_names = []
# Add modules, classes, functions in torch.nn into this module.
for name, obj in inspect.getmembers(torch.nn):
if inspect.isclass(obj):
if name in _NO_WRAP_CLASSES:
code.append(f'{name} = nn.{name}')
elif not issubclass(obj, nn.Module):
# It should never go here
# We did it to play safe
warnings.warn(f'{obj} is found to be not a nn.Module, which is unexpected. '
'It means your PyTorch version might not be supported.', RuntimeWarning)
code.append(f'{name} = nn.{name}')
elif name in _WRAP_WITHOUT_TAG_CLASSES:
code.append(f'{name} = basic_unit(nn.{name}, basic_unit_tag=False)')
else:
code.append(f'{name} = basic_unit(nn.{name})')
all_names.append(name)
elif inspect.isfunction(obj) or inspect.ismodule(obj):
code.append(f'{name} = nn.{name}') # no modification
all_names.append(name)
code.append(f'__all__ = {all_names}')
with nn_cache_file_path.open('w') as fp:
fp.write('\n'.join(code))
# Import all modules from generated _nn.py
from . import _nn # pylint: disable=no-name-in-module
__all__ = _nn.__all__
from ._nn import * # pylint: disable=import-error, wildcard-import
......@@ -170,7 +170,13 @@ def _torchscript_patch(cls) -> None:
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if hasattr(cls, 'trace_symbol'):
# these must all exist or all non-exist
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
try:
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
except AttributeError as e:
if 'property' in str(e):
raise RuntimeError('Trace on PyTorch module failed. Your PyTorch version might be outdated. '
'Please try to upgrade PyTorch.')
raise
......@@ -1001,3 +1001,10 @@ class Shared(unittest.TestCase):
for _ in range(10):
model = _apply_all_mutators(init_model, mutators, sampler)
assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)]
def test_retiarii_nn_import(self):
dummy = torch.zeros(1, 16, 32, 24)
nn.init.uniform_(dummy)
conv = nn.Conv2d(1, 3, 1)
param = nn.Parameter(torch.zeros(1, 3, 24, 24))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment