fsdp_workers.py 24.7 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""

chenych's avatar
chenych committed
18
from typing import Literal, Optional, Union
chenych's avatar
chenych committed
19

chenych's avatar
chenych committed
20
21
import numpy as np
import psutil
chenych's avatar
chenych committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.distributed as dist
from accelerate import init_empty_weights
from codetiming import Timer
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoModelForTokenClassification,
    AutoModelForVision2Seq,
    GenerationConfig,
    PreTrainedModel,
)
from transformers.modeling_utils import no_init_weights

chenych's avatar
chenych committed
39
40
41
42
43
44
45
from ..models.monkey_patch import apply_ulysses_patch
from ..protocol import DataProto
from ..single_controller.base import Worker
from ..single_controller.base.decorator import Dispatch, register
from ..utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from ..utils.flops_counter import FlopsCounter
from ..utils.fsdp_utils import (
chenych's avatar
chenych committed
46
47
48
49
50
51
52
    get_fsdp_wrap_policy,
    get_init_fn,
    load_fsdp_model,
    load_fsdp_optimizer,
    offload_fsdp_model,
    offload_fsdp_optimizer,
)
chenych's avatar
chenych committed
53
54
55
56
57
58
59
from ..utils.model_utils import print_gpu_memory_usage, print_model_size
from ..utils.tokenizer import get_processor, get_tokenizer
from ..utils.torch_dtypes import PrecisionType
from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_with_warmup
from .actor import DataParallelPPOActor
from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig
from .critic import DataParallelPPOCritic
chenych's avatar
Update  
chenych committed
60
from .rollout import vLLMRollout
chenych's avatar
chenych committed
61
62
from .sharding_manager import FSDPVLLMShardingManager
from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
chenych's avatar
chenych committed
63
64
65
66
67
68
69
70
71
72


class FSDPWorker(Worker):
    def __init__(
        self,
        config: WorkerConfig,
        role: Literal["actor", "critic", "rollout", "ref", "actor_rollout", "actor_rollout_ref"],
    ):
        super().__init__()
        self.config = config
chenych's avatar
chenych committed
73
        self.role = role
chenych's avatar
chenych committed
74
75
76
77

        if not dist.is_initialized():
            dist.init_process_group(backend="nccl")

chenych's avatar
Update  
chenych committed
78
79
80
81
        # improve numerical stability
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False

chenych's avatar
chenych committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
        self._is_critic = self.role == "critic"
        self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
        self._is_ref = self.role in ["ref", "actor_rollout_ref"]

        self._use_param_offload = False
        self._use_optimizer_offload = False
        if self._is_actor:
            self._use_param_offload = self.config.actor.offload.offload_params
            self._use_optimizer_offload = self.config.actor.offload.offload_optimizer
            self._init_config(self.config.actor, "actor")
        elif self._is_critic:
            self._use_param_offload = self.config.critic.offload.offload_params
            self._use_optimizer_offload = self.config.critic.offload.offload_optimizer
            self._init_config(self.config.critic, "critic")
        elif self._is_ref:  # NOTE: it seems that manual offload is slower than FSDP offload
            self._use_param_offload = self.config.ref.offload.offload_params
            self._init_config(self.config.ref, "ref")

    def _init_config(
        self, config: Union[ActorConfig, CriticConfig, RefConfig], role: Literal["actor", "critic", "ref"]
    ):
chenych's avatar
chenych committed
104
        world_size = dist.get_world_size()
chenych's avatar
chenych committed
105
106
107
108
109
110
111
        fsdp_size = config.fsdp.fsdp_size
        if fsdp_size <= 0 or fsdp_size >= world_size:
            self.device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
        else:  # hsdp
            self.device_mesh = init_device_mesh(
                "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=("ddp", "fsdp")
            )
chenych's avatar
chenych committed
112

chenych's avatar
chenych committed
113
        if config.ulysses_sequence_parallel_size > 1:
chenych's avatar
chenych committed
114
115
            self.ulysses_device_mesh = init_device_mesh(
                "cuda",
chenych's avatar
chenych committed
116
117
118
119
120
                mesh_shape=(
                    world_size // config.ulysses_sequence_parallel_size,
                    config.ulysses_sequence_parallel_size,
                ),
                mesh_dim_names=("dp", "sp"),
chenych's avatar
chenych committed
121
122
123
124
125
126
            )
        else:
            self.ulysses_device_mesh = None

        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)

chenych's avatar
chenych committed
127
128
        if not hasattr(config, "global_batch_size"):  # ref model
            return
chenych's avatar
chenych committed
129

chenych's avatar
chenych committed
130
131
132
        if self.config.rollout.n > 1:
            config.global_batch_size *= self.config.rollout.n
            self.print_rank0(f"{role} will use global batch size {config.global_batch_size}.")
chenych's avatar
chenych committed
133

chenych's avatar
chenych committed
134
135
136
137
        config.global_batch_size_per_device = (
            config.global_batch_size // self.device_mesh.size() * config.ulysses_sequence_parallel_size
        )
        if config.global_batch_size_per_device == 0:
chenych's avatar
Update  
chenych committed
138
            raise ValueError(f"{role} global batch size * ulysses size must be larger than num gpus.")
chenych's avatar
chenych committed
139
140
141
142
143
144
145
146
147

        if config.global_batch_size_per_device % config.micro_batch_size_per_device_for_update != 0:
            raise ValueError(f"{role} global batch size per device must be divisible by the micro batch size.")

        if (
            config.fsdp.enable_cpu_offload
            and config.global_batch_size_per_device != config.micro_batch_size_per_device_for_update
        ):
            raise ValueError(f"{role} cannot use FSDP's CPU offload when gradient accumulation is enabled.")
chenych's avatar
chenych committed
148
149
150
151
152

    def _build_model_optimizer(
        self,
        model_config: ModelConfig,
        fsdp_config: FSDPConfig,
chenych's avatar
chenych committed
153
        optim_config: Optional[OptimConfig],
chenych's avatar
chenych committed
154
155
        padding_free: bool = False,
    ) -> None:
chenych's avatar
chenych committed
156
157
158
159
160
161
162
163
164
165
        self.tokenizer = get_tokenizer(
            model_config.tokenizer_path,
            trust_remote_code=model_config.trust_remote_code,
            use_fast=True,
        )
        self.processor = get_processor(
            model_config.tokenizer_path,
            trust_remote_code=model_config.trust_remote_code,
            use_fast=True,
        )
chenych's avatar
chenych committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        self.model_config = AutoConfig.from_pretrained(
            model_config.model_path,
            trust_remote_code=model_config.trust_remote_code,
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            **model_config.override_config,
        )

        try:
            self.generation_config = GenerationConfig.from_pretrained(model_config.model_path)
        except Exception:
            self.generation_config = GenerationConfig.from_model_config(self.model_config)

        self.print_rank0(f"Model config: {self.model_config}")

        if padding_free:
chenych's avatar
chenych committed
183
184
            apply_ulysses_patch(self.model_config.model_type)
            self.print_rank0("Ulysses patch applied!")
chenych's avatar
chenych committed
185
186
187
188
189
190
191
192
193
194
195
196
197

        if fsdp_config.torch_dtype is None:
            torch_dtype = torch.float32 if self._is_actor or self._is_critic else torch.bfloat16
        else:
            torch_dtype = PrecisionType.to_dtype(fsdp_config.torch_dtype)

        if self._is_critic:
            auto_class = AutoModelForTokenClassification
        elif type(self.model_config) in AutoModelForVision2Seq._model_mapping.keys():
            auto_class = AutoModelForVision2Seq
        else:
            auto_class = AutoModelForCausalLM

chenych's avatar
chenych committed
198
        if (not fsdp_config.enable_rank0_init) or self.device_mesh.get_local_rank("fsdp") == 0:
chenych's avatar
chenych committed
199
200
201
202
203
            model = auto_class.from_pretrained(
                model_config.model_path,
                config=self.model_config,
                torch_dtype=torch_dtype,
                attn_implementation="flash_attention_2",
chenych's avatar
chenych committed
204
                device_map="cpu" if fsdp_config.enable_rank0_init else "cuda",
chenych's avatar
chenych committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
                low_cpu_mem_usage=True,
                trust_remote_code=model_config.trust_remote_code,
            )
        else:
            with no_init_weights(), init_empty_weights():
                model = auto_class.from_config(
                    self.model_config,
                    torch_dtype=torch_dtype,
                    attn_implementation="flash_attention_2",
                    trust_remote_code=model_config.trust_remote_code,
                )

        assert isinstance(model, PreTrainedModel)  # lint
        model.tie_weights()  # avoid hanging
        model = model.to(torch_dtype)
        if model_config.enable_gradient_checkpointing:
            model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

