Unverified Commit e21a6984 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

[v2.0] Refactor code hierarchy (part 2) (#2987)

parent f98ee672
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import functools import functools
import logging import logging
from . import trial from .. import trial
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -2,6 +2,4 @@ ...@@ -2,6 +2,4 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from .speedup import ModelSpeedup from .speedup import ModelSpeedup
from .pruning import *
from .quantization import *
from .compressor import Compressor, Pruner, Quantizer from .compressor import Compressor, Pruner, Quantizer
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import logging import logging
import torch import torch
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.torch.utils.utils import get_module_by_name from nni.compression.pytorch.utils.utils import get_module_by_name
from .compress_modules import replace_module from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim
...@@ -29,7 +29,7 @@ class ModelSpeedup: ...@@ -29,7 +29,7 @@ class ModelSpeedup:
map_location : str map_location : str
the device on which masks are placed, same to map_location in ```torch.load``` the device on which masks are placed, same to map_location in ```torch.load```
""" """
from nni._graph_utils import build_module_graph from nni.common.graph_utils import build_module_graph
self.bound_model = model self.bound_model = model
self.masks = torch.load(masks_file, map_location) self.masks = torch.load(masks_file, map_location)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.compression.torch.compressor import PrunerModuleWrapper from nni.compression.pytorch.compressor import PrunerModuleWrapper
try: try:
from thop import profile from thop import profile
...@@ -132,4 +132,4 @@ custom_mask_ops = { ...@@ -132,4 +132,4 @@ custom_mask_ops = {
nn.Conv2d: count_convNd_mask, nn.Conv2d: count_convNd_mask,
nn.Conv3d: count_convNd_mask, nn.Conv3d: count_convNd_mask,
nn.Linear: count_linear_mask, nn.Linear: count_linear_mask,
} }
\ No newline at end of file
...@@ -9,7 +9,7 @@ from collections import OrderedDict ...@@ -9,7 +9,7 @@ from collections import OrderedDict
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from ..pruning.constants_pruner import PRUNER_DICT # FIXME: I don't know where "utils" should be
SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d'] SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d']
SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME] SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME]
...@@ -63,6 +63,8 @@ class SensitivityAnalysis: ...@@ -63,6 +63,8 @@ class SensitivityAnalysis:
This value is effective only when the early_stop_mode is set. This value is effective only when the early_stop_mode is set.
""" """
from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT
self.model = model self.model = model
self.val_func = val_func self.val_func = val_func
self.target_layer = OrderedDict() self.target_layer = OrderedDict()
......
...@@ -17,7 +17,7 @@ class Dependency: ...@@ -17,7 +17,7 @@ class Dependency:
""" """
Build the graph for the model. Build the graph for the model.
""" """
from nni._graph_utils import TorchModuleGraph from nni.common.graph_utils import TorchModuleGraph
# check if the input is legal # check if the input is legal
if traced_model is None: if traced_model is None:
......
...@@ -2,4 +2,3 @@ ...@@ -2,4 +2,3 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from .compressor import Compressor, Pruner from .compressor import Compressor, Pruner
from .pruning import *
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
Example: Example:
from nni.nnicli import Experiment from nni.experiment import Experiment
exp = Experiment() exp = Experiment()
exp.start_experiment('../../../../examples/trials/mnist-pytorch/config.yml') exp.start_experiment('../../../../examples/trials/mnist-pytorch/config.yml')
...@@ -196,16 +196,16 @@ class TrialJob: ...@@ -196,16 +196,16 @@ class TrialJob:
Trial job id. Trial job id.
status: str status: str
Job status. Job status.
hyperParameters: list of `nnicli.TrialHyperParameters` hyperParameters: list of `nni.experiment.TrialHyperParameters`
See `nnicli.TrialHyperParameters`. See `nni.experiment.TrialHyperParameters`.
logPath: str logPath: str
Log path. Log path.
startTime: int startTime: int
Job start time (timestamp). Job start time (timestamp).
endTime: int endTime: int
Job end time (timestamp). Job end time (timestamp).
finalMetricData: list of `nnicli.TrialMetricData` finalMetricData: list of `nni.experiment.TrialMetricData`
See `nnicli.TrialMetricData`. See `nni.experiment.TrialMetricData`.
parameter_index: int parameter_index: int
Parameter index. Parameter index.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment