Unverified Commit 7ec9128e authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

FX function refactor (#17625)



* Function refactor

* Update src/transformers/utils/fx.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent edb672ac
...@@ -21,7 +21,7 @@ import math ...@@ -21,7 +21,7 @@ import math
import operator import operator
import random import random
import warnings import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch import torch
from packaging import version from packaging import version
...@@ -883,13 +883,7 @@ class HFTracer(Tracer): ...@@ -883,13 +883,7 @@ class HFTracer(Tracer):
def proxy(self, node): def proxy(self, node):
return HFProxy(node, self) return HFProxy(node, self)
def trace( def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
self,
root: PreTrainedModel,
concrete_args: Optional[Dict[str, Any]] = None,
method_names: Optional[Iterable[str]] = None,
) -> Graph:
if concrete_args is None: if concrete_args is None:
concrete_args = {} concrete_args = {}
...@@ -1012,9 +1006,32 @@ class HFTracer(Tracer): ...@@ -1012,9 +1006,32 @@ class HFTracer(Tracer):
) )
def get_concrete_args(model: nn.Module, input_names: List[str]):
sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
def check_if_model_is_supported(model: PreTrainedModel):
if model.__class__.__name__ not in _SUPPORTED_MODELS:
supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
)
def symbolic_trace( def symbolic_trace(
model: PreTrainedModel, model: PreTrainedModel,
input_names: Optional[List[str]] = None, input_names: Optional[List[str]] = None,
disable_check: bool = False,
) -> GraphModule: ) -> GraphModule:
""" """
...@@ -1025,6 +1042,8 @@ def symbolic_trace( ...@@ -1025,6 +1042,8 @@ def symbolic_trace(
The model to trace. The model to trace.
input_names (`List[str]`, *optional*): input_names (`List[str]`, *optional*):
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead. The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (`bool`, *optional*, defaults to `False`):
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
Returns: Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model. `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
...@@ -1041,24 +1060,10 @@ def symbolic_trace( ...@@ -1041,24 +1060,10 @@ def symbolic_trace(
input_names = model.dummy_inputs.keys() input_names = model.dummy_inputs.keys()
input_names = list(input_names) input_names = list(input_names)
concrete_args = get_concrete_args(model, input_names)
sig = inspect.signature(model.forward) if not disable_check:
check_if_model_is_supported(model)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
if model.__class__.__name__ not in _SUPPORTED_MODELS:
supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
)
# Tracing. # Tracing.
tracer = HFTracer() tracer = HFTracer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment