"docs/source/community.md" did not exist on "223084e42b57cd0d8e78de38e15a42d5d6b04391"
Unverified Commit 4818bf7a authored by Pavel Belevich's avatar Pavel Belevich Committed by GitHub
Browse files

HFTracer.trace should use/return self.graph to be compatible with torch.fx.Tracer (#15824)

parent ad0d7d17
......@@ -472,7 +472,7 @@ class HFTracer(Tracer):
self._patch_leaf_functions_for_root(root)
graph = super().trace(root, concrete_args=concrete_args)
self.graph = super().trace(root, concrete_args=concrete_args)
self._patch_leaf_functions_for_root(root, restore=True)
......@@ -482,16 +482,16 @@ class HFTracer(Tracer):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
# A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
for node in graph.nodes:
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in input_names:
node.args = ()
# It is a concrete arg so it is not used and should be removed.
else:
graph.erase_node(node)
self.graph.erase_node(node)
return graph
return self.graph
def _insert_module_as_submodule(self, mod: nn.Module) -> str:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment