Unverified Commit 5a911b30 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Refactor model compression directory structure (#2501)

parent b5f4d218
...@@ -21,7 +21,7 @@ For each module, we should prepare four functions, three for shape inference and ...@@ -21,7 +21,7 @@ For each module, we should prepare four functions, three for shape inference and
## Usage ## Usage
```python ```python
from nni.compression.speedup.torch import ModelSpeedup from nni.compression.torch import ModelSpeedup
# model: the model you want to speed up # model: the model you want to speed up
# dummy_input: dummy input of the model, given to `jit.trace` # dummy_input: dummy input of the model, given to `jit.trace`
# masks_file: the mask file created by pruning algorithms # masks_file: the mask file created by pruning algorithms
......
...@@ -6,8 +6,7 @@ import torch.nn as nn ...@@ -6,8 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from models.cifar10.vgg import VGG from models.cifar10.vgg import VGG
from nni.compression.speedup.torch import ModelSpeedup from nni.compression.torch import apply_compression_results, ModelSpeedup
from nni.compression.torch import apply_compression_results
torch.manual_seed(0) torch.manual_seed(0)
use_mask = True use_mask = True
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .pruning import *
from .quantization import *
from .compressor import Compressor, Pruner, Quantizer from .compressor import Compressor, Pruner, Quantizer
from .pruners import * from .speedup import ModelSpeedup
from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import *
from .quantizers import *
from .apply_compression import apply_compression_results
from .gradient_rank_filter_pruners import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .pruners import *
from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import *
from .apply_compression import apply_compression_results
from .gradient_rank_filter_pruners import *
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import logging import logging
import torch import torch
from schema import And, Optional from schema import And, Optional
from .utils import CompressorSchema from ..utils.config_validation import CompressorSchema
from .compressor import Pruner from ..compressor import Pruner
__all__ = ['ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner'] __all__ = ['ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import logging import logging
import torch import torch
from .compressor import Pruner from ..compressor import Pruner
__all__ = ['TaylorFOWeightFilterPruner'] __all__ = ['TaylorFOWeightFilterPruner']
......
...@@ -5,8 +5,8 @@ import copy ...@@ -5,8 +5,8 @@ import copy
import logging import logging
import torch import torch
from schema import And, Optional from schema import And, Optional
from .compressor import Pruner from ..utils.config_validation import CompressorSchema
from .utils import CompressorSchema from ..compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'LotteryTicketPruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'LotteryTicketPruner']
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import logging import logging
import torch import torch
from schema import And, Optional from schema import And, Optional
from .utils import CompressorSchema from ..utils.config_validation import CompressorSchema
from .compressor import Pruner from ..compressor import Pruner
__all__ = ['L1FilterPruner', 'L2FilterPruner', 'FPGMPruner'] __all__ = ['L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .quantizers import *
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import logging import logging
import torch import torch
from schema import Schema, And, Or, Optional from schema import Schema, And, Or, Optional
from .utils import CompressorSchema from ..utils.config_validation import CompressorSchema
from .compressor import Quantizer, QuantGrad, QuantType from ..compressor import Quantizer, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']
......
...@@ -10,8 +10,7 @@ from torchvision.models.vgg import vgg16 ...@@ -10,8 +10,7 @@ from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18 from torchvision.models.resnet import resnet18
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner, apply_compression_results from nni.compression.torch import L1FilterPruner, apply_compression_results, ModelSpeedup
from nni.compression.speedup.torch import ModelSpeedup
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment