import inspect import logging import torch import torch.nn as nn from nni.retiarii.utils import version_larger_equal _logger = logging.getLogger(__name__) def wrap_module(original_class): orig_init = original_class.__init__ argname_list = list(inspect.signature(original_class).parameters.keys()) # Make copy of original __init__, so we can call it without recursion original_class.bak_init_for_inject = orig_init def __init__(self, *args, **kws): full_args = {} full_args.update(kws) for i, arg in enumerate(args): full_args[argname_list[i]] = arg self._init_parameters = full_args orig_init(self, *args, **kws) # Call the original __init__ original_class.__init__ = __init__ # Set the class' __init__ to the new one return original_class def unwrap_module(wrapped_class): if hasattr(wrapped_class, 'bak_init_for_inject'): wrapped_class.__init__ = wrapped_class.bak_init_for_inject delattr(wrapped_class, 'bak_init_for_inject') return None def remove_inject_pytorch_nn(): Identity = unwrap_module(nn.Identity) Linear = unwrap_module(nn.Linear) Conv1d = unwrap_module(nn.Conv1d) Conv2d = unwrap_module(nn.Conv2d) Conv3d = unwrap_module(nn.Conv3d) ConvTranspose1d = unwrap_module(nn.ConvTranspose1d) ConvTranspose2d = unwrap_module(nn.ConvTranspose2d) ConvTranspose3d = unwrap_module(nn.ConvTranspose3d) Threshold = unwrap_module(nn.Threshold) ReLU = unwrap_module(nn.ReLU) Hardtanh = unwrap_module(nn.Hardtanh) ReLU6 = unwrap_module(nn.ReLU6) Sigmoid = unwrap_module(nn.Sigmoid) Tanh = unwrap_module(nn.Tanh) Softmax = unwrap_module(nn.Softmax) Softmax2d = unwrap_module(nn.Softmax2d) LogSoftmax = unwrap_module(nn.LogSoftmax) ELU = unwrap_module(nn.ELU) SELU = unwrap_module(nn.SELU) CELU = unwrap_module(nn.CELU) GLU = unwrap_module(nn.GLU) GELU = unwrap_module(nn.GELU) Hardshrink = unwrap_module(nn.Hardshrink) LeakyReLU = unwrap_module(nn.LeakyReLU) LogSigmoid = unwrap_module(nn.LogSigmoid) Softplus = unwrap_module(nn.Softplus) Softshrink = unwrap_module(nn.Softshrink) MultiheadAttention = unwrap_module(nn.MultiheadAttention) PReLU = unwrap_module(nn.PReLU) Softsign = unwrap_module(nn.Softsign) Softmin = unwrap_module(nn.Softmin) Tanhshrink = unwrap_module(nn.Tanhshrink) RReLU = unwrap_module(nn.RReLU) AvgPool1d = unwrap_module(nn.AvgPool1d) AvgPool2d = unwrap_module(nn.AvgPool2d) AvgPool3d = unwrap_module(nn.AvgPool3d) MaxPool1d = unwrap_module(nn.MaxPool1d) MaxPool2d = unwrap_module(nn.MaxPool2d) MaxPool3d = unwrap_module(nn.MaxPool3d) MaxUnpool1d = unwrap_module(nn.MaxUnpool1d) MaxUnpool2d = unwrap_module(nn.MaxUnpool2d) MaxUnpool3d = unwrap_module(nn.MaxUnpool3d) FractionalMaxPool2d = unwrap_module(nn.FractionalMaxPool2d) FractionalMaxPool3d = unwrap_module(nn.FractionalMaxPool3d) LPPool1d = unwrap_module(nn.LPPool1d) LPPool2d = unwrap_module(nn.LPPool2d) LocalResponseNorm = unwrap_module(nn.LocalResponseNorm) BatchNorm1d = unwrap_module(nn.BatchNorm1d) BatchNorm2d = unwrap_module(nn.BatchNorm2d) BatchNorm3d = unwrap_module(nn.BatchNorm3d) InstanceNorm1d = unwrap_module(nn.InstanceNorm1d) InstanceNorm2d = unwrap_module(nn.InstanceNorm2d) InstanceNorm3d = unwrap_module(nn.InstanceNorm3d) LayerNorm = unwrap_module(nn.LayerNorm) GroupNorm = unwrap_module(nn.GroupNorm) SyncBatchNorm = unwrap_module(nn.SyncBatchNorm) Dropout = unwrap_module(nn.Dropout) Dropout2d = unwrap_module(nn.Dropout2d) Dropout3d = unwrap_module(nn.Dropout3d) AlphaDropout = unwrap_module(nn.AlphaDropout) FeatureAlphaDropout = unwrap_module(nn.FeatureAlphaDropout) ReflectionPad1d = unwrap_module(nn.ReflectionPad1d) ReflectionPad2d = unwrap_module(nn.ReflectionPad2d) ReplicationPad2d = unwrap_module(nn.ReplicationPad2d) ReplicationPad1d = unwrap_module(nn.ReplicationPad1d) ReplicationPad3d = unwrap_module(nn.ReplicationPad3d) CrossMapLRN2d = unwrap_module(nn.CrossMapLRN2d) Embedding = unwrap_module(nn.Embedding) EmbeddingBag = unwrap_module(nn.EmbeddingBag) RNNBase = unwrap_module(nn.RNNBase) RNN = unwrap_module(nn.RNN) LSTM = unwrap_module(nn.LSTM) GRU = unwrap_module(nn.GRU) RNNCellBase = unwrap_module(nn.RNNCellBase) RNNCell = unwrap_module(nn.RNNCell) LSTMCell = unwrap_module(nn.LSTMCell) GRUCell = unwrap_module(nn.GRUCell) PixelShuffle = unwrap_module(nn.PixelShuffle) Upsample = unwrap_module(nn.Upsample) UpsamplingNearest2d = unwrap_module(nn.UpsamplingNearest2d) UpsamplingBilinear2d = unwrap_module(nn.UpsamplingBilinear2d) PairwiseDistance = unwrap_module(nn.PairwiseDistance) AdaptiveMaxPool1d = unwrap_module(nn.AdaptiveMaxPool1d) AdaptiveMaxPool2d = unwrap_module(nn.AdaptiveMaxPool2d) AdaptiveMaxPool3d = unwrap_module(nn.AdaptiveMaxPool3d) AdaptiveAvgPool1d = unwrap_module(nn.AdaptiveAvgPool1d) AdaptiveAvgPool2d = unwrap_module(nn.AdaptiveAvgPool2d) AdaptiveAvgPool3d = unwrap_module(nn.AdaptiveAvgPool3d) TripletMarginLoss = unwrap_module(nn.TripletMarginLoss) ZeroPad2d = unwrap_module(nn.ZeroPad2d) ConstantPad1d = unwrap_module(nn.ConstantPad1d) ConstantPad2d = unwrap_module(nn.ConstantPad2d) ConstantPad3d = unwrap_module(nn.ConstantPad3d) Bilinear = unwrap_module(nn.Bilinear) CosineSimilarity = unwrap_module(nn.CosineSimilarity) Unfold = unwrap_module(nn.Unfold) Fold = unwrap_module(nn.Fold) AdaptiveLogSoftmaxWithLoss = unwrap_module(nn.AdaptiveLogSoftmaxWithLoss) TransformerEncoder = unwrap_module(nn.TransformerEncoder) TransformerDecoder = unwrap_module(nn.TransformerDecoder) TransformerEncoderLayer = unwrap_module(nn.TransformerEncoderLayer) TransformerDecoderLayer = unwrap_module(nn.TransformerDecoderLayer) Transformer = unwrap_module(nn.Transformer) Flatten = unwrap_module(nn.Flatten) Hardsigmoid = unwrap_module(nn.Hardsigmoid) if version_larger_equal(torch.__version__, '1.6.0'): Hardswish = unwrap_module(nn.Hardswish) if version_larger_equal(torch.__version__, '1.7.0'): SiLU = unwrap_module(nn.SiLU) Unflatten = unwrap_module(nn.Unflatten) TripletMarginWithDistanceLoss = unwrap_module(nn.TripletMarginWithDistanceLoss) def inject_pytorch_nn(): Identity = wrap_module(nn.Identity) Linear = wrap_module(nn.Linear) Conv1d = wrap_module(nn.Conv1d) Conv2d = wrap_module(nn.Conv2d) Conv3d = wrap_module(nn.Conv3d) ConvTranspose1d = wrap_module(nn.ConvTranspose1d) ConvTranspose2d = wrap_module(nn.ConvTranspose2d) ConvTranspose3d = wrap_module(nn.ConvTranspose3d) Threshold = wrap_module(nn.Threshold) ReLU = wrap_module(nn.ReLU) Hardtanh = wrap_module(nn.Hardtanh) ReLU6 = wrap_module(nn.ReLU6) Sigmoid = wrap_module(nn.Sigmoid) Tanh = wrap_module(nn.Tanh) Softmax = wrap_module(nn.Softmax) Softmax2d = wrap_module(nn.Softmax2d) LogSoftmax = wrap_module(nn.LogSoftmax) ELU = wrap_module(nn.ELU) SELU = wrap_module(nn.SELU) CELU = wrap_module(nn.CELU) GLU = wrap_module(nn.GLU) GELU = wrap_module(nn.GELU) Hardshrink = wrap_module(nn.Hardshrink) LeakyReLU = wrap_module(nn.LeakyReLU) LogSigmoid = wrap_module(nn.LogSigmoid) Softplus = wrap_module(nn.Softplus) Softshrink = wrap_module(nn.Softshrink) MultiheadAttention = wrap_module(nn.MultiheadAttention) PReLU = wrap_module(nn.PReLU) Softsign = wrap_module(nn.Softsign) Softmin = wrap_module(nn.Softmin) Tanhshrink = wrap_module(nn.Tanhshrink) RReLU = wrap_module(nn.RReLU) AvgPool1d = wrap_module(nn.AvgPool1d) AvgPool2d = wrap_module(nn.AvgPool2d) AvgPool3d = wrap_module(nn.AvgPool3d) MaxPool1d = wrap_module(nn.MaxPool1d) MaxPool2d = wrap_module(nn.MaxPool2d) MaxPool3d = wrap_module(nn.MaxPool3d) MaxUnpool1d = wrap_module(nn.MaxUnpool1d) MaxUnpool2d = wrap_module(nn.MaxUnpool2d) MaxUnpool3d = wrap_module(nn.MaxUnpool3d) FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d) FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d) LPPool1d = wrap_module(nn.LPPool1d) LPPool2d = wrap_module(nn.LPPool2d) LocalResponseNorm = wrap_module(nn.LocalResponseNorm) BatchNorm1d = wrap_module(nn.BatchNorm1d) BatchNorm2d = wrap_module(nn.BatchNorm2d) BatchNorm3d = wrap_module(nn.BatchNorm3d) InstanceNorm1d = wrap_module(nn.InstanceNorm1d) InstanceNorm2d = wrap_module(nn.InstanceNorm2d) InstanceNorm3d = wrap_module(nn.InstanceNorm3d) LayerNorm = wrap_module(nn.LayerNorm) GroupNorm = wrap_module(nn.GroupNorm) SyncBatchNorm = wrap_module(nn.SyncBatchNorm) Dropout = wrap_module(nn.Dropout) Dropout2d = wrap_module(nn.Dropout2d) Dropout3d = wrap_module(nn.Dropout3d) AlphaDropout = wrap_module(nn.AlphaDropout) FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout) ReflectionPad1d = wrap_module(nn.ReflectionPad1d) ReflectionPad2d = wrap_module(nn.ReflectionPad2d) ReplicationPad2d = wrap_module(nn.ReplicationPad2d) ReplicationPad1d = wrap_module(nn.ReplicationPad1d) ReplicationPad3d = wrap_module(nn.ReplicationPad3d) CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d) Embedding = wrap_module(nn.Embedding) EmbeddingBag = wrap_module(nn.EmbeddingBag) RNNBase = wrap_module(nn.RNNBase) RNN = wrap_module(nn.RNN) LSTM = wrap_module(nn.LSTM) GRU = wrap_module(nn.GRU) RNNCellBase = wrap_module(nn.RNNCellBase) RNNCell = wrap_module(nn.RNNCell) LSTMCell = wrap_module(nn.LSTMCell) GRUCell = wrap_module(nn.GRUCell) PixelShuffle = wrap_module(nn.PixelShuffle) Upsample = wrap_module(nn.Upsample) UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d) UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d) PairwiseDistance = wrap_module(nn.PairwiseDistance) AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d) AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d) AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d) AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d) AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d) AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d) TripletMarginLoss = wrap_module(nn.TripletMarginLoss) ZeroPad2d = wrap_module(nn.ZeroPad2d) ConstantPad1d = wrap_module(nn.ConstantPad1d) ConstantPad2d = wrap_module(nn.ConstantPad2d) ConstantPad3d = wrap_module(nn.ConstantPad3d) Bilinear = wrap_module(nn.Bilinear) CosineSimilarity = wrap_module(nn.CosineSimilarity) Unfold = wrap_module(nn.Unfold) Fold = wrap_module(nn.Fold) AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss) TransformerEncoder = wrap_module(nn.TransformerEncoder) TransformerDecoder = wrap_module(nn.TransformerDecoder) TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer) TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer) Transformer = wrap_module(nn.Transformer) Flatten = wrap_module(nn.Flatten) Hardsigmoid = wrap_module(nn.Hardsigmoid) if version_larger_equal(torch.__version__, '1.6.0'): Hardswish = wrap_module(nn.Hardswish) if version_larger_equal(torch.__version__, '1.7.0'): SiLU = wrap_module(nn.SiLU) Unflatten = wrap_module(nn.Unflatten) TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)