Unverified Commit e3ee45aa authored by regisss's avatar regisss Committed by GitHub
Browse files

Enable to use custom tracer in FX `symbolic_trace` (#23105)



* Enable to use custom tracer in FX `symbolic_trace`

* Integrate feedback from review

* Formatting
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 441658dd
......@@ -1207,6 +1207,7 @@ def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
disable_check: bool = False,
tracer_cls: Type[HFTracer] = HFTracer,
) -> GraphModule:
"""
Performs symbolic tracing on the model.
......@@ -1218,6 +1219,8 @@ def symbolic_trace(
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (`bool`, *optional*, defaults to `False`):
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
......@@ -1240,7 +1243,7 @@ def symbolic_trace(
check_if_model_is_supported(model)
# Tracing.
tracer = HFTracer()
tracer = tracer_cls()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment