Unverified Commit 00bc6e20 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] Add HFTracer support for PEFT (#23006)



* add hack fx

* continue hacking

* final changes

* Test

* Add a keys method

* Fix keys method

* revert unneeded changes

* small nit

---------
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent 304aacac
...@@ -122,6 +122,7 @@ from .import_utils import ( ...@@ -122,6 +122,7 @@ from .import_utils import (
is_ninja_available, is_ninja_available,
is_onnx_available, is_onnx_available,
is_pandas_available, is_pandas_available,
is_peft_available,
is_phonemizer_available, is_phonemizer_available,
is_protobuf_available, is_protobuf_available,
is_psutil_available, is_psutil_available,
......
...@@ -28,6 +28,7 @@ import torch ...@@ -28,6 +28,7 @@ import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility
from torch.fx.proxy import ParameterProxy from torch.fx.proxy import ParameterProxy
from .. import PretrainedConfig, PreTrainedModel, logging from .. import PretrainedConfig, PreTrainedModel, logging
...@@ -53,10 +54,14 @@ from ..models.auto.modeling_auto import ( ...@@ -53,10 +54,14 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_peft_available, is_torch_fx_available
from ..utils.versions import importlib_metadata from ..utils.versions import importlib_metadata
if is_peft_available():
from peft import PeftModel
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES _IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
...@@ -164,6 +169,8 @@ _SPECIAL_SUPPORTED_MODELS = [ ...@@ -164,6 +169,8 @@ _SPECIAL_SUPPORTED_MODELS = [
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"Speech2Text2Decoder", "Speech2Text2Decoder",
"TrOCRDecoder", "TrOCRDecoder",
"PeftModelForCausalLM",
"PeftModelForSeq2SeqLM"
# TODO: add support for them as it should be quite easy to do so (small blocking issues). # TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering, # XLNetForQuestionAnswering,
] ]
...@@ -724,6 +731,7 @@ class HFTracer(Tracer): ...@@ -724,6 +731,7 @@ class HFTracer(Tracer):
"clamp", "clamp",
"finfo", "finfo",
] ]
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
def __init__(self, autowrap_modules=(math,), autowrap_functions=()): def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
...@@ -794,6 +802,8 @@ class HFTracer(Tracer): ...@@ -794,6 +802,8 @@ class HFTracer(Tracer):
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES), *get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"PeftModelForCausalLM",
"PeftModelForSeq2SeqLM",
]: ]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]: elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
...@@ -1044,7 +1054,9 @@ class HFTracer(Tracer): ...@@ -1044,7 +1054,9 @@ class HFTracer(Tracer):
continue continue
# We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
# be able to use HFTracer._generate_dummy_input. # be able to use HFTracer._generate_dummy_input.
if isinstance(root, PreTrainedModel) or type(root).__qualname__.startswith("_deserialize_graph_module"): if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
"_deserialize_graph_module"
):
inputs.update(self._generate_dummy_input(root, input_name, shape)) inputs.update(self._generate_dummy_input(root, input_name, shape))
else: else:
raise RuntimeError( raise RuntimeError(
...@@ -1157,6 +1169,17 @@ class HFTracer(Tracer): ...@@ -1157,6 +1169,17 @@ class HFTracer(Tracer):
m, module_qualified_name m, module_qualified_name
) )
@compatibility(is_backward_compatible=True)
def keys(self, obj: "Proxy") -> Any:
"""Called when a proxy object is has the keys() method called.
This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
your custom tracer.
"""
attribute = HFAttribute(obj, "keys")()
if obj.node.target == "**kwargs":
return attribute._metadata
return attribute
def get_concrete_args(model: nn.Module, input_names: List[str]): def get_concrete_args(model: nn.Module, input_names: List[str]):
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
......
...@@ -407,6 +407,10 @@ def is_torch_fx_available(): ...@@ -407,6 +407,10 @@ def is_torch_fx_available():
return _torch_fx_available return _torch_fx_available
def is_peft_available():
return importlib.util.find_spec("peft") is not None
def is_bs4_available(): def is_bs4_available():
return importlib.util.find_spec("bs4") is not None return importlib.util.find_spec("bs4") is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment