Commit 0938ae70 authored by zhaoying1's avatar zhaoying1
Browse files

fix save method of adapter_model.bin

parent 1b73554f
......@@ -16,7 +16,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, answer_len
start, end = prompt_len, prompt_len + answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
......
import torch
from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
# from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
beta: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
**kwargs
):
self.finetuning_args = finetuning_args
# if disable_dropout:
# disable_dropout_in_model(model)
# if ref_model is not None:
# disable_dropout_in_model(ref_model)
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = finetuning_args.dpo_beta
self.beta = beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if self.is_deepspeed_enabled:
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,
......@@ -42,27 +50,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits = self.model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
all_logps = self._get_batch_logps(
all_logits,
......
......@@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
......@@ -37,10 +37,10 @@ def run_dpo(
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
model=model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
......
......@@ -2,29 +2,27 @@ import os
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from transformers import TrainerState, TrainerControl
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
from llmtuner.hparams import GeneratingArguments
logger = get_logger(__name__)
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Inherits PPOTrainer.
"""
......@@ -32,17 +30,18 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
def __init__(
self,
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["LogCallback"],
callbacks: List["TrainerCallback"],
compute_dtype: torch.dtype,
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.log_callback = callbacks[0]
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
self.compute_dtype = compute_dtype
self.state = TrainerState()
self.control = TrainerControl()
......@@ -75,10 +74,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate`
gen_kwargs = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
generating_args = self.generating_args.to_dict()
generating_args.update(dict(
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
......@@ -96,17 +96,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
queries, responses = self.get_inputs(batch, length_sampler, generating_args)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
self.model.train()
# Run PPO step
stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
......@@ -137,36 +143,44 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
self.log_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@torch.no_grad()
def get_inputs(
self,
batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None,
**generation_kwargs
length_sampler: Callable,
generating_args: Dict[str, Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
generating_args["max_new_tokens"] = length_sampler()
gen_kwargs = dict(
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
**batch
)
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
input_ids = batch["input_ids"]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
queries, responses = [], []
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1] + 2 # save the EOS token
else:
response_length = response_index[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
......@@ -191,7 +205,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
rewards = []
for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards
......@@ -202,7 +220,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False
return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
):
r"""
Calculates model outputs in multiple batches.
......@@ -220,6 +239,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
......@@ -235,12 +256,19 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)
......@@ -266,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"""
if self.args.should_save:
self._save(output_dir)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
......@@ -4,14 +4,14 @@ import math
from trl import PPOConfig
from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
......@@ -29,7 +29,9 @@ def run_ppo(
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
......@@ -44,7 +46,6 @@ def run_ppo(
)
if finetuning_args.ppo_score_norm:
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True
......@@ -61,11 +62,10 @@ def run_ppo(
)
# Initialize our Trainer
ppo_trainer = PPOPeftTrainer(
ppo_trainer = CustomPPOTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks,
callbacks=callbacks + [SavePeftModelCallback()],
compute_dtype=model_args.compute_dtype,
config=ppo_config,
model=model,
......
......@@ -2,12 +2,11 @@
import math
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling
from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
......@@ -27,8 +26,7 @@ def run_pt(
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
trainer = PeftTrainer(
finetuning_args=finetuning_args,
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment