Unverified Commit 5dc80762 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] pruning speedup support RecursiveScriptModule (#4801)

* support RecursiveScriptModule
parent 9644cf69
...@@ -57,7 +57,7 @@ class TorchGraph: ...@@ -57,7 +57,7 @@ class TorchGraph:
assert torch.__version__ >= '1.3.1' assert torch.__version__ >= '1.3.1'
# check if the input is legal # check if the input is legal
if traced_model is not None: if traced_model is not None:
assert isinstance(traced_model, torch.jit.TopLevelTracedModule) assert isinstance(traced_model, torch.jit.TopLevelTracedModule) or isinstance(traced_model, torch.jit.RecursiveScriptModule)
self.trace = traced_model self.trace = traced_model
# it's ok if the graph is already unpacked # it's ok if the graph is already unpacked
torch._C._jit_pass_inline(self.trace.graph) torch._C._jit_pass_inline(self.trace.graph)
...@@ -709,7 +709,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -709,7 +709,7 @@ class TorchModuleGraph(TorchGraph):
self.leaf_modules = self._extract_leaf_modules() self.leaf_modules = self._extract_leaf_modules()
module_to_type = {name: parse_traced_name( module_to_type = {name: parse_traced_name(
module._name) for name, module in self.trace.named_modules()} module._name if hasattr(module, '_name') else module.original_name) for name, module in self.trace.named_modules()}
# associate module name with their trace graph nodes # associate module name with their trace graph nodes
for node in graph.nodes(): for node in graph.nodes():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment