Unverified Commit e3ee45aa authored by regisss's avatar regisss Committed by GitHub
Browse files

Enable to use custom tracer in FX `symbolic_trace` (#23105)



* Enable to use custom tracer in FX `symbolic_trace`

* Integrate feedback from review

* Formatting
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 441658dd
...@@ -1207,6 +1207,7 @@ def symbolic_trace( ...@@ -1207,6 +1207,7 @@ def symbolic_trace(
model: PreTrainedModel, model: PreTrainedModel,
input_names: Optional[List[str]] = None, input_names: Optional[List[str]] = None,
disable_check: bool = False, disable_check: bool = False,
tracer_cls: Type[HFTracer] = HFTracer,
) -> GraphModule: ) -> GraphModule:
""" """
Performs symbolic tracing on the model. Performs symbolic tracing on the model.
...@@ -1218,6 +1219,8 @@ def symbolic_trace( ...@@ -1218,6 +1219,8 @@ def symbolic_trace(
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead. The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (`bool`, *optional*, defaults to `False`): disable_check (`bool`, *optional*, defaults to `False`):
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes. If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
Returns: Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model. `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
...@@ -1240,7 +1243,7 @@ def symbolic_trace( ...@@ -1240,7 +1243,7 @@ def symbolic_trace(
check_if_model_is_supported(model) check_if_model_is_supported(model)
# Tracing. # Tracing.
tracer = HFTracer() tracer = tracer_cls()
traced_graph = tracer.trace(model, concrete_args=concrete_args) traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph) traced = torch.fx.GraphModule(model, traced_graph)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment