ray_trainer.py 29.5 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import os
import uuid
chenych's avatar
chenych committed
21
from collections import defaultdict
chenych's avatar
chenych committed
22
23
from copy import deepcopy
from dataclasses import dataclass, field
chenych's avatar
chenych committed
24
from enum import Enum, IntEnum, auto
chenych's avatar
chenych committed
25
from typing import Any, Dict, List, Optional, Type
chenych's avatar
chenych committed
26
27

import numpy as np
chenych's avatar
chenych committed
28
import ray
chenych's avatar
chenych committed
29
import torch
chenych's avatar
chenych committed
30
31
from ray.experimental.tqdm_ray import tqdm
from torchdata.stateful_dataloader import StatefulDataLoader
chenych's avatar
chenych committed
32
33
from transformers import PreTrainedTokenizer, ProcessorMixin

chenych's avatar
chenych committed
34
35
36
37
38
39
40
from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto
from ..single_controller.base import Worker
from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.logger import Tracker
chenych's avatar
chenych committed
41
from ..utils.py_functional import convert_dict_to_str, timer
chenych's avatar
chenych committed
42
43
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..workers.fsdp_workers import FSDPWorker
chenych's avatar
chenych committed
44
from ..workers.reward import FunctionRewardManager
chenych's avatar
chenych committed
45
46
47
from . import core_algos
from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
chenych's avatar
chenych committed
48
49


chenych's avatar
chenych committed
50
class Role(IntEnum):
chenych's avatar
chenych committed
51
52
53
54
    """
    To create more roles dynamically, you can subclass Role and add new members
    """

chenych's avatar
chenych committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    Actor = auto()
    Rollout = auto()
    ActorRollout = auto()
    Critic = auto()
    RefPolicy = auto()
    RewardModel = auto()
    ActorRolloutRef = auto()


class AdvantageEstimator(str, Enum):
    """
    Using an enumeration class to avoid spelling errors in adv_estimator
    """

    GAE = "gae"
    GRPO = "grpo"
    REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
    REMAX = "remax"
    RLOO = "rloo"
chenych's avatar
chenych committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95


@dataclass
class ResourcePoolManager:
    """
    Define a resource pool specification. Resource pool will be initialized first.
    """

    resource_pool_spec: dict[str, list[int]]
    mapping: dict[Role, str]
    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)

    def create_resource_pool(self):
        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
            # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
            resource_pool = RayResourcePool(
                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
            )
            self.resource_pool_dict[resource_pool_name] = resource_pool

chenych's avatar
chenych committed
96
97
        self._check_resource_available()

chenych's avatar
chenych committed
98
    def get_resource_pool(self, role: Role) -> RayResourcePool:
chenych's avatar
chenych committed
99
        """Get the resource pool of the worker."""
chenych's avatar
chenych committed
100
101
        return self.resource_pool_dict[self.mapping[role]]

chenych's avatar
update  
chenych committed
102
    def get_num_gpus(self) -> int:
chenych's avatar
chenych committed
103
104
        """Get the number of gpus in this cluster."""
        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
chenych's avatar
chenych committed
105

chenych's avatar
chenych committed
106
107
    def _check_resource_available(self):
        """Check if the resource pool can be satisfied in this ray cluster."""
chenych's avatar
update  
chenych committed
108
109
110
111
        gpus_available = ray.available_resources().get("GPU", 0)
        gpus_required = self.get_num_gpus()
        if gpus_available < gpus_required:
            raise ValueError(f"Total available GPUs {gpus_available} is less than total desired GPUs {gpus_required}.")
chenych's avatar
chenych committed
112
113
114


def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"):
chenych's avatar
chenych committed
115
116
    token_level_scores = data.batch["token_level_scores"]
    batch_size = data.batch.batch_size[0]
chenych's avatar
chenych committed
117
    response_mask = data.batch["response_mask"]
chenych's avatar
chenych committed
118
119

    # compute kl between ref_policy and current policy
chenych's avatar
update  
chenych committed
120
121
    kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
    kld = kld * response_mask  # (batch_size, response_length)
chenych's avatar
chenych committed
122

chenych's avatar
Update  
chenych committed
123
    data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld
chenych's avatar
chenych committed
124

chenych's avatar
chenych committed
125
    current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1)  # average over sequence
chenych's avatar
chenych committed
126
    current_kl = torch.mean(current_kl, dim=0).item()
chenych's avatar
Update  
chenych committed
127
    metrics = {"critic/kl": current_kl, "critic/kl_coef": kl_ctrl.kl_coef}
chenych's avatar
chenych committed
128

chenych's avatar
chenych committed
129
    # According to https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L880
chenych's avatar
chenych committed
130
131
132
133
    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
    return data, metrics


chenych's avatar
chenych committed
134
135
136
137
138
def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0):
    token_level_rewards = data.batch["token_level_rewards"]
    response_mask = data.batch["response_mask"]
    index = data.non_tensor_batch["uid"]
    if adv_estimator == AdvantageEstimator.GAE:
chenych's avatar
chenych committed
139
140
        values = data.batch["values"]
        advantages, returns = core_algos.compute_gae_advantage_return(
chenych's avatar
Update  
chenych committed
141
            token_level_rewards, values, response_mask, gamma, lam
chenych's avatar
chenych committed
142
        )
chenych's avatar
chenych committed
143
    elif adv_estimator == AdvantageEstimator.GRPO:
chenych's avatar
Update  
chenych committed
144
        advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards, response_mask, index)
chenych's avatar
chenych committed
145
    elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
chenych's avatar
chenych committed
146
        advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
chenych's avatar
Update  
chenych committed
147
            token_level_rewards, response_mask, gamma
chenych's avatar
chenych committed
148
        )
chenych's avatar
chenych committed
149
    elif adv_estimator == AdvantageEstimator.REMAX:
chenych's avatar
chenych committed
150
151
        reward_baselines = data.batch["reward_baselines"]
        advantages, returns = core_algos.compute_remax_outcome_advantage(
chenych's avatar
Update  
chenych committed
152
            token_level_rewards, reward_baselines, response_mask
chenych's avatar
chenych committed
153
        )
chenych's avatar
chenych committed
154
    elif adv_estimator == AdvantageEstimator.RLOO:
chenych's avatar
Update  
chenych committed
155
        advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards, response_mask, index)
chenych's avatar
chenych committed
156
157
158
    else:
        raise NotImplementedError

chenych's avatar
chenych committed
159
160
161
    data.batch["advantages"] = advantages
    data.batch["returns"] = returns
    return data
chenych's avatar
chenych committed
162
163
164
165
166
167
168
169
170
171
172
173


class RayPPOTrainer:
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    def __init__(
        self,
        config: PPOConfig,
        tokenizer: PreTrainedTokenizer,
        processor: Optional[ProcessorMixin],
chenych's avatar
update  
chenych committed
174
175
        train_dataloader: StatefulDataLoader,
        val_dataloader: StatefulDataLoader,
chenych's avatar
Update  
chenych committed
176
        role_worker_mapping: dict[Role, Type[Worker]],
chenych's avatar
chenych committed
177
        resource_pool_manager: ResourcePoolManager,
chenych's avatar
Update  
chenych committed
178
        ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
chenych's avatar
chenych committed
179
180
        reward_fn: Optional[FunctionRewardManager] = None,
        val_reward_fn: Optional[FunctionRewardManager] = None,
chenych's avatar
chenych committed
181
182
183
    ):
        self.tokenizer = tokenizer
        self.processor = processor
chenych's avatar
update  
chenych committed
184
185
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
chenych's avatar
chenych committed
186
187
188
189
190
191
        self.config = config
        self.reward_fn = reward_fn
        self.val_reward_fn = val_reward_fn

        self.hybrid_engine = config.worker.hybrid_engine
        if self.hybrid_engine:
chenych's avatar
chenych committed
192
193
194
195
196
            assert Role.ActorRollout in role_worker_mapping, (
                f"ActorRollout should be included in {role_worker_mapping.keys()}."
            )
        else:
            raise NotImplementedError
chenych's avatar
chenych committed
197
198
199
200
201
202
203

        self.role_worker_mapping = role_worker_mapping
        self.resource_pool_manager = resource_pool_manager
        self.use_reward_model = Role.RewardModel in role_worker_mapping
        self.ray_worker_group_cls = ray_worker_group_cls

        # define KL control
chenych's avatar
chenych committed
204
205
        if Role.RefPolicy in role_worker_mapping and not config.algorithm.disable_kl:
            self.use_reference_policy = True
chenych's avatar
chenych committed
206
207
            self.kl_ctrl = core_algos.get_kl_controller(config.algorithm)
        else:
chenych's avatar
chenych committed
208
209
210
            self.use_reference_policy = False
            self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0)
            print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.")
chenych's avatar
chenych committed
211

chenych's avatar
chenych committed
212
        if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
chenych's avatar
chenych committed
213
214
            self.use_critic = True
        else:
chenych's avatar
chenych committed
215
216
217
218
219
220
            self.use_critic = False

        if config.algorithm.adv_estimator not in list(AdvantageEstimator):
            raise NotImplementedError(f"Unknown advantage estimator: {config.algorithm.adv_estimator}.")

        if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0:
chenych's avatar
Update  
chenych committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            raise ValueError("Rollout batch size must be divisible by actor global batch size.")

        if (
            config.data.rollout_batch_size * config.worker.rollout.n
        ) % config.worker.actor.micro_batch_size_per_device_for_experience != 0:
            raise ValueError(
                "Rollout batch size * rollout.n must be divisible by actor micro batch size for experience."
            )

        if self.use_critic:
            if config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0:
                raise ValueError("Rollout batch size must be divisible by critic global batch size.")

            if (
                config.data.rollout_batch_size * config.worker.rollout.n
            ) % config.worker.critic.micro_batch_size_per_device_for_experience != 0:
                raise ValueError(
                    "Rollout batch size * rollout.n must be divisible by critic micro batch size for experience."
                )
chenych's avatar
chenych committed
240

chenych's avatar
Update  
chenych committed
241
242
243
244
245
        if (
            config.algorithm.adv_estimator in (AdvantageEstimator.GRPO, AdvantageEstimator.RLOO)
            and config.worker.rollout.n == 1
        ):
            raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.")
chenych's avatar
chenych committed
246

chenych's avatar
update  
chenych committed
247
248
        if config.trainer.max_steps is not None:
            self.training_steps = config.trainer.max_steps
chenych's avatar
chenych committed
249
        else:
chenych's avatar
chenych committed
250
            self.training_steps = len(train_dataloader) * config.trainer.total_epochs
chenych's avatar
chenych committed
251

chenych's avatar
update  
chenych committed
252
253
        config.worker.actor.optim.training_steps = self.training_steps
        config.worker.critic.optim.training_steps = self.training_steps
chenych's avatar
chenych committed
254
255
        print(f"Total training steps: {self.training_steps}")

chenych's avatar
Update  
chenych committed
256
257
258
    def _maybe_log_val_generations(
        self, inputs: List[str], outputs: List[str], labels: List[str], scores: List[float]
    ) -> None:
chenych's avatar
chenych committed
259
260
        """Log a table of validation samples"""
        if self.config.trainer.val_generations_to_log <= 0:
chenych's avatar
chenych committed
261
262
263
            return

        # Create tuples of (input, output, score) and sort by input text
chenych's avatar
Update  
chenych committed
264
        samples = list(zip(inputs, outputs, labels, scores))
chenych's avatar
chenych committed
265
266
267
268
269
270
        samples.sort(key=lambda x: x[0])  # Sort by input text

        # Use fixed random seed for deterministic shuffling
        rng = np.random.RandomState(42)
        rng.shuffle(samples)

chenych's avatar
chenych committed
271
272
        samples = samples[: self.config.trainer.val_generations_to_log]
        self.logger.log_generation(samples, self.global_step)
chenych's avatar
chenych committed
273

chenych's avatar
chenych committed
274
    def _validate(self) -> Dict[str, Any]:
chenych's avatar
chenych committed
275
276
        reward_tensor_lst = []
        # Lists to collect samples for the table
chenych's avatar
Update  
chenych committed
277
        sample_inputs, sample_outputs, sample_labels, sample_scores = [], [], [], []
chenych's avatar
chenych committed
278
        reward_metrics_lst = defaultdict(list)
chenych's avatar
Update  
chenych committed
279
280
        for batch_dict in self.val_dataloader:
            test_batch = DataProto.from_single_dict(batch_dict)
chenych's avatar
chenych committed
281
282
283
284
285
            # Store original inputs
            input_ids = test_batch.batch["input_ids"]
            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
            sample_inputs.extend(input_texts)

chenych's avatar
update  
chenych committed
286
            if "multi_modal_data" in test_batch.non_tensor_batch.keys():
chenych's avatar
chenych committed
287
288
                test_gen_batch = test_batch.pop(
                    batch_keys=["input_ids", "attention_mask", "position_ids"],
chenych's avatar
update  
chenych committed
289
                    non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
chenych's avatar
chenych committed
290
291
292
293
294
295
296
                )
            else:
                test_gen_batch = test_batch.pop(
                    batch_keys=["input_ids", "attention_mask", "position_ids"],
                    non_tensor_batch_keys=["raw_prompt_ids"],
                )

chenych's avatar
chenych committed
297
298
299
300
            test_gen_batch.meta_info = self.config.worker.rollout.val_override_config
            test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
            test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
            test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
chenych's avatar
chenych committed
301
302
303
304
305

            # Store generated outputs
            output_ids = test_output_gen_batch.batch["responses"]
            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
            sample_outputs.extend(output_texts)
chenych's avatar
Update  
chenych committed
306
            sample_labels.extend(test_batch.non_tensor_batch["ground_truth"].tolist())
chenych's avatar
chenych committed
307
308
309
            test_batch = test_batch.union(test_output_gen_batch)

            # evaluate using reward_function
chenych's avatar
chenych committed
310
            reward_tensor, reward_metrics = ray.get(self.val_reward_fn.compute_reward.remote(test_batch))
chenych's avatar
chenych committed
311
312
313
314
315
316

            # Store scores
            scores = reward_tensor.sum(-1).cpu().tolist()
            sample_scores.extend(scores)

            reward_tensor_lst.append(reward_tensor)
chenych's avatar
chenych committed
317
318
            for key, value in reward_metrics.items():
                reward_metrics_lst[key].extend(value)
chenych's avatar
chenych committed
319

chenych's avatar
Update  
chenych committed
320
        self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores)
chenych's avatar
chenych committed
321
322
323
        reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item()
        val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()}
        return {"val/reward_score": reward_score, **val_reward_metrics}
chenych's avatar
chenych committed
324

chenych's avatar
chenych committed
325
    def init_workers(self) -> None:
chenych's avatar
chenych committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        """Init resource pool and worker group"""
        self.resource_pool_manager.create_resource_pool()
        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}

        # create actor and rollout
        if self.hybrid_engine:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
            actor_rollout_cls = RayClassWithInitArgs(
                cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.worker, role="actor_rollout"
            )
            self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
        else:
            raise NotImplementedError

        # create critic
        if self.use_critic:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
            critic_cls = RayClassWithInitArgs(
                cls=self.role_worker_mapping[Role.Critic], config=self.config.worker, role="critic"
            )
            self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls

        # create reference policy if needed
        if self.use_reference_policy:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
            ref_policy_cls = RayClassWithInitArgs(
                self.role_worker_mapping[Role.RefPolicy], config=self.config.worker, role="ref"
            )
            self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls

        # create a reward model if reward_fn is None
        if self.use_reward_model:
            # we create a RM here
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
            rm_cls = RayClassWithInitArgs(
                cls=self.role_worker_mapping[Role.RewardModel], config=self.config.worker, role="reward"
            )
            self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls

        # initialize WorkerGroup
        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
        # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
chenych's avatar
chenych committed
369
        all_wg: Dict[str, FSDPWorker] = {}
chenych's avatar
chenych committed
370
371
372
373
374
375
376
377
378
379
        self.wg_dicts = []
        for resource_pool, class_dict in self.resource_pool_to_cls.items():
            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
            wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
            all_wg.update(spawn_wg)
            # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
            self.wg_dicts.append(wg_dict)

        if self.use_critic:
chenych's avatar
chenych committed
380
            self.critic_wg = all_wg["critic"]
chenych's avatar
chenych committed
381
382
383
            self.critic_wg.init_model()

        if self.use_reference_policy:
chenych's avatar
chenych committed
384
            self.ref_policy_wg = all_wg["ref"]
chenych's avatar
chenych committed
385
386
387
            self.ref_policy_wg.init_model()

        if self.use_reward_model:
chenych's avatar
chenych committed
388
            self.rm_wg = all_wg["rm"]
chenych's avatar
chenych committed
389
390
391
            self.rm_wg.init_model()

        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
chenych's avatar
chenych committed
392
        self.actor_rollout_wg = all_wg["actor_rollout"]
chenych's avatar
chenych committed
393
394
        self.actor_rollout_wg.init_model()

chenych's avatar
chenych committed
395
396
397
398
    def _save_checkpoint(self) -> None:
        # path: {save_checkpoint_path}/global_step_{global_step}/{actor,critic}
        remove_obsolete_ckpt(
            self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit
chenych's avatar
chenych committed
399
        )
chenych's avatar
chenych committed
400
401
402
        folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}")
        actor_path = os.path.join(folder_path, "actor")
        self.actor_rollout_wg.save_checkpoint(actor_path)
chenych's avatar
chenych committed
403
404

        if self.use_critic:
chenych's avatar
chenych committed
405
406
            critic_path = os.path.join(folder_path, "critic")
            self.critic_wg.save_checkpoint(critic_path)
chenych's avatar
chenych committed
407

chenych's avatar
chenych committed
408
409
410
        dataloader_path = os.path.join(folder_path, "dataloader.pt")
        dataloader_state_dict = self.train_dataloader.state_dict()
        torch.save(dataloader_state_dict, dataloader_path)
chenych's avatar
chenych committed
411

chenych's avatar
chenych committed
412
413
414
415
416
        last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER)
        with open(last_global_step_path, "w") as f:
            f.write(str(self.global_step))

    def _load_checkpoint(self) -> None:
chenych's avatar
chenych committed
417
418
419
        if self.config.trainer.load_checkpoint_path is None:
            return

chenych's avatar
chenych committed
420
421
422
423
424
        if "global_step_" not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]:
            raise ValueError("`load_checkpoint_path` should end with `global_step_*`.")

        print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}.")
        self.global_step = int(self.config.trainer.load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1])
chenych's avatar
chenych committed
425
        actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor")
chenych's avatar
chenych committed
426
        self.actor_rollout_wg.load_checkpoint(actor_path)
chenych's avatar
chenych committed
427
        if self.use_critic:
chenych's avatar
chenych committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
            critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic")
            self.critic_wg.load_checkpoint(critic_path)

        dataloader_path = os.path.join(self.config.trainer.load_checkpoint_path, "dataloader.pt")
        if os.path.exists(dataloader_path):
            dataloader_state_dict = torch.load(dataloader_path, weights_only=False)
            self.train_dataloader.load_state_dict(dataloader_state_dict)
        else:
            print(f"No dataloader state found at {dataloader_path}, will start from scratch.")

    def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None:
        """Reorder the data on single controller such that each dp rank gets similar total tokens"""
        attention_mask = batch.batch["attention_mask"]
        batch_size = attention_mask.shape[0]
        global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)
        world_size = self.actor_rollout_wg.world_size
        global_partition_lst = get_seqlen_balanced_partitions(
            global_seqlen_lst, k_partitions=world_size, equal_size=True
        )
        # reorder based on index. The data will be automatically equally partitioned by dispatch function
        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
        batch.reorder(global_idx)
        global_balance_stats = log_seqlen_unbalance(
            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
        )
        metrics.update(global_balance_stats)
chenych's avatar
chenych committed
454
455
456
457
458
459
460

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
chenych's avatar
chenych committed
461
462
463
        self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict())
        self.global_step = 0
        val_metrics: Optional[Dict[str, Any]] = None
chenych's avatar
chenych committed
464
465
466
467
468
469
470
471

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.val_before_train:
            val_metrics = self._validate()
chenych's avatar
chenych committed
472
            self.logger.log(data=val_metrics, step=self.global_step)
chenych's avatar
chenych committed
473
474
475
            if self.config.trainer.val_only:
                return

chenych's avatar
chenych committed
476
        for _ in tqdm(range(self.config.trainer.total_epochs), desc="Epoch", position=0):
chenych's avatar
chenych committed
477
478
479
            for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1):
                self.global_step += 1
                if self.global_step > self.training_steps:
chenych's avatar
chenych committed
480
481
                    break

chenych's avatar
chenych committed
482
                metrics, timing_raw = {}, {}
chenych's avatar
chenych committed
483
484
485
                batch: DataProto = DataProto.from_single_dict(batch_dict)

                # pop those keys for generation
chenych's avatar
update  
chenych committed
486
                if "multi_modal_data" in batch.non_tensor_batch.keys():
chenych's avatar
chenych committed
487
488
                    gen_batch = batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
chenych's avatar
update  
chenych committed
489
                        non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
chenych's avatar
chenych committed
490
491
492
493
494
495
496
                    )
                else:
                    gen_batch = batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids"],
                    )

chenych's avatar
chenych committed
497
                with timer("step", timing_raw):
chenych's avatar
chenych committed
498
                    # generate a batch
chenych's avatar
chenych committed
499
                    with timer("gen", timing_raw):  # wg: worker group
chenych's avatar
chenych committed
500
501
502
                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

                    if self.config.algorithm.adv_estimator == "remax":
chenych's avatar
chenych committed
503
                        with timer("gen_max", timing_raw):
chenych's avatar
chenych committed
504
                            gen_baseline_batch = deepcopy(gen_batch)
chenych's avatar
Update  
chenych committed
505
506
                            gen_baseline_batch.meta_info["temperature"] = 0
                            gen_baseline_batch.meta_info["n"] = 1
chenych's avatar
chenych committed
507
508
509
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            batch = batch.union(gen_baseline_output)
chenych's avatar
chenych committed
510
                            reward_baseline_tensor, _ = ray.get(self.reward_fn.compute_reward.remote(batch))
chenych's avatar
chenych committed
511
512
513
514
515
516
517
518
519
520
521
522
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
                            batch.batch["reward_baselines"] = reward_baseline_tensor
                            del gen_baseline_batch, gen_baseline_output

                    batch.non_tensor_batch["uid"] = np.array(
                        [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
                    )
                    # repeat to align with repeated responses in rollout
                    batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
                    batch = batch.union(gen_batch_output)
chenych's avatar
update  
chenych committed
523
                    batch.non_tensor_batch.pop("multi_modal_data", None)
chenych's avatar
chenych committed
524
525
526
527

                    # balance the number of valid tokens on each dp rank.
                    # Note that this breaks the order of data inside the batch.
                    # Please take care when you implement group based adv computation such as GRPO and rloo
chenych's avatar
chenych committed
528
                    self._balance_batch(batch, metrics=metrics)
chenych's avatar
chenych committed
529
530
531
532

                    # compute global_valid tokens
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

chenych's avatar
chenych committed
533
534
535
536
                    # compute reward
                    with timer("reward", timing_raw):
                        reward_ref = self.reward_fn.compute_reward.remote(batch)

chenych's avatar
chenych committed
537
                    # recompute old_log_probs
chenych's avatar
chenych committed
538
                    with timer("old", timing_raw):
chenych's avatar
chenych committed
539
540
                        old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
                        batch = batch.union(old_log_probs)
chenych's avatar
chenych committed
541

chenych's avatar
chenych committed
542
                    # compute ref_log_probs
chenych's avatar
chenych committed
543
                    if self.use_reference_policy:
chenych's avatar
chenych committed
544
                        with timer("ref", timing_raw):
chenych's avatar
chenych committed
545
546
                            ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
                            batch = batch.union(ref_log_probs)
chenych's avatar
chenych committed
547
548
549

                    # compute values
                    if self.use_critic:
chenych's avatar
chenych committed
550
                        with timer("values", timing_raw):
chenych's avatar
chenych committed
551
552
553
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

chenych's avatar
chenych committed
554
555
556
557
558
559
560
                    with timer("adv", timing_raw):
                        # get token level scores
                        reward_tensor, reward_metrics = ray.get(reward_ref)
                        batch.batch["token_level_scores"] = reward_tensor
                        reward_metrics = {f"reward/{k}": v for k, v in reduce_metrics(reward_metrics).items()}
                        metrics.update(reward_metrics)

chenych's avatar
chenych committed
561
                        # apply kl penalty if available
chenych's avatar
Update  
chenych committed
562
563
                        if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
                            # apply kl penalty to reward
chenych's avatar
chenych committed
564
                            batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty)
chenych's avatar
chenych committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
                            metrics.update(kl_metrics)
                        else:
                            batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

                        # compute advantages, executed on the driver process
                        batch = compute_advantage(
                            batch,
                            adv_estimator=self.config.algorithm.adv_estimator,
                            gamma=self.config.algorithm.gamma,
                            lam=self.config.algorithm.lam,
                        )

                    # update critic
                    if self.use_critic:
chenych's avatar
chenych committed
579
                        with timer("update_critic", timing_raw):
chenych's avatar
chenych committed
580
581
                            critic_output = self.critic_wg.update_critic(batch)

chenych's avatar
chenych committed
582
583
                        critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
                        metrics.update(critic_metrics)
chenych's avatar
chenych committed
584

chenych's avatar
chenych committed
585
586
                    # update actor
                    if self.config.trainer.critic_warmup <= self.global_step:
chenych's avatar
chenych committed
587
                        with timer("update_actor", timing_raw):
chenych's avatar
chenych committed
588
589
                            actor_output = self.actor_rollout_wg.update_actor(batch)

chenych's avatar
chenych committed
590
591
                        actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
                        metrics.update(actor_metrics)
chenych's avatar
chenych committed
592
593
594
595

                    # validate
                    if (
                        self.val_reward_fn is not None
chenych's avatar
chenych committed
596
597
                        and self.config.trainer.val_freq > 0
                        and self.global_step % self.config.trainer.val_freq == 0
chenych's avatar
chenych committed
598
                    ):
chenych's avatar
chenych committed
599
                        with timer("validation", timing_raw):
chenych's avatar
chenych committed
600
601
                            val_metrics = self._validate()

chenych's avatar
chenych committed
602
603
                        metrics.update(val_metrics)

chenych's avatar
chenych committed
604
                    if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
chenych's avatar
chenych committed
605
                        with timer("save_checkpoint", timing_raw):
chenych's avatar
chenych committed
606
607
608
                            self._save_checkpoint()

                # collect metrics
chenych's avatar
update  
chenych committed
609
                num_gpus = self.resource_pool_manager.get_num_gpus()
chenych's avatar
chenych committed
610
611
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
chenych's avatar
update  
chenych committed
612
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, num_gpus=num_gpus))
chenych's avatar
chenych committed
613

chenych's avatar
chenych committed
614
                self.logger.log(data=metrics, step=self.global_step)
chenych's avatar
chenych committed
615
616
617

        # perform validation after training
        if self.val_reward_fn is not None:
chenych's avatar
chenych committed
618
619
620
621
622
623
624
625
626
627
628
629
            if (
                val_metrics is None
                or self.config.trainer.val_freq <= 0
                or self.global_step % self.config.trainer.val_freq != 0
            ):
                val_metrics = self._validate()
                self.logger.log(data=val_metrics, step=self.global_step)

            print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}")

        if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0:
            self._save_checkpoint()