trainer.py 23.2 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
import math
import os
import sys
chenych's avatar
chenych committed
21
import warnings
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
22
from types import MethodType
chenych's avatar
chenych committed
23
from typing import TYPE_CHECKING, Any, Optional
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25

import torch
chenych's avatar
chenych committed
26
from accelerate.utils import DistributedDataParallelKwargs
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
27
28
29
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
chenych's avatar
chenych committed
30
31
from transformers.trainer import DEFAULT_CALLBACKS
from transformers.trainer_callback import CallbackHandler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
33
34
35
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer
shihm's avatar
uodata  
shihm committed
36
from trl import __version__ as trl_version
chenych's avatar
chenych committed
37
from trl.models.utils import unwrap_model_for_generation
luopl's avatar
luopl committed
38
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
39

luopl's avatar
luopl committed
40
from ...extras import logging
shihm's avatar
uodata  
shihm committed
41
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor, torch_gc
chenych's avatar
chenych committed
42
43
44
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
45
46
47
48


if TYPE_CHECKING:
    from datasets import Dataset
chenych's avatar
chenych committed
49
50
51
52
53
54
55
    from transformers import (
        DataCollatorWithPadding,
        PreTrainedTokenizer,
        ProcessorMixin,
        Seq2SeqTrainingArguments,
        TrainerCallback,
    )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
56
57
58
59
60
    from trl import AutoModelForCausalLMWithValueHead

    from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
61
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
62
63
64


class CustomPPOTrainer(PPOTrainer, Trainer):
chenych's avatar
chenych committed
65
    r"""Inherit PPOTrainer."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
66
67
68
69
70
71
72

    def __init__(
        self,
        model_args: "ModelArguments",
        training_args: "Seq2SeqTrainingArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
chenych's avatar
chenych committed
73
        callbacks: Optional[list["TrainerCallback"]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
74
75
76
77
        model: "AutoModelForCausalLMWithValueHead",
        reward_model: Optional["AutoModelForCausalLMWithValueHead"],
        ref_model: Optional["AutoModelForCausalLMWithValueHead"],
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
78
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
79
        data_collator: "DataCollatorWithPadding",
chenych's avatar
chenych committed
80
81
82
83
84
85
        train_dataset: Optional["Dataset"] = None,
        eval_dataset: Optional["Dataset"] = None,
    ) -> None:
        if eval_dataset is not None:
            raise NotImplementedError("PPOTrainer does not support eval dataset yet.")

shihm's avatar
uodata  
shihm committed
86
87
88
89
90
91
92
93
94
95
96
97
98
        # Check if TRL version is compatible (0.8.6 <= version <= 0.9.6)
        try:
            from transformers.utils.versions import require_version

            require_version(
                "trl>=0.8.6,<=0.9.6",
                "Incompatible TRL version detected. LLaMA-Factory ppo requires TRL version >=0.8.6,<=0.9.6. "
                f"Found version {trl_version}. Please install the correct version with: `pip install trl>=0.8.6,<=0.9.6`\n"
                "To fix: run `DISABLE_VERSION_CHECK=1 llamafactory-cli train example_ppo.yaml`\n",
            )
        except ImportError as e:
            raise e

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
        ppo_config = PPOConfig(
            model_name=model_args.model_name_or_path,
            learning_rate=training_args.learning_rate,
            mini_batch_size=training_args.per_device_train_batch_size,
            batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
            gradient_accumulation_steps=training_args.gradient_accumulation_steps,
            ppo_epochs=finetuning_args.ppo_epochs,
            max_grad_norm=training_args.max_grad_norm,
            seed=training_args.seed,
            optimize_device_cache=True,
            target=finetuning_args.ppo_target,
            use_score_scaling=finetuning_args.ppo_score_norm,
            use_score_norm=finetuning_args.ppo_score_norm,
            whiten_rewards=finetuning_args.ppo_whiten_rewards,
            accelerator_kwargs={"step_scheduler_with_optimizer": False},
            log_with=training_args.report_to[0] if training_args.report_to else None,
            project_kwargs={"logging_dir": training_args.logging_dir},
        )

chenych's avatar
chenych committed
119
120
121
122
123
124
125
        # Add deepspeed config
        if training_args.deepspeed_plugin is not None:
            ppo_config.accelerator_kwargs["kwargs_handlers"] = [
                DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
            ]
            ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
            if ppo_config.log_with is not None:
luopl's avatar
luopl committed
126
                logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
chenych's avatar
chenych committed
127
128
                ppo_config.log_with = None

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
129
130
131
132
133
        # Create optimizer and scheduler
        if training_args.max_steps > 0:
            num_training_steps = training_args.max_steps
        else:
            total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
chenych's avatar
chenych committed
134
135
136
            num_training_steps = training_args.num_train_epochs * math.ceil(
                len(train_dataset) / total_train_batch_size
            )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
137
138
139
140
141
142
143
144
145
146

        optimizer = self.create_optimizer(model, training_args, finetuning_args)
        scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)

        PPOTrainer.__init__(
            self,
            config=ppo_config,
            model=model,
            ref_model=ref_model,
            tokenizer=tokenizer,
chenych's avatar
chenych committed
147
148
            dataset=train_dataset,
            optimizer=optimizer,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            data_collator=data_collator,
            lr_scheduler=scheduler,
        )

        self.args = training_args
        self.model_args = model_args
        self.finetuning_args = finetuning_args
        self.reward_model = reward_model
        self.current_device = get_current_device()  # patch for deepspeed training

        self.generation_config = GenerationConfig(
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
            **generating_args.to_dict(),
        )

        self.state = TrainerState()
        self.control = TrainerControl()
chenych's avatar
chenych committed
167
168
169
170
171
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
        callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
        self.callback_handler = CallbackHandler(
            callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
172
173
        )
        if self.args.max_steps > 0:
luopl's avatar
luopl committed
174
            logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
175

chenych's avatar
chenych committed
176
177
178
        self.amp_context = torch.autocast(self.current_device.type)
        warnings.simplefilter("ignore")  # remove gc warnings on ref model

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
179
180
181
182
183
184
185
186
187
188
        if finetuning_args.reward_model_type == "full":
            if self.is_deepspeed_enabled:
                if not (
                    getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
                    or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
                ):  # quantized models are already set on the correct device
                    self.reward_model = self._prepare_deepspeed(self.reward_model)
            else:
                self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)

chenych's avatar
chenych committed
189
190
191
192
193
        self.add_callback(FixValueHeadModelCallback)

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
194
        if finetuning_args.use_badam:
luopl's avatar
luopl committed
195
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
196

chenych's avatar
chenych committed
197
198
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
199
200

    def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
chenych's avatar
chenych committed
201
        r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        if resume_from_checkpoint is not None:
            raise ValueError("`resume_from_checkpoint` will be supported in the future version.")

        total_train_batch_size = (
            self.args.per_device_train_batch_size
            * self.args.gradient_accumulation_steps
            * self.finetuning_args.ppo_buffer_size
            * self.args.world_size
        )
        if self.args.max_steps > 0:
            num_examples = total_train_batch_size * self.args.max_steps
            num_train_epochs = sys.maxsize
            max_steps = self.args.max_steps
            steps_in_epoch = self.args.max_steps
        else:
            len_dataloader = len(self.dataloader)
            num_examples = len(self.dataset)
            num_train_epochs = self.args.num_train_epochs
            max_steps = math.ceil(num_train_epochs * len_dataloader)
            steps_in_epoch = len_dataloader

        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()

luopl's avatar
luopl committed
228
229
230
231
232
        logger.info_rank0("***** Running training *****")
        logger.info_rank0(f"  Num examples = {num_examples:,}")
        logger.info_rank0(f"  Num Epochs = {num_train_epochs:,}")
        logger.info_rank0(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
        logger.info_rank0(
chenych's avatar
chenych committed
233
            f"  Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
luopl's avatar
luopl committed
234
235
236
237
238
        )
        logger.info_rank0(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
        logger.info_rank0(f"  Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
        logger.info_rank0(f"  Total training steps = {max_steps:,}")
        logger.info_rank0(f"  Number of trainable parameters = {count_parameters(self.model)[0]:,}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
239
240
241
242

        dataiter = iter(self.dataloader)
        loss_meter = AverageMeter()
        reward_meter = AverageMeter()
chenych's avatar
chenych committed
243
        self.callback_handler.on_train_begin(self.args, self.state, self.control)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
244
245
246
247
248
249
250
251
252

        for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
            try:
                batch = next(dataiter)
            except StopIteration:
                dataiter = iter(self.dataloader)
                batch = next(dataiter)

            # Get inputs
chenych's avatar
chenych committed
253
            self.model.eval()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
254
255
256
            self.tokenizer.padding_side = "right"  # change padding side
            queries, responses, rewards = [], [], []
            for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
chenych's avatar
chenych committed
257
258
259
260
261
                mini_batch = {
                    "input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
                    "attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
                }
                mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
chenych's avatar
chenych committed
262
                mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
263
264
265
266
267
                queries.extend(mini_batch_queries)
                responses.extend(mini_batch_responses)
                rewards.extend(mini_batch_rewards)

            # Run PPO step
chenych's avatar
chenych committed
268
            self.model.train()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
269
270
271
272
273
274
275
276
277
278
279
            stats = self.step(queries, responses, rewards)
            self.tokenizer.padding_side = "left"  # restore padding side
            loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
            reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))

            if self.config.log_with is not None:
                try:
                    batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
                    batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
                    self.log_stats(stats, batch, rewards)
                except Exception:
luopl's avatar
luopl committed
280
                    logger.warning_rank0("Failed to save stats due to unknown errors.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
281
282

            self.state.global_step += 1
chenych's avatar
chenych committed
283
            self.callback_handler.on_step_end(self.args, self.state, self.control)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
284
285
286
287
288
289
290
291
292
293
294

            if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
                logs = dict(
                    loss=round(loss_meter.avg, 4),
                    reward=round(reward_meter.avg, 4),
                    learning_rate=stats["ppo/learning_rate"],
                    epoch=round(step / steps_in_epoch, 2),
                )
                tqdm.write(str(logs))
                logs["step"] = step
                self.state.log_history.append(logs)
chenych's avatar
chenych committed
295
                self.callback_handler.on_log(self.args, self.state, self.control, logs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
296
297
298
299
300
                loss_meter.reset()
                reward_meter.reset()

            if (step + 1) % self.args.save_steps == 0:  # save checkpoint
                self.save_model(
luopl's avatar
luopl committed
301
                    os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
302
                )
chenych's avatar
chenych committed
303
                self.callback_handler.on_save(self.args, self.state, self.control)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
304
305
306
307

            if self.control.should_epoch_stop or self.control.should_training_stop:
                break

chenych's avatar
chenych committed
308
        self.callback_handler.on_train_end(self.args, self.state, self.control)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
309

luopl's avatar
luopl committed
310
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
311
312
313
314
315
316
    def create_optimizer(
        self,
        model: "AutoModelForCausalLMWithValueHead",
        training_args: "Seq2SeqTrainingArguments",
        finetuning_args: "FinetuningArguments",
    ) -> "torch.optim.Optimizer":
chenych's avatar
chenych committed
317
        optimizer = create_custom_optimizer(model, training_args, finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        if optimizer is None:
            decay_params, nodecay_params = [], []
            decay_param_names = self.get_decay_parameter_names(model)
            for name, param in model.named_parameters():
                if param.requires_grad:
                    if name in decay_param_names:
                        decay_params.append(param)
                    else:
                        nodecay_params.append(param)

            optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
            param_groups = [
                dict(params=nodecay_params),
                dict(params=decay_params, weight_decay=training_args.weight_decay),
            ]
            optimizer = optim_class(param_groups, **optim_kwargs)

        return optimizer

luopl's avatar
luopl committed
337
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
338
339
340
341
342
343
344
345
346
347
348
349
350
    def create_scheduler(
        self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(training_args, num_training_steps, optimizer)
        lr_scheduler = get_scheduler(
            training_args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
            num_training_steps=num_training_steps,
        )
        return lr_scheduler

    @torch.no_grad()
chenych's avatar
chenych committed
351
352
    def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
        r"""Generate model's responses given queries."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
353
354
355
356
357
        if batch["input_ids"].size(0) == 1:  # handle llama2 ppo with gradient accumulation > 1
            start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
            for k, v in batch.items():
                batch[k] = v[:, start_index:]

chenych's avatar
chenych committed
358
        with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
chenych's avatar
chenych committed
359
            unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
chenych's avatar
chenych committed
360
361
            if self.model_args.upcast_layernorm:
                layernorm_params = dump_layernorm(unwrapped_model)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
362

chenych's avatar
chenych committed
363
            generate_output: torch.Tensor = unwrapped_model.generate(
chenych's avatar
chenych committed
364
365
366
367
                generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
            )
            if self.model_args.upcast_layernorm:
                restore_layernorm(unwrapped_model, layernorm_params)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
368
369
370
371
372
373

        query = batch["input_ids"].detach().cpu()
        response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
        queries, responses = [], []
        for i in range(len(query)):
            query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
chenych's avatar
chenych committed
374
            response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
375

chenych's avatar
chenych committed
376
377
378
379
            if len(response_indexes) == 0:  # allow empty response
                response_length = 1
            elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id:  # include eos token
                response_length = response_indexes[-1].item() + 2
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
380
            else:
chenych's avatar
chenych committed
381
                response_length = response_indexes[-1].item() + 1
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
382
383
384
385
386
387
388
389
390

            queries.append(query[i, query_start_index:])  # remove padding from left
            responses.append(response[i, :response_length])  # remove padding from right

        return queries, responses

    @torch.no_grad()
    def get_rewards(
        self,
chenych's avatar
chenych committed
391
392
393
394
        queries: list["torch.Tensor"],
        responses: list["torch.Tensor"],
    ) -> list["torch.Tensor"]:
        r"""Compute scores using given reward model.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
395
396
397
398
399

        Both inputs and outputs are put on CPU.
        """
        if self.finetuning_args.reward_model_type == "api":
            token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
luopl's avatar
luopl committed
400
            messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
401
402
            return get_rewards_from_server(self.reward_model, messages)

chenych's avatar
chenych committed
403
404
        batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
        unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
chenych's avatar
chenych committed
405

shihm's avatar
uodata  
shihm committed
406
        if self.finetuning_args.reward_model_type in ["lora", "oft"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
407
408
409
410
411
            replace_model(unwrapped_model, target="reward")
            reward_model = self.model
        else:
            reward_model = self.reward_model

chenych's avatar
chenych committed
412
        with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context:  # support bf16
chenych's avatar
chenych committed
413
            values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
414

shihm's avatar
uodata  
shihm committed
415
        if self.finetuning_args.reward_model_type in ["lora", "oft"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
416
417
            replace_model(unwrapped_model, target="default")

chenych's avatar
chenych committed
418
419
        rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
        return rewards.float().detach()  # use fp32 type
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
420

luopl's avatar
luopl committed
421
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
422
423
424
    def batched_forward_pass(
        self,
        model: "AutoModelForCausalLMWithValueHead",
chenych's avatar
chenych committed
425
426
        queries: "torch.Tensor",
        responses: "torch.Tensor",
chenych's avatar
chenych committed
427
        model_inputs: dict[str, Any],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
428
        return_logits: bool = False,
chenych's avatar
chenych committed
429
        response_masks: Optional["torch.Tensor"] = None,
chenych's avatar
chenych committed
430
431
    ) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
        r"""Calculate model outputs in multiple batches.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
432
433
434

        Subclass and override to inject custom behavior.
        """
shihm's avatar
uodata  
shihm committed
435
436
437
        from trl.core import logprobs_from_logits

        torch_gc()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        bs = len(queries)
        fbs = self.config.mini_batch_size
        all_logprobs = []
        all_logits = []
        all_masks = []
        all_values = []

        for i in range(math.ceil(bs / fbs)):
            input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
            query_batch = queries[i * fbs : (i + 1) * fbs]
            response_batch = responses[i * fbs : (i + 1) * fbs]
            if response_masks is not None:
                response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
            input_ids = input_kwargs["input_ids"]
            attention_mask = input_kwargs["attention_mask"]

chenych's avatar
chenych committed
454
455
            with self.amp_context:  # support bf16
                logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

            logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
            masks = torch.zeros_like(attention_mask)
            masks[:, :-1] = attention_mask[:, 1:]

            for j in range(len(query_batch)):
                start = len(query_batch[j]) - 1
                if attention_mask[j, 0] == 0:  # offset left padding
                    start += attention_mask[j, :].nonzero()[0].item()
                end = start + len(response_batch[j])

                if response_masks is not None:
                    response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]

                masks[j, :start] = 0
                masks[j, end:] = 0
                if response_masks is not None:
                    masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]

            if return_logits:
                all_logits.append(logits)
            else:
                del logits

            all_values.append(values)
            all_logprobs.append(logprobs)
            all_masks.append(masks)

        return (
            torch.cat(all_logprobs),
            torch.cat(all_logits)[:, :-1] if return_logits else None,
            torch.cat(all_values)[:, :-1],
            torch.cat(all_masks)[:, :-1],
        )

luopl's avatar
luopl committed
491
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
492
    def save_model(self, output_dir: Optional[str] = None) -> None:
chenych's avatar
chenych committed
493
        r"""Save model checkpoint.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
494
495
496

        Subclass and override to inject custom behavior.
        """
chenych's avatar
chenych committed
497
498
499
500
        if output_dir is None:
            output_dir = self.args.output_dir

        if self.is_fsdp_enabled or self.is_deepspeed_enabled:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
501
            try:
chenych's avatar
chenych committed
502
503
504
                state_dict = self.accelerator.get_state_dict(self.model)  # must be called at all ranks
                if self.args.should_save:
                    self._save(output_dir, state_dict=state_dict)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
505
            except ValueError:
luopl's avatar
luopl committed
506
                logger.warning_rank0(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
507
508
509
                    " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
                    " use zero_to_fp32.py to recover weights"
                )
chenych's avatar
chenych committed
510
511
512
513
                if self.args.should_save:
                    self._save(output_dir, state_dict={})
                # remove the dummy state_dict
                remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
514
                self.model.save_checkpoint(output_dir)
chenych's avatar
chenych committed
515
516

        elif self.args.should_save:
chenych's avatar
chenych committed
517
            unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
chenych's avatar
chenych committed
518
            self._save(output_dir, state_dict=unwrapped_model.state_dict())