ray_trainer.py 30 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import os
import uuid
chenych's avatar
chenych committed
21
from collections import defaultdict
chenych's avatar
chenych committed
22
23
24
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
chenych's avatar
chenych committed
25
from enum import Enum, IntEnum, auto
chenych's avatar
Update  
chenych committed
26
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
chenych's avatar
chenych committed
27
28

import numpy as np
chenych's avatar
chenych committed
29
import ray
chenych's avatar
chenych committed
30
31
import torch
from codetiming import Timer
chenych's avatar
chenych committed
32
33
from ray.experimental.tqdm_ray import tqdm
from torchdata.stateful_dataloader import StatefulDataLoader
chenych's avatar
chenych committed
34
35
from transformers import PreTrainedTokenizer, ProcessorMixin

chenych's avatar
chenych committed
36
37
38
39
40
41
42
43
44
45
46
47
48
from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto
from ..single_controller.base import Worker
from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..workers.fsdp_workers import FSDPWorker
from . import core_algos
from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
chenych's avatar
chenych committed
49
50


chenych's avatar
chenych committed
51
class Role(IntEnum):
chenych's avatar
chenych committed
52
53
54
55
    """
    To create more roles dynamically, you can subclass Role and add new members
    """

chenych's avatar
chenych committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    Actor = auto()
    Rollout = auto()
    ActorRollout = auto()
    Critic = auto()
    RefPolicy = auto()
    RewardModel = auto()
    ActorRolloutRef = auto()


class AdvantageEstimator(str, Enum):
    """
    Using an enumeration class to avoid spelling errors in adv_estimator
    """

    GAE = "gae"
    GRPO = "grpo"
    REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
    REMAX = "remax"
    RLOO = "rloo"
chenych's avatar
chenych committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


@dataclass
class ResourcePoolManager:
    """
    Define a resource pool specification. Resource pool will be initialized first.
    """

    resource_pool_spec: dict[str, list[int]]
    mapping: dict[Role, str]
    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)

    def create_resource_pool(self):
        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
            # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
            resource_pool = RayResourcePool(
                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
            )
            self.resource_pool_dict[resource_pool_name] = resource_pool

chenych's avatar
chenych committed
97
98
        self._check_resource_available()

chenych's avatar
chenych committed
99
    def get_resource_pool(self, role: Role) -> RayResourcePool:
chenych's avatar
chenych committed
100
        """Get the resource pool of the worker."""
chenych's avatar
chenych committed
101
102
        return self.resource_pool_dict[self.mapping[role]]

chenych's avatar
update  
chenych committed
103
    def get_num_gpus(self) -> int:
chenych's avatar
chenych committed
104
105
        """Get the number of gpus in this cluster."""
        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
chenych's avatar
chenych committed
106

chenych's avatar
chenych committed
107
108
    def _check_resource_available(self):
        """Check if the resource pool can be satisfied in this ray cluster."""
chenych's avatar
update  
chenych committed
109
110
111
112
        gpus_available = ray.available_resources().get("GPU", 0)
        gpus_required = self.get_num_gpus()
        if gpus_available < gpus_required:
            raise ValueError(f"Total available GPUs {gpus_available} is less than total desired GPUs {gpus_required}.")
chenych's avatar
chenych committed
113
114
115


def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"):
chenych's avatar
chenych committed
116
117
    token_level_scores = data.batch["token_level_scores"]
    batch_size = data.batch.batch_size[0]
chenych's avatar
chenych committed
118
    response_mask = data.batch["response_mask"]
chenych's avatar
chenych committed
119
120

    # compute kl between ref_policy and current policy
chenych's avatar
update  
chenych committed
121
122
    kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
    kld = kld * response_mask  # (batch_size, response_length)
chenych's avatar
chenych committed
123

chenych's avatar
Update  
chenych committed
124
    data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld
chenych's avatar
chenych committed
125

chenych's avatar
chenych committed
126
    current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1)  # average over sequence
chenych's avatar
chenych committed
127
    current_kl = torch.mean(current_kl, dim=0).item()
chenych's avatar
Update  
chenych committed
128
    metrics = {"critic/kl": current_kl, "critic/kl_coef": kl_ctrl.kl_coef}
chenych's avatar
chenych committed
129

chenych's avatar
chenych committed
130
    # According to https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L880
chenych's avatar
chenych committed
131
132
133
134
    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
    return data, metrics


chenych's avatar
chenych committed
135
136
137
138
139
def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0):
    token_level_rewards = data.batch["token_level_rewards"]
    response_mask = data.batch["response_mask"]
    index = data.non_tensor_batch["uid"]
    if adv_estimator == AdvantageEstimator.GAE:
chenych's avatar
chenych committed
140
141
        values = data.batch["values"]
        advantages, returns = core_algos.compute_gae_advantage_return(
chenych's avatar
Update  
chenych committed
142
            token_level_rewards, values, response_mask, gamma, lam
chenych's avatar
chenych committed
143
        )
chenych's avatar
chenych committed
144
    elif adv_estimator == AdvantageEstimator.GRPO:
chenych's avatar
Update  
chenych committed
145
        advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards, response_mask, index)
chenych's avatar
chenych committed
146
    elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
chenych's avatar
chenych committed
147
        advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
chenych's avatar
Update  
chenych committed
148
            token_level_rewards, response_mask, gamma
chenych's avatar
chenych committed
149
        )
chenych's avatar
chenych committed
150
    elif adv_estimator == AdvantageEstimator.REMAX:
chenych's avatar
chenych committed
151
152
        reward_baselines = data.batch["reward_baselines"]
        advantages, returns = core_algos.compute_remax_outcome_advantage(
chenych's avatar
Update  
chenych committed
153
            token_level_rewards, reward_baselines, response_mask
chenych's avatar
chenych committed
154
        )
chenych's avatar
chenych committed
155
    elif adv_estimator == AdvantageEstimator.RLOO:
chenych's avatar
Update  
chenych committed
156
        advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards, response_mask, index)
chenych's avatar
chenych committed
157
158
159
    else:
        raise NotImplementedError

chenych's avatar
chenych committed
160
161
162
    data.batch["advantages"] = advantages
    data.batch["returns"] = returns
    return data
chenych's avatar
chenych committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182


@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
    with Timer(name=name, logger=None) as timer:
        yield

    timing_raw[name] = timer.last


class RayPPOTrainer:
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    def __init__(
        self,
        config: PPOConfig,
        tokenizer: PreTrainedTokenizer,
        processor: Optional[ProcessorMixin],
chenych's avatar
update  
chenych committed
183
184
        train_dataloader: StatefulDataLoader,
        val_dataloader: StatefulDataLoader,
chenych's avatar
Update  
chenych committed
185
        role_worker_mapping: dict[Role, Type[Worker]],
chenych's avatar
chenych committed
186
        resource_pool_manager: ResourcePoolManager,
chenych's avatar
Update  
chenych committed
187
188
189
        ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
        reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
        val_reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
chenych's avatar
chenych committed
190
191
192
    ):
        self.tokenizer = tokenizer
        self.processor = processor
chenych's avatar
update  
chenych committed
193
194
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
chenych's avatar
chenych committed
195
196
197
198
199
200
        self.config = config
        self.reward_fn = reward_fn
        self.val_reward_fn = val_reward_fn

        self.hybrid_engine = config.worker.hybrid_engine
        if self.hybrid_engine:
chenych's avatar
chenych committed
201
202
203
204
205
            assert Role.ActorRollout in role_worker_mapping, (
                f"ActorRollout should be included in {role_worker_mapping.keys()}."
            )
        else:
            raise NotImplementedError
chenych's avatar
chenych committed
206
207
208
209
210
211
212

        self.role_worker_mapping = role_worker_mapping
        self.resource_pool_manager = resource_pool_manager
        self.use_reward_model = Role.RewardModel in role_worker_mapping
        self.ray_worker_group_cls = ray_worker_group_cls

        # define KL control
chenych's avatar
chenych committed
213
214
        if Role.RefPolicy in role_worker_mapping and not config.algorithm.disable_kl:
            self.use_reference_policy = True
chenych's avatar
chenych committed
215
216
            self.kl_ctrl = core_algos.get_kl_controller(config.algorithm)
        else:
chenych's avatar
chenych committed
217
218
219
            self.use_reference_policy = False
            self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0)
            print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.")
chenych's avatar
chenych committed
220

chenych's avatar
chenych committed
221
        if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
chenych's avatar
chenych committed
222
223
            self.use_critic = True
        else:
chenych's avatar
chenych committed
224
225
226
227
228
229
            self.use_critic = False

        if config.algorithm.adv_estimator not in list(AdvantageEstimator):
            raise NotImplementedError(f"Unknown advantage estimator: {config.algorithm.adv_estimator}.")

        if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0:
chenych's avatar
Update  
chenych committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            raise ValueError("Rollout batch size must be divisible by actor global batch size.")

        if (
            config.data.rollout_batch_size * config.worker.rollout.n
        ) % config.worker.actor.micro_batch_size_per_device_for_experience != 0:
            raise ValueError(
                "Rollout batch size * rollout.n must be divisible by actor micro batch size for experience."
            )

        if self.use_critic:
            if config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0:
                raise ValueError("Rollout batch size must be divisible by critic global batch size.")

            if (
                config.data.rollout_batch_size * config.worker.rollout.n
            ) % config.worker.critic.micro_batch_size_per_device_for_experience != 0:
                raise ValueError(
                    "Rollout batch size * rollout.n must be divisible by critic micro batch size for experience."
                )
chenych's avatar
chenych committed
249

chenych's avatar
Update  
chenych committed
250
251
252
253
254
        if (
            config.algorithm.adv_estimator in (AdvantageEstimator.GRPO, AdvantageEstimator.RLOO)
            and config.worker.rollout.n == 1
        ):
            raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.")
chenych's avatar
chenych committed
255

chenych's avatar
update  
chenych committed
256
257
        if config.trainer.max_steps is not None:
            self.training_steps = config.trainer.max_steps
chenych's avatar
chenych committed
258
        else:
chenych's avatar
update  
chenych committed
259
            self.training_steps = len(train_dataloader) * config.trainer.total_episodes
chenych's avatar
chenych committed
260

chenych's avatar
update  
chenych committed
261
262
        config.worker.actor.optim.training_steps = self.training_steps
        config.worker.critic.optim.training_steps = self.training_steps
chenych's avatar
chenych committed
263
264
        print(f"Total training steps: {self.training_steps}")

chenych's avatar
Update  
chenych committed
265
266
267
    def _maybe_log_val_generations(
        self, inputs: List[str], outputs: List[str], labels: List[str], scores: List[float]
    ) -> None:
chenych's avatar
chenych committed
268
269
        """Log a table of validation samples"""
        if self.config.trainer.val_generations_to_log <= 0:
chenych's avatar
chenych committed
270
271
272
            return

        # Create tuples of (input, output, score) and sort by input text
chenych's avatar
Update  
chenych committed
273
        samples = list(zip(inputs, outputs, labels, scores))
chenych's avatar
chenych committed
274
275
276
277
278
279
        samples.sort(key=lambda x: x[0])  # Sort by input text

        # Use fixed random seed for deterministic shuffling
        rng = np.random.RandomState(42)
        rng.shuffle(samples)

chenych's avatar
chenych committed
280
281
        samples = samples[: self.config.trainer.val_generations_to_log]
        self.logger.log_generation(samples, self.global_step)
chenych's avatar
chenych committed
282

