"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "06f561687c94f572b03ef71d707b697401b34ce9"
Unverified Commit e0769530 authored by Clara Pohland's avatar Clara Pohland Committed by GitHub
Browse files

Trainer._load_from_checkpoint - support loading multiple Peft adapters (#30505)



* Trainer: load checkpoint model with multiple adapters

* Trainer._load_from_checkpoint support multiple active adapters

* PeftModel.set_adapter does not support multiple adapters yet

* Trainer._load_from_checkpoint test multiple adapters

---------
Co-authored-by: default avatarClara Luise Pohland <clara-luise.pohland@telekom.de>
parent aa64f086
......@@ -2413,6 +2413,20 @@ class Trainer:
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
# if multiple adapters exist, they get saved in sub directories
adapter_subdirs = (
[
folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
and (
os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
)
]
if os.path.isdir(resume_from_checkpoint)
else []
)
if is_fsdp_ckpt and not self.is_fsdp_enabled:
raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
......@@ -2430,6 +2444,7 @@ class Trainer:
]
)
or is_fsdp_ckpt
or adapter_subdirs
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
......@@ -2503,7 +2518,14 @@ class Trainer:
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
if adapter_subdirs:
active_adapter = model.active_adapter
for subdir_name in adapter_subdirs:
peft_id = os.path.join(resume_from_checkpoint, subdir_name)
model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
model.set_adapter(active_adapter)
else:
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
......
......@@ -964,6 +964,63 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
with self.assertRaises(ValueError):
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
@require_peft
def test_multiple_peft_adapters(self):
from peft import LoraConfig, get_peft_model
# Tests if resuming from checkpoint works if the model has multiple adapters
MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tiny_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
peft_config = LoraConfig(
r=4,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
tiny_model = get_peft_model(tiny_model, peft_config, "adapter1")
tiny_model.add_adapter("adapter2", peft_config)
train_dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path=PATH_SAMPLE_TEXT,
block_size=tokenizer.max_len_single_sentence,
)
for example in train_dataset.examples:
example["labels"] = example["input_ids"]
tokenizer.pad_token = tokenizer.eos_token
with tempfile.TemporaryDirectory() as tmpdir:
args = TrainingArguments(
tmpdir,
per_device_train_batch_size=1,
learning_rate=1e-9,
save_steps=5,
logging_steps=5,
max_steps=10,
use_cpu=True,
)
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)
trainer.train()
parameters = dict(tiny_model.named_parameters())
state = dataclasses.asdict(trainer.state)
# Reinitialize trainer
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)
checkpoint = os.path.join(tmpdir, "checkpoint-5")
trainer.train(resume_from_checkpoint=checkpoint)
parameters1 = dict(tiny_model.named_parameters())
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(parameters, parameters1)
self.check_trainer_state_are_the_same(state, state1)
@require_bitsandbytes
def test_rmsprop_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment