Unverified Commit bccda3d8 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix _graph_utils import (#2675)

parent 0f33bc7e
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import logging import logging
import torch import torch
from nni._graph_utils import build_module_graph
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from .compress_modules import replace_module from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
...@@ -51,6 +50,8 @@ class ModelSpeedup: ...@@ -51,6 +50,8 @@ class ModelSpeedup:
map_location : str map_location : str
the device on which masks are placed, same to map_location in ```torch.load``` the device on which masks are placed, same to map_location in ```torch.load```
""" """
from nni._graph_utils import build_module_graph
self.bound_model = model self.bound_model = model
self.masks = torch.load(masks_file, map_location) self.masks = torch.load(masks_file, map_location)
self.inferred_masks = dict() # key: module_name, value: ModuleMasks self.inferred_masks = dict() # key: module_name, value: ModuleMasks
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
import csv import csv
import logging import logging
from nni._graph_utils import TorchModuleGraph
__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency'] __all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency']
CONV_TYPE = 'aten::_convolution' CONV_TYPE = 'aten::_convolution'
...@@ -19,6 +17,8 @@ class Dependency: ...@@ -19,6 +17,8 @@ class Dependency:
""" """
Build the graph for the model. Build the graph for the model.
""" """
from nni._graph_utils import TorchModuleGraph
# check if the input is legal # check if the input is legal
if traced_model is None: if traced_model is None:
# user should provide model & dummy_input to trace # user should provide model & dummy_input to trace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment