"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8e5a1b2abb319c0d6e23f4f9c86c9064ac5aae89"
Unverified Commit c74befc9 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

HFTracer.trace can now take callables and torch.nn.Module (#18457)

* Enable HFTracer to trace with custom dummy inputs instead of pre-computed ones

* Add HFTracer.trace docstring, and make it possible to handle callable and torch.nn.Module in general

* Remove pdb comment

* Apply suggestions
parent fc1d841b
......@@ -882,11 +882,51 @@ class HFTracer(Tracer):
def proxy(self, node):
return HFProxy(node, self)
def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
dummy_inputs: Optional[Dict[str, Any]] = None,
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
"""
Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
`torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
`torch.nn.Module` instance to use as the root and add embedded constants to.
Args:
root (`torch.nn.Module` or `Callable`):
Either a `torch.nn.Module`` or a function to be traced through. If root is not a
[`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
concrete_args (`Dict[str, Any], *optional*):
Concrete arguments that should not be treated as Proxies
dummy_inputs (`Dict[str, Any]`, *optional*):
The dummy inputs needed to handle data-dependent control-flow if `root` is not a
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
`dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.
Returns:
`torch.fx.Graph`:
A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.
"""
sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)
if concrete_args is None:
concrete_args = {}
sig = inspect.signature(root.forward)
if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
for param in sig.parameters.values():
if param.name in dummy_inputs:
continue
if param.default is inspect.Parameter.empty:
raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
concrete_args.update({p.name: p.default for p in sig.parameters.values() if p.name not in dummy_inputs})
input_names = sig.parameters.keys() - concrete_args.keys()
# Creating a random input shape to generate dummy inputs.
......@@ -898,11 +938,24 @@ class HFTracer(Tracer):
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)
inputs = {}
inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
for input_name in input_names:
inputs.update(self._generate_dummy_input(root, input_name, shape))
if input_name in inputs:
continue
# We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
# be able to use HFTracer._generate_dummy_input.
if isinstance(root, PreTrainedModel) or type(root).__qualname__.startswith("_deserialize_graph_module"):
inputs.update(self._generate_dummy_input(root, input_name, shape))
else:
raise RuntimeError(
f"Could not generate input named {input_name} for because root is not a"
" transformers.PreTrainedModel."
)
concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
concrete_metas = {
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
for input_name, input_ in inputs.items()
}
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
concrete_metas[f"**{param.name}"] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment