Unverified Commit f7b7edac authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

graphutils supports torch17 (#3076)

parent b6233e52
...@@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct' ...@@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND = 'prim::ListUnpack' LIST_UNPACK_KIND = 'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct' TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct'
TUPLE_UNPACK_KIND = 'prim::TupleUnpack' TUPLE_UNPACK_KIND = 'prim::TupleUnpack'
CONSTANT_KIND = 'prim::Constant'
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -68,9 +69,11 @@ class TorchGraph: ...@@ -68,9 +69,11 @@ class TorchGraph:
'Please provide model & dummy_input or the traced_model as inputs') 'Please provide model & dummy_input or the traced_model as inputs')
def _trace(self, model, dummy_input): def _trace(self, model, dummy_input):
with torch.onnx.set_training(model, False): training = model.training
self.trace = torch.jit.trace(model, dummy_input) model.eval()
torch._C._jit_pass_inline(self.trace.graph) self.trace = torch.jit.trace(model, dummy_input)
torch._C._jit_pass_inline(self.trace.graph)
model.train(training)
class TorchProtoGraph(TorchGraph): class TorchProtoGraph(TorchGraph):
...@@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph): ...@@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph):
self.global_count += 1 self.global_count += 1
op_type = node.kind() op_type = node.kind()
node_group = [node] node_group = [node]
inputs = list() inputs = set()
outputs = list() outputs = set()
node_queue = queue.Queue() node_queue = queue.Queue()
node_queue.put(node) node_queue.put(node)
while not node_queue.empty(): while not node_queue.empty():
curr_node = node_queue.get() curr_node = node_queue.get()
for _input in curr_node.inputs(): for _input in curr_node.inputs():
if _input.node().kind() == CONSTANT_KIND:
continue
input_name = _input.debugName() input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes: if input_name in output_to_node:
predecessor_node = output_to_node[input_name] for predecessor_node in output_to_node[input_name]:
if not self._is_key_func(predecessor_node): if predecessor_node in nodes:
node_group.append(predecessor_node) if not self._is_key_func(predecessor_node):
node_queue.put(predecessor_node) if predecessor_node not in node_group:
else: node_group.append(predecessor_node)
inputs.append(input_name) node_queue.put(predecessor_node)
else:
inputs.add(input_name)
else:
inputs.add(input_name)
else: else:
inputs.append(input_name) inputs.add(input_name)
for output in node.outputs(): for output in node.outputs():
outputs.append(output.debugName()) if output.node().kind() == CONSTANT_KIND:
continue
outputs.add(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs, key_node=node) node_group, inputs=list(inputs), outputs=list(outputs), key_node=node)
return nodepy return nodepy
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
...@@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph): ...@@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph):
if not op_type: if not op_type:
op_type = node.kind() op_type = node.kind()
node_group = [node] node_group = [node]
inputs = list() inputs = set()
outputs = list() outputs = set()
node_queue = queue.Queue() node_queue = queue.Queue()
node_queue.put(node) node_queue.put(node)
visited = {node} visited = {node}
while not node_queue.empty(): while not node_queue.empty():
curr_node = node_queue.get() curr_node = node_queue.get()
for _input in curr_node.inputs(): for _input in curr_node.inputs():
if _input.node().kind() == CONSTANT_KIND:
continue
input_name = _input.debugName() input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes: if input_name in output_to_node:
predecessor_node = output_to_node[input_name] for predecessor_node in output_to_node[input_name]:
if predecessor_node not in visited: if predecessor_node in nodes:
node_group.append(predecessor_node) if predecessor_node not in visited:
node_queue.put(predecessor_node) node_group.append(predecessor_node)
visited.add(predecessor_node) node_queue.put(predecessor_node)
visited.add(predecessor_node)
else:
inputs.add(input_name)
else: else:
inputs.append(input_name) inputs.add(input_name)
for _output in curr_node.outputs(): for _output in curr_node.outputs():
if _output.node().kind() == CONSTANT_KIND:
continue
output_name = _output.debugName() output_name = _output.debugName()
if output_name in input_to_node and input_to_node[output_name] in nodes: if output_name in input_to_node:
successor_node = input_to_node[output_name] for successor_node in input_to_node[output_name]:
if successor_node not in visited: if successor_node in nodes:
node_group.append(successor_node) if successor_node not in visited:
node_queue.put(successor_node) node_group.append(successor_node)
visited.add(successor_node) node_queue.put(successor_node)
visited.add(successor_node)
else:
outputs.add(output_name)
else: else:
outputs.append(output_name) outputs.add(output_name)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs) node_group, inputs=list(inputs), outputs=list(outputs))
return nodepy return nodepy
def _extract_cat_info(self, node_group, cpp_node): def _extract_cat_info(self, node_group, cpp_node):
...@@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph):
input_to_node[_input].append(node) input_to_node[_input].append(node)
for output in node.outputs: for output in node.outputs:
assert not output in output_to_node, \ assert not output in output_to_node, \
"One output cannot be generated by multiple nodes" "One output cannot be generated by multiple nodes %s" % output
output_to_node[output] = node output_to_node[output] = node
return name_to_node, input_to_node, output_to_node return name_to_node, input_to_node, output_to_node
...@@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph): ...@@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph):
omit_useless_nodes = True omit_useless_nodes = True
graph = self.trace.graph graph = self.trace.graph
_logger.debug(graph) _logger.debug(graph)
# build output mapping, from output debugName to its node # build input/output mapping, from input/output debugName to its node
output_to_node = {x.debugName(): n for n in graph.nodes() input_to_node = defaultdict(list)
for x in n.outputs()} output_to_node = defaultdict(list)
# build input mapping, from input debugName to its node for node in graph.nodes():
input_to_node = {x.debugName(): n for n in graph.nodes() if node.kind() == CONSTANT_KIND:
for x in n.inputs()} continue
for x in node.outputs():
if x.node().kind() == CONSTANT_KIND:
continue
output_to_node[x.debugName()].append(node)
assert len(output_to_node[x.debugName()]) <= 1, "One output cannot be generated by multiple nodes %s" % x.debugName()
for x in node.inputs():
if x.node().kind() == CONSTANT_KIND:
continue
input_to_node[x.debugName()].append(node)
# build module mapping, from module name to all nodes (as list) under this module scope # build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list) module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name # the mapping of function (non-module in forward) to nodes, key is scope name
...@@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph):
# associate module name with their trace graph nodes # associate module name with their trace graph nodes
for node in graph.nodes(): for node in graph.nodes():
if node.kind() == CONSTANT_KIND:
continue
module_name = self._get_module_name(node.scopeName()) module_name = self._get_module_name(node.scopeName())
if module_name in self.leaf_modules: if module_name in self.leaf_modules:
module_to_nodes[module_name].append(node) module_to_nodes[module_name].append(node)
......
...@@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): ...@@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# this traced model. # this traced model.
if traced is None: if traced is None:
assert model is not None and dummy_input is not None assert model is not None and dummy_input is not None
with torch.onnx.set_training(model, False): training = model.training
# We need to trace the model in this way, else it will have problems model.eval()
traced = torch.jit.trace(model, dummy_input) # We need to trace the model in eval mode
traced = torch.jit.trace(model, dummy_input)
model.train(training)
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced) fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
masks = fix_group_mask.fix_mask() masks = fix_group_mask.fix_mask()
......
...@@ -34,7 +34,7 @@ jobs: ...@@ -34,7 +34,7 @@ jobs:
set -e set -e
sudo apt-get install -y pandoc sudo apt-get install -y pandoc
python3 -m pip install -U --upgrade pygments python3 -m pip install -U --upgrade pygments
python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -m pip install -U torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==2.3.1 python3 -m pip install -U tensorflow==2.3.1
python3 -m pip install -U keras==2.4.2 python3 -m pip install -U keras==2.4.2
python3 -m pip install -U gym onnx peewee thop python3 -m pip install -U gym onnx peewee thop
...@@ -96,7 +96,7 @@ jobs: ...@@ -96,7 +96,7 @@ jobs:
- script: | - script: |
set -e set -e
python3 -m pip install -U torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==1.15.2 python3 -m pip install -U tensorflow==1.15.2
python3 -m pip install -U keras==2.1.6 python3 -m pip install -U keras==2.1.6
python3 -m pip install -U gym onnx peewee python3 -m pip install -U gym onnx peewee
......
...@@ -61,7 +61,6 @@ channel_dependency_ground_truth = { ...@@ -61,7 +61,6 @@ channel_dependency_ground_truth = {
unittest.TestLoader.sortTestMethodsUsing = None unittest.TestLoader.sortTestMethodsUsing = None
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class AnalysisUtilsTest(TestCase): class AnalysisUtilsTest(TestCase):
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported") @unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
def test_channel_dependency(self): def test_channel_dependency(self):
......
...@@ -47,7 +47,6 @@ def generate_random_sparsity_v2(model): ...@@ -47,7 +47,6 @@ def generate_random_sparsity_v2(model):
return cfg_list return cfg_list
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class DependencyawareTest(TestCase): class DependencyawareTest(TestCase):
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported") @unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
def test_dependency_aware_pruning(self): def test_dependency_aware_pruning(self):
......
...@@ -177,7 +177,6 @@ def channel_prune(model): ...@@ -177,7 +177,6 @@ def channel_prune(model):
pruner.compress() pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE) pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class SpeedupTestCase(TestCase): class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self): def test_speedup_vgg16(self):
prune_model_l1(vgg16()) prune_model_l1(vgg16())
......
...@@ -264,7 +264,6 @@ class SimpleDataset: ...@@ -264,7 +264,6 @@ class SimpleDataset:
def __len__(self): def __len__(self):
return 1000 return 1000
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class PrunerTestCase(TestCase): class PrunerTestCase(TestCase):
def test_pruners(self): def test_pruners(self):
pruners_test(bias=True) pruners_test(bias=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment