trainer.py 4.27 KB
Newer Older
zhaoying1's avatar
zhaoying1 committed
1
2
3
4
import os
import json
import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5
from transformers import Trainer
zhaoying1's avatar
zhaoying1 committed
6
7
8
9
10
11
12
13
14
15
16

from llmtuner.extras.logging import get_logger

if TYPE_CHECKING:
    from transformers.trainer import PredictionOutput
    from transformers.modeling_utils import PreTrainedModel


logger = get_logger(__name__)


17
class PairwiseTrainer(Trainer):
zhaoying1's avatar
zhaoying1 committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    r"""
    Inherits PeftTrainer to compute pairwise loss.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.can_return_loss = True # override property to return eval_loss

    def compute_loss(
        self,
        model: "PreTrainedModel",
        inputs: Dict[str, torch.Tensor],
        return_outputs: Optional[bool] = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
        r"""
        Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.

35
        Subclass and override to inject custom behavior.
zhaoying1's avatar
zhaoying1 committed
36

37
        Note that the first element will be removed from the output tuple. 
zhaoying1's avatar
zhaoying1 committed
38
39
        See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
        """
40
        # Compute rewards
zhaoying1's avatar
zhaoying1 committed
41
42
43
        _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
        if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
            values = torch.transpose(values, 0, 1)
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

        # Split the inputs and rewards into two parts, chosen and rejected
        batch_size = inputs["input_ids"].size(0) // 2
        chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
        chosen_attn_mask, rejected_attn_mask = (
            inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
        )
        chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
        chosen_scores, rejected_scores = [], []

        # Compute pairwise loss. Only backprop on the different tokens before padding
        # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
        loss = 0
        for i in range(batch_size):
            chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
            rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
            check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()

            if len(check_divergence) == 0:
                end_index = chosen_length
                div_index = end_index - 1
            else:
                end_index = max(chosen_length, rejected_length)
                div_index = check_divergence[0]

            assert div_index > 0
            chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
            rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
            if return_outputs: # use the score on the EOS token for inference
                chosen_scores.append(chosen_rewards[i, chosen_length-1])
                rejected_scores.append(rejected_rewards[i, rejected_length-1])
            loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()

        loss = loss / batch_size
        if return_outputs:
            chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
            return loss, [loss, chosen_scores, rejected_scores]

        return loss
zhaoying1's avatar
zhaoying1 committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def save_predictions(
        self,
        predict_results: "PredictionOutput"
    ) -> None:
        r"""
        Saves model predictions to `output_dir`.

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
        logger.info(f"Saving prediction results to {output_prediction_file}")

99
        chosen_scores, rejected_scores = predict_results.predictions
zhaoying1's avatar
zhaoying1 committed
100
101
102

        with open(output_prediction_file, "w", encoding="utf-8") as writer:
            res: List[str] = []
103
104
            for c_score, r_score in zip(chosen_scores, rejected_scores):
                res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
zhaoying1's avatar
zhaoying1 committed
105
            writer.write("\n".join(res))