Unverified Commit bfe94a39 authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[Enhacne] Support maybe_raise_or_warn for peft (#5653)

* Support maybe_raise_or_warn for peft

* fix by comment

* unwrap function
parent c9c5436c
......@@ -49,6 +49,7 @@ from ..utils import (
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_torch_version,
is_transformers_available,
logging,
......@@ -270,6 +271,20 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
)
def _unwrap_model(model):
"""Unwraps a model."""
if is_compiled_module(model):
model = model._orig_mod
if is_peft_available():
from peft import PeftModel
if isinstance(model, PeftModel):
model = model.base_model.model
return model
def maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
):
......@@ -287,9 +302,8 @@ def maybe_raise_or_warn(
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
model_cls = sub_model.__class__
if is_compiled_module(sub_model):
model_cls = sub_model._orig_mod.__class__
unwrapped_sub_model = _unwrap_model(sub_model)
model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj):
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment