# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import math
import sys
import unittest
from unittest import TestCase, main
from nni.algorithms.compression.pytorch.pruning import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
    L2FilterPruner, AGPPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \
    TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, \
    AutoCompressPruner, AMCPruner

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from ut.sdk.models.pytorch_models.mobilenet import MobileNet

def validate_sparsity(wrapper, sparsity, bias=False):
    masks = [wrapper.weight_mask]
    if bias and wrapper.bias_mask is not None:
        masks.append(wrapper.bias_mask)
    for m in masks:
        actual_sparsity = (m == 0).sum().item() / m.numel()
        msg = 'actual sparsity: {:.2f}, target sparsity: {:.2f}'.format(actual_sparsity, sparsity)
        assert math.isclose(actual_sparsity, sparsity, abs_tol=0.1), msg

prune_config = {
    'level': {
        'pruner_class': LevelPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['default'],
        }],
        'validators': [
            lambda model: validate_sparsity(model.conv1, 0.5, False),
            lambda model: validate_sparsity(model.fc, 0.5, False)
        ]
    },
    'agp': {
        'pruner_class': AGPPruner,
        'config_list': [{
            'sparsity': 0.8,
            'op_types': ['Conv2d']
        }],
        'trainer': lambda model, optimizer, criterion, epoch: model,
        'validators': []
    },
    'slim': {
        'pruner_class': SlimPruner,
        'config_list': [{
            'sparsity': 0.7,
            'op_types': ['BatchNorm2d']
        }],
        'trainer': lambda model, optimizer, criterion, epoch: model,
        'validators': [
            lambda model: validate_sparsity(model.bn1, 0.7, model.bias)
        ]
    },
    'fpgm': {
        'pruner_class': FPGMPruner,
        'config_list':[{
            'sparsity': 0.5,
            'op_types': ['Conv2d']
        }],
        'validators': [
            lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
        ]
    },
    'l1': {
        'pruner_class': L1FilterPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'validators': [
            lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
        ]
    },
    'l2': {
        'pruner_class': L2FilterPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'validators': [
            lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
        ]
    },
    'taylorfo': {
        'pruner_class': TaylorFOWeightFilterPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'trainer': lambda model, optimizer, criterion, epoch: model,
        'validators': [
            lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
        ]
    },
    'mean_activation': {
        'pruner_class': ActivationMeanRankFilterPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'trainer': lambda model, optimizer, criterion, epoch: model,
        'validators': [
            lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
        ]
    },
    'apoz': {
        'pruner_class': ActivationAPoZRankFilterPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'trainer': lambda model, optimizer, criterion, epoch: model,
        'validators': [
            lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
        ]
    },
    'netadapt': {
        'pruner_class': NetAdaptPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d']
        }],
        'short_term_fine_tuner': lambda model: model, 
        'evaluator':lambda model: 0.9,
        'validators': []
    },
    'simulatedannealing': {
        'pruner_class': SimulatedAnnealingPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d']
        }],
        'evaluator':lambda model: 0.9,
        'validators': []
    },
    'admm': {
        'pruner_class': ADMMPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'trainer': lambda model, optimizer, criterion, epoch : model, 
        'validators': [
            lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
        ]
    },
    'autocompress_l1': {
        'pruner_class': AutoCompressPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'base_algo': 'l1',
        'trainer': lambda model, optimizer, criterion, epoch : model,
        'evaluator': lambda model: 0.9,
        'dummy_input': torch.randn([64, 1, 28, 28]),
        'validators': []
    },
    'autocompress_l2': {
        'pruner_class': AutoCompressPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'base_algo': 'l2',
        'trainer': lambda model, optimizer, criterion, epoch : model,
        'evaluator': lambda model: 0.9,
        'dummy_input': torch.randn([64, 1, 28, 28]),
        'validators': []
    },
    'autocompress_fpgm': {
        'pruner_class': AutoCompressPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['Conv2d'],
        }],
        'base_algo': 'fpgm',
        'trainer': lambda model, optimizer, criterion, epoch : model,
        'evaluator': lambda model: 0.9,
        'dummy_input': torch.randn([64, 1, 28, 28]),
        'validators': []
    },
    'amc': {
        'pruner_class': AMCPruner,
        'config_list':[{
            'op_types': ['Conv2d', 'Linear']
        }]
    }
}

class Model(nn.Module):
    def __init__(self, bias=True):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1, bias=bias)
        self.bn1 = nn.BatchNorm2d(8)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(8, 2, bias=bias)
        self.bias = bias
    def forward(self, x):
        return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1))

class SimpleDataset:
    def __getitem__(self, index):
        return torch.randn(3, 32, 32), 1.

    def __len__(self):
        return 1000

def train(model, train_loader, criterion, optimizer):
    model.train()
    device = next(model.parameters()).device
    x = torch.randn(2, 1, 28, 28).to(device)
    y = torch.tensor([0, 1]).long().to(device)
    # print('hello...')

    for _ in range(2):
        out = model(x)
        loss = criterion(out, y)
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress_l1', 'autocompress_l2', 'autocompress_fpgm',], bias=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dummy_input = torch.randn(2, 1, 28, 28).to(device)

    criterion = torch.nn.CrossEntropyLoss()
    train_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True)

    def trainer(model, optimizer, criterion, epoch):
        return train(model, train_loader, criterion, optimizer)

    for pruner_name in pruner_names:
        print('testing {}...'.format(pruner_name))

        model = Model(bias=bias).to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        config_list = prune_config[pruner_name]['config_list']

        if pruner_name == 'netadapt':
            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, short_term_fine_tuner=prune_config[pruner_name]['short_term_fine_tuner'], evaluator=prune_config[pruner_name]['evaluator'])
        elif pruner_name == 'simulatedannealing':
            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, evaluator=prune_config[pruner_name]['evaluator'])
        elif pruner_name in ('agp', 'slim', 'taylorfo', 'apoz', 'mean_activation'):
            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=trainer, optimizer=optimizer, criterion=criterion)
        elif pruner_name == 'admm':
            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=trainer)
        elif pruner_name.startswith('autocompress'):
            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], criterion=torch.nn.CrossEntropyLoss(), dummy_input=dummy_input, base_algo=prune_config[pruner_name]['base_algo'])
        else:
            pruner = prune_config[pruner_name]['pruner_class'](model, config_list)

        pruner.compress()
        pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28), device=device)

        for v in prune_config[pruner_name]['validators']:
            v(model)

    filePaths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', './search_history.csv', './search_result.json']
    for f in filePaths:
        if os.path.exists(f):
            os.remove(f)


def _test_agp(pruning_algorithm):
    train_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True)
    model = Model()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    def trainer(model, optimizer, criterion, epoch):
        return train(model, train_loader, criterion, optimizer)

    config_list = prune_config['agp']['config_list']
    pruner = AGPPruner(model, config_list, optimizer=optimizer, trainer=trainer, criterion=torch.nn.CrossEntropyLoss(), pruning_algorithm=pruning_algorithm)
    pruner.compress()

    target_sparsity = pruner.compute_target_sparsity(config_list[0])
    actual_sparsity = (model.conv1.weight_mask == 0).sum().item() / model.conv1.weight_mask.numel()
    # set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
    assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)


class PrunerTestCase(TestCase):
    def test_pruners(self):
        pruners_test(bias=True)

    def test_pruners_no_bias(self):
        pruners_test(bias=False)

    def test_agp_pruner(self):
        for pruning_algorithm in ['l1', 'l2', 'fpgm', 'taylorfo', 'apoz']:
            _test_agp(pruning_algorithm)

        for pruning_algorithm in ['level']:
            prune_config['agp']['config_list'][0]['op_types'] = ['default']
            _test_agp(pruning_algorithm)

    def testAMC(self):
        model = MobileNet(n_class=10)

        def validate(val_loader, model):
            return 80.
        val_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True)
        config_list = prune_config['amc']['config_list']
        pruner = AMCPruner(model, config_list, validate, val_loader, train_episode=1)
        pruner.compress()

        pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,3,32,32))
        filePaths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth']
        for f in filePaths:
            if os.path.exists(f):
                os.remove(f)

if __name__ == '__main__':
    main()