chenych's avatar
chenych committed
283
    def _validate(self) -> Dict[str, Any]:
chenych's avatar
chenych committed
284
285
        reward_tensor_lst = []
        # Lists to collect samples for the table
chenych's avatar
Update  
chenych committed
286
        sample_inputs, sample_outputs, sample_labels, sample_scores = [], [], [], []
chenych's avatar
chenych committed
287
        reward_metrics_lst = defaultdict(list)
chenych's avatar
Update  
chenych committed
288
289
        for batch_dict in self.val_dataloader:
            test_batch = DataProto.from_single_dict(batch_dict)
chenych's avatar
chenych committed
290
291
292
293
294
            # Store original inputs
            input_ids = test_batch.batch["input_ids"]
            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
            sample_inputs.extend(input_texts)

chenych's avatar
update  
chenych committed
295
            if "multi_modal_data" in test_batch.non_tensor_batch.keys():
chenych's avatar
chenych committed
296
297
                test_gen_batch = test_batch.pop(
                    batch_keys=["input_ids", "attention_mask", "position_ids"],
chenych's avatar
update  
chenych committed
298
                    non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
chenych's avatar
chenych committed
299
300
301
302
303
304
305
                )
            else:
                test_gen_batch = test_batch.pop(
                    batch_keys=["input_ids", "attention_mask", "position_ids"],
                    non_tensor_batch_keys=["raw_prompt_ids"],
                )

chenych's avatar
chenych committed
306
307
308
309
            test_gen_batch.meta_info = self.config.worker.rollout.val_override_config
            test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
            test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
            test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
chenych's avatar
chenych committed
310
311
312
313
314
315
            print("validation generation end")

            # Store generated outputs
            output_ids = test_output_gen_batch.batch["responses"]
            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
            sample_outputs.extend(output_texts)
chenych's avatar
Update  
chenych committed
316
            sample_labels.extend(test_batch.non_tensor_batch["ground_truth"].tolist())
chenych's avatar
chenych committed
317
318
319
            test_batch = test_batch.union(test_output_gen_batch)

            # evaluate using reward_function
chenych's avatar
chenych committed
320
            reward_tensor, reward_metrics = self.val_reward_fn(test_batch)
chenych's avatar
chenych committed
321
322
323
324
325
326

            # Store scores
            scores = reward_tensor.sum(-1).cpu().tolist()
            sample_scores.extend(scores)

            reward_tensor_lst.append(reward_tensor)
chenych's avatar
chenych committed
327
328
            for key, value in reward_metrics.items():
                reward_metrics_lst[key].extend(value)
chenych's avatar
chenych committed
329

chenych's avatar
Update  
chenych committed
330
        self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores)
chenych's avatar
chenych committed
331
332
333
        reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item()
        val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()}
        return {"val/reward_score": reward_score, **val_reward_metrics}
chenych's avatar
chenych committed
334

chenych's avatar
chenych committed
335
    def init_workers(self) -> None:
chenych's avatar
chenych committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        """Init resource pool and worker group"""
        self.resource_pool_manager.create_resource_pool()
        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}

        # create actor and rollout
        if self.hybrid_engine:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
            actor_rollout_cls = RayClassWithInitArgs(
                cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.worker, role="actor_rollout"
            )
            self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
        else:
            raise NotImplementedError

        # create critic
        if self.use_critic:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
            critic_cls = RayClassWithInitArgs(
                cls=self.role_worker_mapping[Role.Critic], config=self.config.worker, role="critic"
            )
            self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls

        # create reference policy if needed
        if self.use_reference_policy:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
            ref_policy_cls = RayClassWithInitArgs(
                self.role_worker_mapping[Role.RefPolicy], config=self.config.worker, role="ref"
            )
            self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls

        # create a reward model if reward_fn is None
        if self.use_reward_model:
            # we create a RM here
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
            rm_cls = RayClassWithInitArgs(
                cls=self.role_worker_mapping[Role.RewardModel], config=self.config.worker, role="reward"
            )
            self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls

        # initialize WorkerGroup
        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
        # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
chenych's avatar
chenych committed
379
        all_wg: Dict[str, FSDPWorker] = {}
chenych's avatar
chenych committed
380
381
382
383
384
385
386
387
388
389
        self.wg_dicts = []
        for resource_pool, class_dict in self.resource_pool_to_cls.items():
            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
            wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
            all_wg.update(spawn_wg)
            # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
            self.wg_dicts.append(wg_dict)

        if self.use_critic:
chenych's avatar
chenych committed
390
            self.critic_wg = all_wg["critic"]
chenych's avatar
chenych committed
391
392
393
            self.critic_wg.init_model()

        if self.use_reference_policy:
chenych's avatar
chenych committed
394
            self.ref_policy_wg = all_wg["ref"]
chenych's avatar
chenych committed
395
396
397
            self.ref_policy_wg.init_model()

        if self.use_reward_model:
chenych's avatar
chenych committed
398
            self.rm_wg = all_wg["rm"]
chenych's avatar
chenych committed
399
400
401
            self.rm_wg.init_model()

        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
chenych's avatar
chenych committed
402
        self.actor_rollout_wg = all_wg["actor_rollout"]
chenych's avatar
chenych committed
403
404
        self.actor_rollout_wg.init_model()

chenych's avatar
chenych committed
405
406
407
408
    def _save_checkpoint(self) -> None:
        # path: {save_checkpoint_path}/global_step_{global_step}/{actor,critic}
        remove_obsolete_ckpt(
            self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit
chenych's avatar
chenych committed
409
        )
chenych's avatar
chenych committed
410
411
412
        folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}")
        actor_path = os.path.join(folder_path, "actor")
        self.actor_rollout_wg.save_checkpoint(actor_path)
chenych's avatar
chenych committed
413
414

        if self.use_critic:
chenych's avatar
chenych committed
415
416
            critic_path = os.path.join(folder_path, "critic")
            self.critic_wg.save_checkpoint(critic_path)
chenych's avatar
chenych committed
417

chenych's avatar
chenych committed
418
419
420
        dataloader_path = os.path.join(folder_path, "dataloader.pt")
        dataloader_state_dict = self.train_dataloader.state_dict()
        torch.save(dataloader_state_dict, dataloader_path)
chenych's avatar
chenych committed
421

chenych's avatar
chenych committed
422
423
424
425
426
        last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER)
        with open(last_global_step_path, "w") as f:
            f.write(str(self.global_step))

    def _load_checkpoint(self) -> None:
chenych's avatar
chenych committed
427
428
429
        if self.config.trainer.load_checkpoint_path is None:
            return

chenych's avatar
chenych committed
430
431
432
433
434
        if "global_step_" not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]:
            raise ValueError("`load_checkpoint_path` should end with `global_step_*`.")

        print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}.")
        self.global_step = int(self.config.trainer.load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1])
chenych's avatar
chenych committed
435
        actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor")
chenych's avatar
chenych committed
436
        self.actor_rollout_wg.load_checkpoint(actor_path)
chenych's avatar
chenych committed
437
        if self.use_critic:
chenych's avatar
chenych committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
            critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic")
            self.critic_wg.load_checkpoint(critic_path)

        dataloader_path = os.path.join(self.config.trainer.load_checkpoint_path, "dataloader.pt")
        if os.path.exists(dataloader_path):
            dataloader_state_dict = torch.load(dataloader_path, weights_only=False)
            self.train_dataloader.load_state_dict(dataloader_state_dict)
        else:
            print(f"No dataloader state found at {dataloader_path}, will start from scratch.")

    def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None:
        """Reorder the data on single controller such that each dp rank gets similar total tokens"""
        attention_mask = batch.batch["attention_mask"]
        batch_size = attention_mask.shape[0]
        global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)
        world_size = self.actor_rollout_wg.world_size
        global_partition_lst = get_seqlen_balanced_partitions(
            global_seqlen_lst, k_partitions=world_size, equal_size=True
        )
        # reorder based on index. The data will be automatically equally partitioned by dispatch function
        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
        batch.reorder(global_idx)
        global_balance_stats = log_seqlen_unbalance(
            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
        )
        metrics.update(global_balance_stats)
chenych's avatar
chenych committed
464
465
466
467
468
469
470

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
chenych's avatar
chenych committed
471
472
473
        self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict())
        self.global_step = 0
        val_metrics: Optional[Dict[str, Any]] = None
chenych's avatar
chenych committed
474
475
476
477
478
479
480
481

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.val_before_train:
            val_metrics = self._validate()
chenych's avatar
chenych committed
482
            self.logger.log(data=val_metrics, step=self.global_step)
chenych's avatar
chenych committed
483
484
485
            if self.config.trainer.val_only:
                return

chenych's avatar
chenych committed
486
487
488
489
        for _ in tqdm(range(self.config.trainer.total_episodes), desc="Episode", position=0):
            for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1):
                self.global_step += 1
                if self.global_step > self.training_steps:
chenych's avatar
chenych committed
490
491
                    break

chenych's avatar
chenych committed
492
                metrics, timing_raw = {}, {}
chenych's avatar
chenych committed
493
494
495
                batch: DataProto = DataProto.from_single_dict(batch_dict)

                # pop those keys for generation
chenych's avatar
update  
chenych committed
496
                if "multi_modal_data" in batch.non_tensor_batch.keys():
chenych's avatar
chenych committed
497
498
                    gen_batch = batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
chenych's avatar
update  
chenych committed
499
                        non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
chenych's avatar
chenych committed
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
                    )
                else:
                    gen_batch = batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids"],
                    )

                with _timer("step", timing_raw):
                    # generate a batch
                    with _timer("gen", timing_raw):  # wg: worker group
                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

                    if self.config.algorithm.adv_estimator == "remax":
                        with _timer("gen_max", timing_raw):
                            gen_baseline_batch = deepcopy(gen_batch)
chenych's avatar
Update  
chenych committed
515
516
                            gen_baseline_batch.meta_info["temperature"] = 0
                            gen_baseline_batch.meta_info["n"] = 1
chenych's avatar
chenych committed
517
518
519
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            batch = batch.union(gen_baseline_output)
chenych's avatar
chenych committed
520
                            reward_baseline_tensor, _ = self.reward_fn(batch)
chenych's avatar
chenych committed
521
522
523
524
525
526
527
528
529
530
531
532
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
                            batch.batch["reward_baselines"] = reward_baseline_tensor
                            del gen_baseline_batch, gen_baseline_output

                    batch.non_tensor_batch["uid"] = np.array(
                        [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
                    )
                    # repeat to align with repeated responses in rollout
                    batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
                    batch = batch.union(gen_batch_output)
chenych's avatar
update  
chenych committed
533
                    batch.non_tensor_batch.pop("multi_modal_data", None)
chenych's avatar
chenych committed
534

chenych's avatar
chenych committed
535
536
537
538
539
540
541
542
543
544
545
546
547
                    # compute reward
                    with _timer("reward", timing_raw):
                        if self.use_reward_model:
                            raise NotImplementedError("Reward model is not supported yet.")

                        # we combine with rule-based rm
                        reward_tensor, reward_metrics = self.reward_fn(batch)
                        batch.batch["token_level_scores"] = reward_tensor
                        reward_metrics = {
                            f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items()
                        }
                        metrics.update(reward_metrics)

chenych's avatar
chenych committed
548
549
550
                    # balance the number of valid tokens on each dp rank.
                    # Note that this breaks the order of data inside the batch.
                    # Please take care when you implement group based adv computation such as GRPO and rloo
chenych's avatar
chenych committed
551
                    self._balance_batch(batch, metrics=metrics)
chenych's avatar
chenych committed
552
553
554
555
556

                    # compute global_valid tokens
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    # recompute old_log_probs
chenych's avatar
chenych committed
557
558
559
                    with _timer("old", timing_raw):
                        old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
                        batch = batch.union(old_log_probs)
chenych's avatar
chenych committed
560

chenych's avatar
chenych committed
561
                    # compute ref_log_probs
chenych's avatar
chenych committed
562
563
                    if self.use_reference_policy:
                        with _timer("ref", timing_raw):
chenych's avatar
chenych committed
564
565
                            ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
                            batch = batch.union(ref_log_probs)
chenych's avatar
chenych committed
566
567
568
569
570
571
572
573

                    # compute values
                    if self.use_critic:
                        with _timer("values", timing_raw):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with _timer("adv", timing_raw):
chenych's avatar
chenych committed
574
                        # apply kl penalty if available
chenych's avatar
Update  
chenych committed
575
576
                        if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
                            # apply kl penalty to reward
chenych's avatar
chenych committed
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
                            batch, kl_metrics = apply_kl_penalty(
                                batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
                            )
                            metrics.update(kl_metrics)
                        else:
                            batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

                        # compute advantages, executed on the driver process
                        batch = compute_advantage(
                            batch,
                            adv_estimator=self.config.algorithm.adv_estimator,
                            gamma=self.config.algorithm.gamma,
                            lam=self.config.algorithm.lam,
                        )

                    # update critic
                    if self.use_critic:
                        with _timer("update_critic", timing_raw):
                            critic_output = self.critic_wg.update_critic(batch)

chenych's avatar
chenych committed
597
598
                        critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
                        metrics.update(critic_metrics)
chenych's avatar
chenych committed
599

chenych's avatar
chenych committed
600
601
                    # update actor
                    if self.config.trainer.critic_warmup <= self.global_step:
chenych's avatar
chenych committed
602
603
604
                        with _timer("update_actor", timing_raw):
                            actor_output = self.actor_rollout_wg.update_actor(batch)

chenych's avatar
chenych committed
605
606
                        actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
                        metrics.update(actor_metrics)
chenych's avatar
chenych committed
607
608
609
610

                    # validate
                    if (
                        self.val_reward_fn is not None
chenych's avatar
chenych committed
611
612
                        and self.config.trainer.val_freq > 0
                        and self.global_step % self.config.trainer.val_freq == 0
chenych's avatar
chenych committed
613
                    ):
chenych's avatar
chenych committed
614
615
616
                        with _timer("validation", timing_raw):
                            val_metrics = self._validate()

chenych's avatar
chenych committed
617
618
                        metrics.update(val_metrics)

chenych's avatar
chenych committed
619
                    if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
chenych's avatar
chenych committed
620
621
622
623
                        with _timer("save_checkpoint", timing_raw):
                            self._save_checkpoint()

                # collect metrics
chenych's avatar
update  
chenych committed
624
                num_gpus = self.resource_pool_manager.get_num_gpus()
chenych's avatar
chenych committed
625
626
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
chenych's avatar
update  
chenych committed
627
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, num_gpus=num_gpus))
chenych's avatar
chenych committed
628

chenych's avatar
chenych committed
629
                self.logger.log(data=metrics, step=self.global_step)
chenych's avatar
chenych committed
630
631
632

        # perform validation after training
        if self.val_reward_fn is not None:
chenych's avatar
chenych committed
633
634
635
636
637
638
639
640
641
642
643
644
            if (
                val_metrics is None
                or self.config.trainer.val_freq <= 0
                or self.global_step % self.config.trainer.val_freq != 0
            ):
                val_metrics = self._validate()
                self.logger.log(data=val_metrics, step=self.global_step)

            print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}")

        if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0:
            self._save_checkpoint()