import inspect
import logging

import torch
import torch.nn as nn

from nni.retiarii.utils import version_larger_equal

_logger = logging.getLogger(__name__)

def wrap_module(original_class):
    orig_init = original_class.__init__
    argname_list = list(inspect.signature(original_class).parameters.keys())
    # Make copy of original __init__, so we can call it without recursion
    original_class.bak_init_for_inject = orig_init

    def __init__(self, *args, **kws):
        full_args = {}
        full_args.update(kws)
        for i, arg in enumerate(args):
            full_args[argname_list[i]] = arg
        self._init_parameters = full_args

        orig_init(self, *args, **kws)  # Call the original __init__

    original_class.__init__ = __init__  # Set the class' __init__ to the new one
    return original_class

def unwrap_module(wrapped_class):
    if hasattr(wrapped_class, 'bak_init_for_inject'):
        wrapped_class.__init__ = wrapped_class.bak_init_for_inject
        delattr(wrapped_class, 'bak_init_for_inject')
    return None

def remove_inject_pytorch_nn():
    Identity = unwrap_module(nn.Identity)
    Linear = unwrap_module(nn.Linear)
    Conv1d = unwrap_module(nn.Conv1d)
    Conv2d = unwrap_module(nn.Conv2d)
    Conv3d = unwrap_module(nn.Conv3d)
    ConvTranspose1d = unwrap_module(nn.ConvTranspose1d)
    ConvTranspose2d = unwrap_module(nn.ConvTranspose2d)
    ConvTranspose3d = unwrap_module(nn.ConvTranspose3d)
    Threshold = unwrap_module(nn.Threshold)
    ReLU = unwrap_module(nn.ReLU)
    Hardtanh = unwrap_module(nn.Hardtanh)
    ReLU6 = unwrap_module(nn.ReLU6)
    Sigmoid = unwrap_module(nn.Sigmoid)
    Tanh = unwrap_module(nn.Tanh)
    Softmax = unwrap_module(nn.Softmax)
    Softmax2d = unwrap_module(nn.Softmax2d)
    LogSoftmax = unwrap_module(nn.LogSoftmax)
    ELU = unwrap_module(nn.ELU)
    SELU = unwrap_module(nn.SELU)
    CELU = unwrap_module(nn.CELU)
    GLU = unwrap_module(nn.GLU)
    GELU = unwrap_module(nn.GELU)
    Hardshrink = unwrap_module(nn.Hardshrink)
    LeakyReLU = unwrap_module(nn.LeakyReLU)
    LogSigmoid = unwrap_module(nn.LogSigmoid)
    Softplus = unwrap_module(nn.Softplus)
    Softshrink = unwrap_module(nn.Softshrink)
    MultiheadAttention = unwrap_module(nn.MultiheadAttention)
    PReLU = unwrap_module(nn.PReLU)
    Softsign = unwrap_module(nn.Softsign)
    Softmin = unwrap_module(nn.Softmin)
    Tanhshrink = unwrap_module(nn.Tanhshrink)
    RReLU = unwrap_module(nn.RReLU)
    AvgPool1d = unwrap_module(nn.AvgPool1d)
    AvgPool2d = unwrap_module(nn.AvgPool2d)
    AvgPool3d = unwrap_module(nn.AvgPool3d)
    MaxPool1d = unwrap_module(nn.MaxPool1d)
    MaxPool2d = unwrap_module(nn.MaxPool2d)
    MaxPool3d = unwrap_module(nn.MaxPool3d)
    MaxUnpool1d = unwrap_module(nn.MaxUnpool1d)
    MaxUnpool2d = unwrap_module(nn.MaxUnpool2d)
    MaxUnpool3d = unwrap_module(nn.MaxUnpool3d)
    FractionalMaxPool2d = unwrap_module(nn.FractionalMaxPool2d)
    FractionalMaxPool3d = unwrap_module(nn.FractionalMaxPool3d)
    LPPool1d = unwrap_module(nn.LPPool1d)
    LPPool2d = unwrap_module(nn.LPPool2d)
    LocalResponseNorm = unwrap_module(nn.LocalResponseNorm)
    BatchNorm1d = unwrap_module(nn.BatchNorm1d)
    BatchNorm2d = unwrap_module(nn.BatchNorm2d)
    BatchNorm3d = unwrap_module(nn.BatchNorm3d)
    InstanceNorm1d = unwrap_module(nn.InstanceNorm1d)
    InstanceNorm2d = unwrap_module(nn.InstanceNorm2d)
    InstanceNorm3d = unwrap_module(nn.InstanceNorm3d)
    LayerNorm = unwrap_module(nn.LayerNorm)
    GroupNorm = unwrap_module(nn.GroupNorm)
    SyncBatchNorm = unwrap_module(nn.SyncBatchNorm)
    Dropout = unwrap_module(nn.Dropout)
    Dropout2d = unwrap_module(nn.Dropout2d)
    Dropout3d = unwrap_module(nn.Dropout3d)
    AlphaDropout = unwrap_module(nn.AlphaDropout)
    FeatureAlphaDropout = unwrap_module(nn.FeatureAlphaDropout)
    ReflectionPad1d = unwrap_module(nn.ReflectionPad1d)
    ReflectionPad2d = unwrap_module(nn.ReflectionPad2d)
    ReplicationPad2d = unwrap_module(nn.ReplicationPad2d)
    ReplicationPad1d = unwrap_module(nn.ReplicationPad1d)
    ReplicationPad3d = unwrap_module(nn.ReplicationPad3d)
    CrossMapLRN2d = unwrap_module(nn.CrossMapLRN2d)
    Embedding = unwrap_module(nn.Embedding)
    EmbeddingBag = unwrap_module(nn.EmbeddingBag)
    RNNBase = unwrap_module(nn.RNNBase)
    RNN = unwrap_module(nn.RNN)
    LSTM = unwrap_module(nn.LSTM)
    GRU = unwrap_module(nn.GRU)
    RNNCellBase = unwrap_module(nn.RNNCellBase)
    RNNCell = unwrap_module(nn.RNNCell)
    LSTMCell = unwrap_module(nn.LSTMCell)
    GRUCell = unwrap_module(nn.GRUCell)
    PixelShuffle = unwrap_module(nn.PixelShuffle)
    Upsample = unwrap_module(nn.Upsample)
    UpsamplingNearest2d = unwrap_module(nn.UpsamplingNearest2d)
    UpsamplingBilinear2d = unwrap_module(nn.UpsamplingBilinear2d)
    PairwiseDistance = unwrap_module(nn.PairwiseDistance)
    AdaptiveMaxPool1d = unwrap_module(nn.AdaptiveMaxPool1d)
    AdaptiveMaxPool2d = unwrap_module(nn.AdaptiveMaxPool2d)
    AdaptiveMaxPool3d = unwrap_module(nn.AdaptiveMaxPool3d)
    AdaptiveAvgPool1d = unwrap_module(nn.AdaptiveAvgPool1d)
    AdaptiveAvgPool2d = unwrap_module(nn.AdaptiveAvgPool2d)
    AdaptiveAvgPool3d = unwrap_module(nn.AdaptiveAvgPool3d)
    TripletMarginLoss = unwrap_module(nn.TripletMarginLoss)
    ZeroPad2d = unwrap_module(nn.ZeroPad2d)
    ConstantPad1d = unwrap_module(nn.ConstantPad1d)
    ConstantPad2d = unwrap_module(nn.ConstantPad2d)
    ConstantPad3d = unwrap_module(nn.ConstantPad3d)
    Bilinear = unwrap_module(nn.Bilinear)
    CosineSimilarity = unwrap_module(nn.CosineSimilarity)
    Unfold = unwrap_module(nn.Unfold)
    Fold = unwrap_module(nn.Fold)
    AdaptiveLogSoftmaxWithLoss = unwrap_module(nn.AdaptiveLogSoftmaxWithLoss)
    TransformerEncoder = unwrap_module(nn.TransformerEncoder)
    TransformerDecoder = unwrap_module(nn.TransformerDecoder)
    TransformerEncoderLayer = unwrap_module(nn.TransformerEncoderLayer)
    TransformerDecoderLayer = unwrap_module(nn.TransformerDecoderLayer)
    Transformer = unwrap_module(nn.Transformer)
    Flatten = unwrap_module(nn.Flatten)
    Hardsigmoid = unwrap_module(nn.Hardsigmoid)

    if version_larger_equal(torch.__version__, '1.6.0'):
        Hardswish = unwrap_module(nn.Hardswish)

    if version_larger_equal(torch.__version__, '1.7.0'):
        SiLU = unwrap_module(nn.SiLU)
        Unflatten = unwrap_module(nn.Unflatten)
        TripletMarginWithDistanceLoss = unwrap_module(nn.TripletMarginWithDistanceLoss)

def inject_pytorch_nn():
    Identity = wrap_module(nn.Identity)
    Linear = wrap_module(nn.Linear)
    Conv1d = wrap_module(nn.Conv1d)
    Conv2d = wrap_module(nn.Conv2d)
    Conv3d = wrap_module(nn.Conv3d)
    ConvTranspose1d = wrap_module(nn.ConvTranspose1d)
    ConvTranspose2d = wrap_module(nn.ConvTranspose2d)
    ConvTranspose3d = wrap_module(nn.ConvTranspose3d)
    Threshold = wrap_module(nn.Threshold)
    ReLU = wrap_module(nn.ReLU)
    Hardtanh = wrap_module(nn.Hardtanh)
    ReLU6 = wrap_module(nn.ReLU6)
    Sigmoid = wrap_module(nn.Sigmoid)
    Tanh = wrap_module(nn.Tanh)
    Softmax = wrap_module(nn.Softmax)
    Softmax2d = wrap_module(nn.Softmax2d)
    LogSoftmax = wrap_module(nn.LogSoftmax)
    ELU = wrap_module(nn.ELU)
    SELU = wrap_module(nn.SELU)
    CELU = wrap_module(nn.CELU)
    GLU = wrap_module(nn.GLU)
    GELU = wrap_module(nn.GELU)
    Hardshrink = wrap_module(nn.Hardshrink)
    LeakyReLU = wrap_module(nn.LeakyReLU)
    LogSigmoid = wrap_module(nn.LogSigmoid)
    Softplus = wrap_module(nn.Softplus)
    Softshrink = wrap_module(nn.Softshrink)
    MultiheadAttention = wrap_module(nn.MultiheadAttention)
    PReLU = wrap_module(nn.PReLU)
    Softsign = wrap_module(nn.Softsign)
    Softmin = wrap_module(nn.Softmin)
    Tanhshrink = wrap_module(nn.Tanhshrink)
    RReLU = wrap_module(nn.RReLU)
    AvgPool1d = wrap_module(nn.AvgPool1d)
    AvgPool2d = wrap_module(nn.AvgPool2d)
    AvgPool3d = wrap_module(nn.AvgPool3d)
    MaxPool1d = wrap_module(nn.MaxPool1d)
    MaxPool2d = wrap_module(nn.MaxPool2d)
    MaxPool3d = wrap_module(nn.MaxPool3d)
    MaxUnpool1d = wrap_module(nn.MaxUnpool1d)
    MaxUnpool2d = wrap_module(nn.MaxUnpool2d)
    MaxUnpool3d = wrap_module(nn.MaxUnpool3d)
    FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d)
    FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d)
    LPPool1d = wrap_module(nn.LPPool1d)
    LPPool2d = wrap_module(nn.LPPool2d)
    LocalResponseNorm = wrap_module(nn.LocalResponseNorm)
    BatchNorm1d = wrap_module(nn.BatchNorm1d)
    BatchNorm2d = wrap_module(nn.BatchNorm2d)
    BatchNorm3d = wrap_module(nn.BatchNorm3d)
    InstanceNorm1d = wrap_module(nn.InstanceNorm1d)
    InstanceNorm2d = wrap_module(nn.InstanceNorm2d)
    InstanceNorm3d = wrap_module(nn.InstanceNorm3d)
    LayerNorm = wrap_module(nn.LayerNorm)
    GroupNorm = wrap_module(nn.GroupNorm)
    SyncBatchNorm = wrap_module(nn.SyncBatchNorm)
    Dropout = wrap_module(nn.Dropout)
    Dropout2d = wrap_module(nn.Dropout2d)
    Dropout3d = wrap_module(nn.Dropout3d)
    AlphaDropout = wrap_module(nn.AlphaDropout)
    FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout)
    ReflectionPad1d = wrap_module(nn.ReflectionPad1d)
    ReflectionPad2d = wrap_module(nn.ReflectionPad2d)
    ReplicationPad2d = wrap_module(nn.ReplicationPad2d)
    ReplicationPad1d = wrap_module(nn.ReplicationPad1d)
    ReplicationPad3d = wrap_module(nn.ReplicationPad3d)
    CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d)
    Embedding = wrap_module(nn.Embedding)
    EmbeddingBag = wrap_module(nn.EmbeddingBag)
    RNNBase = wrap_module(nn.RNNBase)
    RNN = wrap_module(nn.RNN)
    LSTM = wrap_module(nn.LSTM)
    GRU = wrap_module(nn.GRU)
    RNNCellBase = wrap_module(nn.RNNCellBase)
    RNNCell = wrap_module(nn.RNNCell)
    LSTMCell = wrap_module(nn.LSTMCell)
    GRUCell = wrap_module(nn.GRUCell)
    PixelShuffle = wrap_module(nn.PixelShuffle)
    Upsample = wrap_module(nn.Upsample)
    UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d)
    UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d)
    PairwiseDistance = wrap_module(nn.PairwiseDistance)
    AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d)
    AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d)
    AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d)
    AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d)
    AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
    AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
    TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
    ZeroPad2d = wrap_module(nn.ZeroPad2d)
    ConstantPad1d = wrap_module(nn.ConstantPad1d)
    ConstantPad2d = wrap_module(nn.ConstantPad2d)
    ConstantPad3d = wrap_module(nn.ConstantPad3d)
    Bilinear = wrap_module(nn.Bilinear)
    CosineSimilarity = wrap_module(nn.CosineSimilarity)
    Unfold = wrap_module(nn.Unfold)
    Fold = wrap_module(nn.Fold)
    AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
    TransformerEncoder = wrap_module(nn.TransformerEncoder)
    TransformerDecoder = wrap_module(nn.TransformerDecoder)
    TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
    TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
    Transformer = wrap_module(nn.Transformer)
    Flatten = wrap_module(nn.Flatten)
    Hardsigmoid = wrap_module(nn.Hardsigmoid)

    if version_larger_equal(torch.__version__, '1.6.0'):
        Hardswish = wrap_module(nn.Hardswish)

    if version_larger_equal(torch.__version__, '1.7.0'):
        SiLU = wrap_module(nn.SiLU)
        Unflatten = wrap_module(nn.Unflatten)
        TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)