chenych's avatar
chenych committed
223
224
225
226
227
228
229
230
231
232
        if not (self._is_actor or self._is_critic):
            model.requires_grad_(False)

        if model_config.freeze_vision_tower:
            if hasattr(model, "visual"):
                model.visual.requires_grad_(False)
                fsdp_config.use_orig_params = True
                self.print_rank0("Vision tower is set to not trainable.")
            else:
                self.print_rank0("No vision tower found.")
chenych's avatar
chenych committed
233

chenych's avatar
chenych committed
234
235
236
        dist.barrier()
        print_model_size(model)
        print_gpu_memory_usage("After huggingface model init")
chenych's avatar
chenych committed
237
238
239
240
241
242
        mixed_precision = MixedPrecision(
            param_dtype=PrecisionType.to_dtype(fsdp_config.mp_param_dtype),
            reduce_dtype=PrecisionType.to_dtype(fsdp_config.mp_reduce_dtype),
            buffer_dtype=PrecisionType.to_dtype(fsdp_config.mp_buffer_dtype),
        )
        auto_wrap_policy = get_fsdp_wrap_policy(model)
chenych's avatar
chenych committed
243
244
245
246
247
248
249
        self.print_rank0(f"FSDP wrap policy: {auto_wrap_policy}.")

        if self.device_mesh.ndim == 2:
            if fsdp_config.enable_full_shard:
                sharding_strategy = ShardingStrategy.HYBRID_SHARD
            else:
                sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
chenych's avatar
chenych committed
250
        else:
chenych's avatar
chenych committed
251
252
253
254
            if fsdp_config.enable_full_shard:
                sharding_strategy = ShardingStrategy.FULL_SHARD
            else:
                sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
chenych's avatar
chenych committed
255

chenych's avatar
chenych committed
256
257
        if fsdp_config.enable_cpu_offload:
            cpu_offload = CPUOffload(offload_params=True)
chenych's avatar
chenych committed
258
259
260
        else:
            cpu_offload = None

chenych's avatar
chenych committed
261
262
263
264
265
266
        if fsdp_config.enable_rank0_init:
            sync_module_states = True
            param_init_fn = get_init_fn(model, device="cuda") if self.rank != 0 else None
        else:
            sync_module_states = False
            param_init_fn = None
chenych's avatar
chenych committed
267
268
269
270
271
272
273

        self.fsdp_module = FSDP(
            model,
            sharding_strategy=sharding_strategy,
            cpu_offload=cpu_offload,
            auto_wrap_policy=auto_wrap_policy,
            mixed_precision=mixed_precision,
chenych's avatar
chenych committed
274
            param_init_fn=param_init_fn,
chenych's avatar
chenych committed
275
            device_id=torch.cuda.current_device(),
chenych's avatar
chenych committed
276
            sync_module_states=sync_module_states,
chenych's avatar
chenych committed
277
            forward_prefetch=False,
chenych's avatar
chenych committed
278
            use_orig_params=fsdp_config.use_orig_params,
chenych's avatar
chenych committed
279
280
            device_mesh=self.device_mesh,
        )
chenych's avatar
chenych committed
281
        print_gpu_memory_usage("After FSDP module init")
chenych's avatar
chenych committed
282
283

        if self._is_actor or self._is_critic:
chenych's avatar
chenych committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            if optim_config.strategy == "adamw":
                self.optimizer = torch.optim.AdamW(
                    self.fsdp_module.parameters(),
                    lr=optim_config.lr,
                    betas=optim_config.betas,
                    weight_decay=optim_config.weight_decay,
                    fused=True,
                )
            elif optim_config.strategy == "adamw_bf16":
                self.optimizer = AnyPrecisionAdamW(
                    self.fsdp_module.parameters(),
                    lr=optim_config.lr,
                    betas=optim_config.betas,
                    weight_decay=optim_config.weight_decay,
                )
            else:
                raise NotImplementedError(f"Optimizer {optim_config.strategy} not supported.")

            num_warmup_steps = int(optim_config.lr_warmup_ratio * optim_config.training_steps)
chenych's avatar
chenych committed
303
304
305
            self.lr_scheduler = get_constant_schedule_with_warmup(
                optimizer=self.optimizer, num_warmup_steps=num_warmup_steps
            )
chenych's avatar
chenych committed
306
            print_gpu_memory_usage("After optimizer init")
chenych's avatar
chenych committed
307
308
309
310
311
312
313
        else:
            self.optimizer, self.lr_scheduler = None, None

    def _build_rollout(self) -> None:
        tp_size = self.config.rollout.tensor_parallel_size
        dp_size = self.world_size // tp_size
        assert self.world_size % tp_size == 0, (
chenych's avatar
chenych committed
314
            f"rollout world size: {self.world_size} is not divisible by tp size: {tp_size}"
chenych's avatar
chenych committed
315
        )
chenych's avatar
chenych committed
316
        rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp"))
chenych's avatar
chenych committed
317
318
319
320
321
322
323
324
325
326
        self.rollout = vLLMRollout(
            model_path=self.config.actor.model.model_path,
            config=self.config.rollout,
            tokenizer=self.tokenizer,
        )
        self.rollout_sharding_manager = FSDPVLLMShardingManager(
            module=self.fsdp_module,
            inference_engine=self.rollout.inference_engine,
            device_mesh=rollout_device_mesh,
        )
chenych's avatar
chenych committed
327
        print_gpu_memory_usage("After vllm init")
chenych's avatar
chenych committed
328
329
330
331
332
333
334
335

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def init_model(self):
        if self._is_critic:
            model_config = self.config.critic.model
            fsdp_config = self.config.critic.fsdp
            optim_config = self.config.critic.optim
            padding_free = self.config.critic.padding_free
chenych's avatar
chenych committed
336
337
            role = "critic"
        elif self._is_actor:
chenych's avatar
chenych committed
338
339
340
341
            model_config = self.config.actor.model
            fsdp_config = self.config.actor.fsdp
            optim_config = self.config.actor.optim
            padding_free = self.config.actor.padding_free
chenych's avatar
chenych committed
342
343
344
345
346
347
348
349
350
            role = "actor"
        elif self._is_ref:
            model_config = self.config.actor.model
            fsdp_config = self.config.ref.fsdp
            optim_config = None
            padding_free = self.config.ref.padding_free
            role = "ref"
        else:
            raise ValueError(f"Unknown role {role}.")
chenych's avatar
chenych committed
351
352
353
354
355
356
357
358

        if self._is_actor or self._is_critic or self._is_ref:
            self._build_model_optimizer(
                model_config=model_config,
                fsdp_config=fsdp_config,
                optim_config=optim_config,
                padding_free=padding_free,
            )
chenych's avatar
chenych committed
359
360
361
362
363
            if self._use_param_offload:
                offload_fsdp_model(self.fsdp_module)
                print_gpu_memory_usage(f"After offload {role} model during init")

            if self._use_optimizer_offload:
chenych's avatar
chenych committed
364
                offload_fsdp_optimizer(optimizer=self.optimizer)
chenych's avatar
chenych committed
365
                print_gpu_memory_usage(f"After offload {role} optimizer during init")
chenych's avatar
chenych committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

        if self._is_actor:
            self.actor = DataParallelPPOActor(
                config=self.config.actor,
                actor_module=self.fsdp_module,
                actor_optimizer=self.optimizer,
            )

        if self._is_critic:
            self.critic = DataParallelPPOCritic(
                config=self.config,
                critic_module=self.fsdp_module,
                critic_optimizer=self.optimizer,
            )

        if self._is_rollout:
            self._build_rollout()

        if self._is_ref:
chenych's avatar
chenych committed
385
386
387
388
            self.ref_policy = DataParallelPPOActor(
                config=self.config.ref,
                actor_module=self.fsdp_module,
            )
chenych's avatar
chenych committed
389
390
391
392
393
394
395

        if self._is_actor or self._is_critic:
            self.flops_counter = FlopsCounter(self.model_config)
            self.checkpoint_manager = FSDPCheckpointManager(
                model=self.fsdp_module,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
chenych's avatar
chenych committed
396
                processing_class=self.processor if self.processor is not None else self.tokenizer,
chenych's avatar
chenych committed
397
398
399
            )

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
chenych's avatar
chenych committed
400
    def save_checkpoint(self, path: str):
chenych's avatar
chenych committed
401
402
403
404
        assert self._is_actor or self._is_critic
        if self._use_param_offload:
            load_fsdp_model(self.fsdp_module)

chenych's avatar
chenych committed
405
        self.checkpoint_manager.save_checkpoint(path)
chenych's avatar
chenych committed
406
407
408
409
410
        dist.barrier()
        if self._use_param_offload:
            offload_fsdp_model(self.fsdp_module)

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
chenych's avatar
chenych committed
411
    def load_checkpoint(self, path: str):
chenych's avatar
chenych committed
412
413
414
        if self._use_param_offload:
            load_fsdp_model(self.fsdp_module)

chenych's avatar
chenych committed
415
        self.checkpoint_manager.load_checkpoint(path)
chenych's avatar
chenych committed
416
417
418
419
        dist.barrier()
        if self._use_param_offload:
            offload_fsdp_model(self.fsdp_module)

chenych's avatar
Update  
chenych committed
420
        if self._use_optimizer_offload:  # avoid OOM in resuming
chenych's avatar
chenych committed
421
            offload_fsdp_optimizer(self.optimizer)
chenych's avatar
chenych committed
422
423
424
425

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    def update_actor(self, data: DataProto):
        assert self._is_actor
chenych's avatar
chenych committed
426
        data = data.to(torch.cuda.current_device())
chenych's avatar
chenych committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

        if self._use_param_offload:
            load_fsdp_model(self.fsdp_module)

        if self._use_optimizer_offload:
            load_fsdp_optimizer(optimizer=self.optimizer)

        with self.ulysses_sharding_manager:
            data = self.ulysses_sharding_manager.preprocess_data(data=data)
            with Timer(name="update_policy", logger=None) as timer:
                metrics = self.actor.update_policy(data=data)

            delta_time = timer.last
            global_num_tokens = data.meta_info["global_token_num"]
            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
chenych's avatar
chenych committed
442
443
444
445
446
447
448
449
450
451
            metrics["perf/mfu_actor"] = (
                estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size)
            )
            metrics["perf/max_memory_allocated_gb"] = (
                torch.cuda.max_memory_allocated() - self.rollout_sharding_manager.freed_bytes
            ) / (1024**3)
            metrics["perf/max_memory_reserved_gb"] = (
                torch.cuda.max_memory_reserved() - self.rollout_sharding_manager.freed_bytes
            ) / (1024**3)
            metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
chenych's avatar
chenych committed
452
453
454
455
456

            self.lr_scheduler.step()
            lr = self.lr_scheduler.get_last_lr()[0]
            metrics["actor/lr"] = lr

chenych's avatar
chenych committed
457
458
459
460
461
462
            # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
            output = DataProto(
                non_tensor_batch={
                    key: np.array([value] if np.isscalar(value) else value) for key, value in metrics.items()
                }
            )
chenych's avatar
chenych committed
463
464
465
466
467
468
469

        if self._use_param_offload:
            offload_fsdp_model(self.fsdp_module)

        if self._use_optimizer_offload:
            offload_fsdp_optimizer(optimizer=self.optimizer)

chenych's avatar
chenych committed
470
        output = output.to("cpu")
chenych's avatar
chenych committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        return output

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    def generate_sequences(self, prompts: DataProto):
        assert self._is_rollout

        if self._use_param_offload:
            load_fsdp_model(self.fsdp_module)

        meta_info = {
            "eos_token_id": self.generation_config.eos_token_id
            if self.generation_config is not None
            else self.tokenizer.eos_token_id,
            "pad_token_id": self.generation_config.pad_token_id
            if self.generation_config is not None
            else self.tokenizer.pad_token_id,
        }
        prompts.meta_info.update(meta_info)
        with self.rollout_sharding_manager:
            # after parameters sync with rollout, offload actor model to CPU
            if self._use_param_offload:
                offload_fsdp_model(self.fsdp_module)

            if self._use_optimizer_offload:
                offload_fsdp_optimizer(optimizer=self.optimizer)

            prompts = self.rollout_sharding_manager.preprocess_data(prompts)
            output = self.rollout.generate_sequences(prompts=prompts)
            output = self.rollout_sharding_manager.postprocess_data(output)

        output = output.to("cpu")
        return output

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
chenych's avatar
chenych committed
505
    def compute_log_probs(self, data: DataProto):
chenych's avatar
chenych committed
506
        assert self._is_actor
chenych's avatar
chenych committed
507
        data = data.to(torch.cuda.current_device())
chenych's avatar
chenych committed
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        if self._use_param_offload:
            load_fsdp_model(self.fsdp_module)

        # we should always recompute old_log_probs when it is HybridEngine
        data.meta_info["temperature"] = self.config.rollout.temperature
        # perform recompute log_prob
        with self.ulysses_sharding_manager:
            data = self.ulysses_sharding_manager.preprocess_data(data)
            output = self.actor.compute_log_prob(data=data)
            output = DataProto.from_dict(
                tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}
            )
            output = self.ulysses_sharding_manager.postprocess_data(output)

        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
        # unshard the root FSDP module
        if self.world_size > 1:
            self.fsdp_module._handle.reshard(True)

        if self._use_param_offload:
            offload_fsdp_model(self.fsdp_module)

chenych's avatar
chenych committed
530
        output = output.to("cpu")
chenych's avatar
chenych committed
531
532
533
        return output

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
chenych's avatar
chenych committed
534
    def compute_ref_log_probs(self, data: DataProto):
chenych's avatar
chenych committed
535
        assert self._is_ref
chenych's avatar
chenych committed
536
        data = data.to(torch.cuda.current_device())
chenych's avatar
chenych committed
537
538
539
540
541
542
543
        if self._use_param_offload:
            load_fsdp_model(self.fsdp_module)

        data.meta_info["temperature"] = self.config.rollout.temperature
        with self.ulysses_sharding_manager:
            data = self.ulysses_sharding_manager.preprocess_data(data)
            output = self.ref_policy.compute_log_prob(data=data)
chenych's avatar
chenych committed
544
            output = DataProto.from_dict(tensors={"ref_log_probs": output})
chenych's avatar
chenych committed
545
546
547
548
549
550
551
552
553
554
            output = self.ulysses_sharding_manager.postprocess_data(output)

        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
        # unshard the root FSDP module
        if self.world_size > 1:
            self.fsdp_module._handle.reshard(True)

        if self._use_param_offload:
            offload_fsdp_model(self.fsdp_module)

chenych's avatar
chenych committed
555
        output = output.to("cpu")
chenych's avatar
chenych committed
556
557
558
559
560
        return output

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    def compute_values(self, data: DataProto):
        assert self._is_critic
chenych's avatar
chenych committed
561
        data = data.to(torch.cuda.current_device())
chenych's avatar
chenych committed
562
563
564
565
566
567
568
569
570
571
572
573
        if self._use_param_offload:
            load_fsdp_model(self.fsdp_module)

        with self.ulysses_sharding_manager:
            data = self.ulysses_sharding_manager.preprocess_data(data=data)
            values = self.critic.compute_values(data=data)
            output = DataProto.from_dict(tensors={"values": values})
            output = self.ulysses_sharding_manager.postprocess_data(data=output)

        if self._use_param_offload:
            offload_fsdp_model(self.fsdp_module)

chenych's avatar
chenych committed
574
        output = output.to("cpu")
chenych's avatar
chenych committed
575
576
577
578
        return output

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    def update_critic(self, data: DataProto):
chenych's avatar
chenych committed
579
        data = data.to(torch.cuda.current_device())
chenych's avatar
chenych committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        if self._use_param_offload:
            load_fsdp_model(self.fsdp_module)

        if self._use_optimizer_offload:
            load_fsdp_optimizer(optimizer=self.optimizer)

        with self.ulysses_sharding_manager:
            data = self.ulysses_sharding_manager.preprocess_data(data=data)
            with Timer(name="update_critic", logger=None) as timer:
                metrics = self.critic.update_critic(data=data)

            delta_time = timer.last
            global_num_tokens = data.meta_info["global_token_num"]
            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
chenych's avatar
chenych committed
594
595
596
            metrics["perf/mfu_critic"] = (
                estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size)
            )
chenych's avatar
chenych committed
597
598
599
600
601

            self.lr_scheduler.step()
            lr = self.lr_scheduler.get_last_lr()[0]
            metrics["critic/lr"] = lr

chenych's avatar
chenych committed
602
603
604
605
606
607
            # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
            output = DataProto(
                non_tensor_batch={
                    metric: np.array([value] if np.isscalar(value) else value) for metric, value in metrics.items()
                }
            )
chenych's avatar
chenych committed
608
609
610
611
612
613
614

        if self._use_param_offload:
            offload_fsdp_model(self.fsdp_module)

        if self._use_optimizer_offload:
            offload_fsdp_optimizer(optimizer=self.optimizer)

chenych's avatar
chenych committed
615
        output = output.to("cpu")
chenych's avatar
chenych committed
616
        return output