Unverified Commit efdd4366 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FIX [`PEFT` / `Trainer` ] Handle better peft + quantized compiled models (#29055)

* handle peft + compiled models

* add tests

* fixup

* adapt from suggestions

* clarify comment
parent 5e95dcab
......@@ -429,6 +429,12 @@ class Trainer:
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)
# Filter out quantized + compiled models
if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
raise ValueError(
"You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
)
# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model(model):
raise ValueError(
......
......@@ -62,6 +62,7 @@ from transformers.testing_utils import (
require_deepspeed,
require_intel_extension_for_pytorch,
require_optuna,
require_peft,
require_ray,
require_safetensors,
require_sentencepiece,
......@@ -873,6 +874,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@require_peft
@require_bitsandbytes
def test_bnb_compile(self):
from peft import LoraConfig, get_peft_model
# Simply tests if initializing a Trainer with a PEFT + compiled model works out of the box
# QLoRA + torch compile is not really supported yet, but we should at least support the model
# loading and let torch throw the
tiny_model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM", load_in_4bit=True
)
peft_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
tiny_model = get_peft_model(tiny_model, peft_config)
tiny_model = torch.compile(tiny_model)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
learning_rate=1e-9,
logging_steps=5,
)
with self.assertRaises(ValueError):
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
@require_bitsandbytes
def test_rmsprop_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment