Unverified Commit 80b6cb3b authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #3030 from microsoft/v2.0

Merge v2.0 into master
parents 77dac12b ff1af7f2
......@@ -2,6 +2,4 @@
# Licensed under the MIT license.
from .speedup import ModelSpeedup
from .pruning import *
from .quantization import *
from .compressor import Compressor, Pruner, Quantizer
......@@ -3,8 +3,8 @@
import logging
import torch
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from nni.compression.torch.utils.utils import get_module_by_name
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.utils import get_module_by_name
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim
......@@ -29,7 +29,7 @@ class ModelSpeedup:
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
from nni._graph_utils import build_module_graph
from nni.common.graph_utils import build_module_graph
self.bound_model = model
self.masks = torch.load(masks_file, map_location)
......
......@@ -3,7 +3,7 @@
import torch
import torch.nn as nn
from nni.compression.torch.compressor import PrunerModuleWrapper
from nni.compression.pytorch.compressor import PrunerModuleWrapper
try:
from thop import profile
......@@ -132,4 +132,4 @@ custom_mask_ops = {
nn.Conv2d: count_convNd_mask,
nn.Conv3d: count_convNd_mask,
nn.Linear: count_linear_mask,
}
\ No newline at end of file
}
......@@ -9,7 +9,7 @@ from collections import OrderedDict
import numpy as np
import torch.nn as nn
from ..pruning.constants_pruner import PRUNER_DICT
# FIXME: I don't know where "utils" should be
SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d']
SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME]
......@@ -63,6 +63,8 @@ class SensitivityAnalysis:
This value is effective only when the early_stop_mode is set.
"""
from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT
self.model = model
self.val_func = val_func
self.target_layer = OrderedDict()
......
......@@ -17,7 +17,7 @@ class Dependency:
"""
Build the graph for the model.
"""
from nni._graph_utils import TorchModuleGraph
from nni.common.graph_utils import TorchModuleGraph
# check if the input is legal
if traced_model is None:
......
......@@ -2,4 +2,3 @@
# Licensed under the MIT license.
from .compressor import Compressor, Pruner
from .pruning import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment