Unverified Commit bccda3d8 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix _graph_utils import (#2675)

parent 0f33bc7e
......@@ -3,7 +3,6 @@
import logging
import torch
from nni._graph_utils import build_module_graph
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
......@@ -51,6 +50,8 @@ class ModelSpeedup:
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
from nni._graph_utils import build_module_graph
self.bound_model = model
self.masks = torch.load(masks_file, map_location)
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
......
......@@ -4,8 +4,6 @@
import csv
import logging
from nni._graph_utils import TorchModuleGraph
__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency']
CONV_TYPE = 'aten::_convolution'
......@@ -19,6 +17,8 @@ class Dependency:
"""
Build the graph for the model.
"""
from nni._graph_utils import TorchModuleGraph
# check if the input is legal
if traced_model is None:
# user should provide model & dummy_input to trace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment