Unverified Commit 5a911b30 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Refactor model compression directory structure (#2501)

parent b5f4d218
......@@ -21,7 +21,7 @@ For each module, we should prepare four functions, three for shape inference and
## Usage
```python
from nni.compression.speedup.torch import ModelSpeedup
from nni.compression.torch import ModelSpeedup
# model: the model you want to speed up
# dummy_input: dummy input of the model, given to `jit.trace`
# masks_file: the mask file created by pruning algorithms
......
......@@ -6,8 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from models.cifar10.vgg import VGG
from nni.compression.speedup.torch import ModelSpeedup
from nni.compression.torch import apply_compression_results
from nni.compression.torch import apply_compression_results, ModelSpeedup
torch.manual_seed(0)
use_mask = True
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .pruning import *
from .quantization import *
from .compressor import Compressor, Pruner, Quantizer
from .pruners import *
from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import *
from .quantizers import *
from .apply_compression import apply_compression_results
from .gradient_rank_filter_pruners import *
from .speedup import ModelSpeedup
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .pruners import *
from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import *
from .apply_compression import apply_compression_results
from .gradient_rank_filter_pruners import *
......@@ -4,8 +4,8 @@
import logging
import torch
from schema import And, Optional
from .utils import CompressorSchema
from .compressor import Pruner
from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner
__all__ = ['ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
......
......@@ -3,7 +3,7 @@
import logging
import torch
from .compressor import Pruner
from ..compressor import Pruner
__all__ = ['TaylorFOWeightFilterPruner']
......
......@@ -5,8 +5,8 @@ import copy
import logging
import torch
from schema import And, Optional
from .compressor import Pruner
from .utils import CompressorSchema
from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'LotteryTicketPruner']
......
......@@ -4,8 +4,8 @@
import logging
import torch
from schema import And, Optional
from .utils import CompressorSchema
from .compressor import Pruner
from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner
__all__ = ['L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .quantizers import *
......@@ -4,8 +4,8 @@
import logging
import torch
from schema import Schema, And, Or, Optional
from .utils import CompressorSchema
from .compressor import Quantizer, QuantGrad, QuantType
from ..utils.config_validation import CompressorSchema
from ..compressor import Quantizer, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']
......
......@@ -10,8 +10,7 @@ from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18
from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner, apply_compression_results
from nni.compression.speedup.torch import ModelSpeedup
from nni.compression.torch import L1FilterPruner, apply_compression_results, ModelSpeedup
torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment