Unverified Commit 00bc6e20 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] Add HFTracer support for PEFT (#23006)



* add hack fx

* continue hacking

* final changes

* Test

* Add a keys method

* Fix keys method

* revert unneeded changes

* small nit

---------
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent 304aacac
......@@ -122,6 +122,7 @@ from .import_utils import (
is_ninja_available,
is_onnx_available,
is_pandas_available,
is_peft_available,
is_phonemizer_available,
is_protobuf_available,
is_psutil_available,
......
......@@ -28,6 +28,7 @@ import torch
from packaging import version
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility
from torch.fx.proxy import ParameterProxy
from .. import PretrainedConfig, PreTrainedModel, logging
......@@ -53,10 +54,14 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_peft_available, is_torch_fx_available
from ..utils.versions import importlib_metadata
if is_peft_available():
from peft import PeftModel
logger = logging.get_logger(__name__)
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
......@@ -164,6 +169,8 @@ _SPECIAL_SUPPORTED_MODELS = [
"GPT2DoubleHeadsModel",
"Speech2Text2Decoder",
"TrOCRDecoder",
"PeftModelForCausalLM",
"PeftModelForSeq2SeqLM"
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering,
]
......@@ -724,6 +731,7 @@ class HFTracer(Tracer):
"clamp",
"finfo",
]
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
......@@ -794,6 +802,8 @@ class HFTracer(Tracer):
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
"GPT2DoubleHeadsModel",
"PeftModelForCausalLM",
"PeftModelForSeq2SeqLM",
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
......@@ -1044,7 +1054,9 @@ class HFTracer(Tracer):
continue
# We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
# be able to use HFTracer._generate_dummy_input.
if isinstance(root, PreTrainedModel) or type(root).__qualname__.startswith("_deserialize_graph_module"):
if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
"_deserialize_graph_module"
):
inputs.update(self._generate_dummy_input(root, input_name, shape))
else:
raise RuntimeError(
......@@ -1157,6 +1169,17 @@ class HFTracer(Tracer):
m, module_qualified_name
)
@compatibility(is_backward_compatible=True)
def keys(self, obj: "Proxy") -> Any:
"""Called when a proxy object is has the keys() method called.
This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
your custom tracer.
"""
attribute = HFAttribute(obj, "keys")()
if obj.node.target == "**kwargs":
return attribute._metadata
return attribute
def get_concrete_args(model: nn.Module, input_names: List[str]):
sig = inspect.signature(model.forward)
......
......@@ -407,6 +407,10 @@ def is_torch_fx_available():
return _torch_fx_available
def is_peft_available():
return importlib.util.find_spec("peft") is not None
def is_bs4_available():
return importlib.util.find_spec("bs4") is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment