Commit 0938ae70 authored by zhaoying1's avatar zhaoying1
Browse files

fix save method of adapter_model.bin

parent 1b73554f
...@@ -2,9 +2,9 @@ import os ...@@ -2,9 +2,9 @@ import os
import json import json
import torch import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers import Trainer
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
...@@ -14,7 +14,7 @@ if TYPE_CHECKING: ...@@ -14,7 +14,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
class PairwisePeftTrainer(PeftTrainer): class PairwiseTrainer(Trainer):
r""" r"""
Inherits PeftTrainer to compute pairwise loss. Inherits PeftTrainer to compute pairwise loss.
""" """
...@@ -32,21 +32,54 @@ class PairwisePeftTrainer(PeftTrainer): ...@@ -32,21 +32,54 @@ class PairwisePeftTrainer(PeftTrainer):
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence. Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
Note that the first element will be removed from the output tuple.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
""" """
batch_size = inputs["input_ids"].size(0) // 2 # Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True) _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2 if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() # Split the inputs and rewards into two parts, chosen and rejected
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss batch_size = inputs["input_ids"].size(0) // 2
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
chosen_attn_mask, rejected_attn_mask = (
inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
)
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
chosen_scores, rejected_scores = [], []
# Compute pairwise loss. Only backprop on the different tokens before padding
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0
for i in range(batch_size):
chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
if len(check_divergence) == 0:
end_index = chosen_length
div_index = end_index - 1
else:
end_index = max(chosen_length, rejected_length)
div_index = check_divergence[0]
assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the EOS token for inference
chosen_scores.append(chosen_rewards[i, chosen_length-1])
rejected_scores.append(rejected_rewards[i, rejected_length-1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size
if return_outputs:
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
return loss, [loss, chosen_scores, rejected_scores]
return loss
def save_predictions( def save_predictions(
self, self,
...@@ -63,10 +96,10 @@ class PairwisePeftTrainer(PeftTrainer): ...@@ -63,10 +96,10 @@ class PairwisePeftTrainer(PeftTrainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info(f"Saving prediction results to {output_prediction_file}")
acc_scores, rej_scores = predict_results.predictions chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = [] res: List[str] = []
for acc_score, rej_score in zip(acc_scores, rej_scores): for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)})) res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
writer.write("\n".join(res)) writer.write("\n".join(res))
# Inspired by: # Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer from llmtuner.tuner.rm.trainer import PairwiseTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
...@@ -34,13 +34,12 @@ def run_rm( ...@@ -34,13 +34,12 @@ def run_rm(
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
trainer = PairwisePeftTrainer( trainer = PairwiseTrainer(
finetuning_args=finetuning_args,
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks + [SavePeftModelCallback()],
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args)
) )
......
...@@ -4,10 +4,10 @@ import torch ...@@ -4,10 +4,10 @@ import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
...@@ -16,7 +16,7 @@ if TYPE_CHECKING: ...@@ -16,7 +16,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
class Seq2SeqPeftTrainer(PeftTrainer): class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r""" r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
""" """
...@@ -33,27 +33,29 @@ class Seq2SeqPeftTrainer(PeftTrainer): ...@@ -33,27 +33,29 @@ class Seq2SeqPeftTrainer(PeftTrainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) if self.args.predict_with_generate:
if prompt_len > label_len: assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) assert self.tokenizer.pad_token_id is not None, "Pad token is required."
if label_len > prompt_len: prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"]) if prompt_len > label_len:
if "attention_mask" in inputs: inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
inputs["attention_mask"] = self._pad_tensors_to_target_len( if label_len > prompt_len:
inputs["attention_mask"], inputs["labels"], pad_token_id=0 inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
) if "attention_mask" in inputs:
if "position_ids" in inputs: inputs["attention_mask"] = self._pad_tensors_to_target_len(
inputs["position_ids"] = self._pad_tensors_to_target_len( inputs["attention_mask"], inputs["labels"], pad_token_id=0
inputs["position_ids"], inputs["labels"], pad_token_id=0 )
) if "position_ids" in inputs:
inputs["position_ids"] = self._pad_tensors_to_target_len(
inputs["position_ids"], inputs["labels"], pad_token_id=0
)
loss, generated_tokens, labels = super().prediction_step( loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
if generated_tokens is not None: if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :max(prompt_len, label_len)] = ( generated_tokens[:, :max(prompt_len, label_len)] = self.tokenizer.pad_token_id
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)]) generated_tokens = generated_tokens.contiguous()
)
return loss, generated_tokens, labels return loss, generated_tokens, labels
...@@ -65,16 +67,8 @@ class Seq2SeqPeftTrainer(PeftTrainer): ...@@ -65,16 +67,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Pads the tensor to the same length as the target tensor. Pads the tensor to the same length as the target tensor.
Should only be called when predict_with_generate=True.
""" """
if pad_token_id is None: pad_token_id = pad_token_id if pad_token_id is not None else self.tokenizer.pad_token_id
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = self.tokenizer.pad_token_id
else:
raise ValueError("PAD token is required.")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory return padded_tensor.contiguous() # in contiguous memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment