trainer.py 12 KB
Newer Older
zhaoying1's avatar
zhaoying1 committed
1
2
3
4
import os
import math
import torch
from tqdm import tqdm
5
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
zhaoying1's avatar
zhaoying1 committed
6

7
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
zhaoying1's avatar
zhaoying1 committed
8
9
10
11
12
13
14
15
16

from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits

from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model

if TYPE_CHECKING:
17
    from transformers import Seq2SeqTrainingArguments, TrainerCallback
zhaoying1's avatar
zhaoying1 committed
18
    from trl import AutoModelForCausalLMWithValueHead
19
    from llmtuner.hparams import GeneratingArguments
zhaoying1's avatar
zhaoying1 committed
20
21
22
23
24


logger = get_logger(__name__)


25
class CustomPPOTrainer(PPOTrainer, Trainer):
zhaoying1's avatar
zhaoying1 committed
26
27
28
29
30
31
32
33
    r"""
    Inherits PPOTrainer.
    """

    def __init__(
        self,
        training_args: "Seq2SeqTrainingArguments",
        generating_args: "GeneratingArguments",
34
        callbacks: List["TrainerCallback"],
zhaoying1's avatar
zhaoying1 committed
35
36
37
38
        compute_dtype: torch.dtype,
        **kwargs
    ):
        PPOTrainer.__init__(self, **kwargs)
39
40
41
        if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
            raise ValueError("PPOTrainer is incompatible with DeepSpeed.")

zhaoying1's avatar
zhaoying1 committed
42
43
        self.args = training_args
        self.generating_args = generating_args
44
        self.log_callback, self.save_callback = callbacks[0], callbacks[1]
zhaoying1's avatar
zhaoying1 committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        self.compute_dtype = compute_dtype
        self.state = TrainerState()
        self.control = TrainerControl()

    def ppo_train(self, max_target_length: int) -> None:
        r"""
        Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
        """
        total_train_batch_size = (
            self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
        )
        len_dataloader = len(self.dataloader)
        num_examples = len(self.dataset)
        num_train_epochs = self.args.num_train_epochs
        max_steps = math.ceil(num_train_epochs * len_dataloader)

        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()

        if self.is_world_process_zero():
            logger.info("***** Running training *****")
            logger.info(f"  Num examples = {num_examples}")
            logger.info(f"  Num Epochs = {num_train_epochs}")
            logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
            logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
            logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
            logger.info(f"  Total optimization steps = {max_steps}")
            logger.info(f"  Number of trainable parameters = {count_parameters(self.model)[0]}")

        # Keyword arguments for `model.generate`
77
78
79
80
81
        generating_args = self.generating_args.to_dict()
        generating_args.update(dict(
            eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
            pad_token_id=self.tokenizer.pad_token_id
        ))
zhaoying1's avatar
zhaoying1 committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

        length_sampler = LengthSampler(max_target_length // 2, max_target_length)
        unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)

        dataiter = iter(self.dataloader)
        steps_trained = 0
        loss_meter = AverageMeter()
        reward_meter = AverageMeter()
        self.log_callback.on_train_begin(self.args, self.state, self.control)

        for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
            batch = next(dataiter)
            steps_trained += 1

            # Cast to inference mode
            unwrapped_model.gradient_checkpointing_disable()
            unwrapped_model.config.use_cache = True
99
100
            unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
            self.model.eval()
zhaoying1's avatar
zhaoying1 committed
101
102

            # Get inputs
103
104
            queries, responses = self.get_inputs(batch, length_sampler, generating_args)
            self.tokenizer.padding_side = "right" # change padding side
zhaoying1's avatar
zhaoying1 committed
105
106
107
108
109
            rewards = self.get_rewards(queries, responses, unwrapped_model)

            # Cast to training mode
            unwrapped_model.gradient_checkpointing_enable()
            unwrapped_model.config.use_cache = False
110
111
            unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
            self.model.train()
zhaoying1's avatar
zhaoying1 committed
112
113
114

            # Run PPO step
            stats = self.step(queries, responses, rewards)
115
            self.tokenizer.padding_side = "left" # restore padding side
zhaoying1's avatar
zhaoying1 committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
            reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))

            self.state.global_step += 1
            self.log_callback.on_step_end(self.args, self.state, self.control)

            if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
                logs = dict(
                    loss=round(loss_meter.avg, 4),
                    reward=round(reward_meter.avg, 4),
                    learning_rate=stats["ppo/learning_rate"],
                    epoch=round(step / len_dataloader, 2)
                )
                tqdm.write(str(logs))
                logs["step"] = step
                self.state.log_history.append(logs)
                self.log_callback.on_log(self.args, self.state, self.control)
                loss_meter.reset()
                reward_meter.reset()

            if (step+1) % self.args.save_steps == 0: # save checkpoint
                self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))

            if self.control.should_epoch_stop or self.control.should_training_stop:
                break

            if steps_trained == len_dataloader:
                dataiter = iter(self.dataloader)
                steps_trained = 0

146
147
148
        self.log_callback.on_train_end(
            self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
        )
zhaoying1's avatar
zhaoying1 committed
149
150
151
152
153

    @torch.no_grad()
    def get_inputs(
        self,
        batch: Dict[str, torch.Tensor],
154
155
        length_sampler: Callable,
        generating_args: Dict[str, Any]
zhaoying1's avatar
zhaoying1 committed
156
157
158
159
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        r"""
        Generates model's responses given queries.
        """
160
161
162
163
164
165
        generating_args["max_new_tokens"] = length_sampler()
        gen_kwargs = dict(
            generation_config=GenerationConfig(**generating_args),
            logits_processor=get_logits_processor(),
            **batch
        )
zhaoying1's avatar
zhaoying1 committed
166

167
        input_ids = batch["input_ids"]
zhaoying1's avatar
zhaoying1 committed
168
        unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
169
170
        response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
        query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
zhaoying1's avatar
zhaoying1 committed
171
172
173
174

        queries, responses = [], []
        for i in range(len(query)):
            query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
175
176
177
178
179
180
181
182
183
            response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()

            if len(response_index) == 0:
                response_length = 1 # allow empty response
            elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
                response_length = response_index[-1] + 2 # save the EOS token
            else:
                response_length = response_index[-1] + 1

zhaoying1's avatar
zhaoying1 committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            queries.append(query[i, query_length:]) # remove padding from left
            responses.append(response[i, :response_length]) # remove padding from right

        return queries, responses

    @torch.no_grad()
    def get_rewards(
        self,
        queries: List[torch.Tensor],
        responses: List[torch.Tensor],
        unwrapped_model: "AutoModelForCausalLMWithValueHead"
    ) -> List[torch.Tensor]:
        r"""
        Computes scores using given reward model.
        """
        replace_model(unwrapped_model, target="reward")
        batch = self.prepare_model_inputs(queries, responses)

        with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
            _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)

        if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
            values = torch.transpose(values, 0, 1)

208
209
210
211
212
        rewards = []
        for i in range(values.size(0)):
            end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
            rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type

zhaoying1's avatar
zhaoying1 committed
213
214
215
216
217
218
219
220
221
222
        replace_model(unwrapped_model, target="default")
        return rewards

    @PPODecorators.empty_cuda_cache()
    def batched_forward_pass(
        self,
        model: "AutoModelForCausalLMWithValueHead",
        queries: torch.Tensor,
        responses: torch.Tensor,
        model_inputs: dict,
223
224
        return_logits: Optional[bool] = False,
        response_masks: Optional[torch.Tensor] = None
zhaoying1's avatar
zhaoying1 committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    ):
        r"""
        Calculates model outputs in multiple batches.

        Subclass and override to inject custom behavior.
        """
        bs = len(queries)
        fbs = self.config.mini_batch_size
        all_logprobs = []
        all_logits = []
        all_masks = []
        all_values = []

        for i in range(math.ceil(bs / fbs)):
            input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
            query_batch = queries[i * fbs : (i + 1) * fbs]
            response_batch = responses[i * fbs : (i + 1) * fbs]
242
243
            if response_masks is not None:
                response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
zhaoying1's avatar
zhaoying1 committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            input_ids = input_kwargs["input_ids"]
            attention_mask = input_kwargs["attention_mask"]

            with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
                logits, _, values = model(**input_kwargs)

            if values.size(0) != input_ids.size(0): # adapt to chatglm2
                values = torch.transpose(values, 0, 1)

            logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
            masks = torch.zeros_like(attention_mask)
            masks[:, :-1] = attention_mask[:, 1:]

            for j in range(len(query_batch)):
                start = len(query_batch[j]) - 1
259
                if attention_mask[j, 0] == 0: # offset left padding
zhaoying1's avatar
zhaoying1 committed
260
261
262
                    start += attention_mask[j, :].nonzero()[0]
                end = start + len(response_batch[j])

263
264
265
266
267
                if response_masks is not None:
                    response_masks_batch = torch.cat(
                        (torch.zeros_like(query_batch[j]), response_masks_batch[j])
                    )[1:]

zhaoying1's avatar
zhaoying1 committed
268
269
                masks[j, :start] = 0
                masks[j, end:] = 0
270
271
                if response_masks is not None:
                    masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
zhaoying1's avatar
zhaoying1 committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

            if return_logits:
                all_logits.append(logits)
            else:
                del logits

            all_values.append(values)
            all_logprobs.append(logprobs)
            all_masks.append(masks)

        return (
            torch.cat(all_logprobs),
            torch.cat(all_logits)[:, :-1] if return_logits else None,
            torch.cat(all_values)[:, :-1],
            torch.cat(all_masks)[:, :-1],
        )

    def save_model(self, output_dir: Optional[str] = None) -> None:
        r"""
        Saves model checkpoint.

        Subclass and override to inject custom behavior.
        """
        if self.args.should_save:
            self._save(output_dir)
297
298
299
            self.save_callback.on_save(
                self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
            )