"src/include/gridwise_direct_convolution_1.hip.hpp" did not exist on "73480fee3635310aedbbec68b6084c94cfd2457d"
Unverified Commit 5dc80762 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] pruning speedup support RecursiveScriptModule (#4801)

* support RecursiveScriptModule
parent 9644cf69
......@@ -57,7 +57,7 @@ class TorchGraph:
assert torch.__version__ >= '1.3.1'
# check if the input is legal
if traced_model is not None:
assert isinstance(traced_model, torch.jit.TopLevelTracedModule)
assert isinstance(traced_model, torch.jit.TopLevelTracedModule) or isinstance(traced_model, torch.jit.RecursiveScriptModule)
self.trace = traced_model
# it's ok if the graph is already unpacked
torch._C._jit_pass_inline(self.trace.graph)
......@@ -709,7 +709,7 @@ class TorchModuleGraph(TorchGraph):
self.leaf_modules = self._extract_leaf_modules()
module_to_type = {name: parse_traced_name(
module._name) for name, module in self.trace.named_modules()}
module._name if hasattr(module, '_name') else module.original_name) for name, module in self.trace.named_modules()}
# associate module name with their trace graph nodes
for node in graph.nodes():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment