"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "baf7e5c927744122c89ab1270c6c312541c7eb41"
Unverified Commit 7ec9128e authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

FX function refactor (#17625)



* Function refactor

* Update src/transformers/utils/fx.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent edb672ac
......@@ -21,7 +21,7 @@ import math
import operator
import random
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from packaging import version
......@@ -883,13 +883,7 @@ class HFTracer(Tracer):
def proxy(self, node):
return HFProxy(node, self)
def trace(
self,
root: PreTrainedModel,
concrete_args: Optional[Dict[str, Any]] = None,
method_names: Optional[Iterable[str]] = None,
) -> Graph:
def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
if concrete_args is None:
concrete_args = {}
......@@ -1012,9 +1006,32 @@ class HFTracer(Tracer):
)
def get_concrete_args(model: nn.Module, input_names: List[str]):
sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
def check_if_model_is_supported(model: PreTrainedModel):
if model.__class__.__name__ not in _SUPPORTED_MODELS:
supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
)
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
disable_check: bool = False,
) -> GraphModule:
"""
......@@ -1025,6 +1042,8 @@ def symbolic_trace(
The model to trace.
input_names (`List[str]`, *optional*):
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (`bool`, *optional*, defaults to `False`):
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
......@@ -1041,24 +1060,10 @@ def symbolic_trace(
input_names = model.dummy_inputs.keys()
input_names = list(input_names)
concrete_args = get_concrete_args(model, input_names)
sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
if model.__class__.__name__ not in _SUPPORTED_MODELS:
supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
)
if not disable_check:
check_if_model_is_supported(model)
# Tracing.
tracer = HFTracer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment