Commit 0938ae70 authored by zhaoying1's avatar zhaoying1
Browse files

fix save method of adapter_model.bin

parent 1b73554f
......@@ -2,9 +2,9 @@ import os
import json
import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers import Trainer
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
......@@ -14,7 +14,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class PairwisePeftTrainer(PeftTrainer):
class PairwiseTrainer(Trainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""
......@@ -32,21 +32,54 @@ class PairwisePeftTrainer(PeftTrainer):
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
Note that the first element will be removed from the output tuple.
Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
batch_size = inputs["input_ids"].size(0) // 2
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
# Split the inputs and rewards into two parts, chosen and rejected
batch_size = inputs["input_ids"].size(0) // 2
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
chosen_attn_mask, rejected_attn_mask = (
inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
)
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
chosen_scores, rejected_scores = [], []
# Compute pairwise loss. Only backprop on the different tokens before padding
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0
for i in range(batch_size):
chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
if len(check_divergence) == 0:
end_index = chosen_length
div_index = end_index - 1
else:
end_index = max(chosen_length, rejected_length)
div_index = check_divergence[0]
assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the EOS token for inference
chosen_scores.append(chosen_rewards[i, chosen_length-1])
rejected_scores.append(rejected_rewards[i, rejected_length-1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size
if return_outputs:
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
return loss, [loss, chosen_scores, rejected_scores]
return loss
def save_predictions(
self,
......@@ -63,10 +96,10 @@ class PairwisePeftTrainer(PeftTrainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
acc_scores, rej_scores = predict_results.predictions
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for acc_score, rej_score in zip(acc_scores, rej_scores):
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
writer.write("\n".join(res))
# Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
from llmtuner.tuner.rm.trainer import PairwiseTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
......@@ -34,13 +34,12 @@ def run_rm(
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = PairwisePeftTrainer(
finetuning_args=finetuning_args,
trainer = PairwiseTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
callbacks=callbacks + [SavePeftModelCallback()],
compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args)
)
......
......@@ -4,10 +4,10 @@ import torch
import numpy as np
import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
......@@ -16,7 +16,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class Seq2SeqPeftTrainer(PeftTrainer):
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""
......@@ -33,27 +33,29 @@ class Seq2SeqPeftTrainer(PeftTrainer):
Subclass and override to inject custom behavior.
"""
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len:
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
if "attention_mask" in inputs:
inputs["attention_mask"] = self._pad_tensors_to_target_len(
inputs["attention_mask"], inputs["labels"], pad_token_id=0
)
if "position_ids" in inputs:
inputs["position_ids"] = self._pad_tensors_to_target_len(
inputs["position_ids"], inputs["labels"], pad_token_id=0
)
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len:
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
if "attention_mask" in inputs:
inputs["attention_mask"] = self._pad_tensors_to_target_len(
inputs["attention_mask"], inputs["labels"], pad_token_id=0
)
if "position_ids" in inputs:
inputs["position_ids"] = self._pad_tensors_to_target_len(
inputs["position_ids"], inputs["labels"], pad_token_id=0
)
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
if generated_tokens is not None:
generated_tokens[:, :max(prompt_len, label_len)] = (
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :max(prompt_len, label_len)] = self.tokenizer.pad_token_id
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels
......@@ -65,16 +67,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
) -> torch.Tensor:
r"""
Pads the tensor to the same length as the target tensor.
Should only be called when predict_with_generate=True.
"""
if pad_token_id is None:
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = self.tokenizer.pad_token_id
else:
raise ValueError("PAD token is required.")
pad_token_id = pad_token_id if pad_token_id is not None else self.tokenizer.pad_token_id
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment