# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2023-2024 SGLang Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import warnings import psutil import torch import torch.distributed from codetiming import Timer from omegaconf import OmegaConf, open_dict from torch.distributed.device_mesh import init_device_mesh import verl.utils.torch_functional as verl_F from verl import DataProto from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer, ) from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask from verl.utils.profiler import log_gpu_memory_usage from verl.workers.fsdp_workers import ActorRolloutRefWorker from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: device_mesh = init_device_mesh( get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] ) return device_mesh def get_sharding_strategy(device_mesh): from torch.distributed.fsdp import ShardingStrategy if device_mesh.ndim == 1: sharding_strategy = ShardingStrategy.FULL_SHARD elif device_mesh.ndim == 2: sharding_strategy = ShardingStrategy.HYBRID_SHARD else: raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") return sharding_strategy class SPINRolloutRefWorker(ActorRolloutRefWorker): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) use_remove_padding = self.config.model.get("use_remove_padding", False) use_fused_kernels = self.config.model.get("use_fused_kernels", False) if self._is_actor or self._is_rollout or self._is_ref: # we need the model for actor and rollout if self._is_actor or self._is_ref: optim_config = self.config.actor.optim fsdp_config = self.config.actor.fsdp_config else: optim_config = None fsdp_config = OmegaConf.create() self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( self._build_model_optimizer( model_path=self.config.model.path, fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", ) ) # get the original unwrapped module self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) # load from checkpoint if self._is_actor or self._is_ref: OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding self.config.actor.use_fused_kernels = use_fused_kernels self.actor = DataParallelPPOActor( config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer ) if self._is_rollout: self.rollout, self.rollout_sharding_manager = self._build_rollout( trust_remote_code=self.config.model.get("trust_remote_code", False) ) if self._is_ref: # self.ref_module_fsdp = self._build_model_optimizer( # model_path=self.config.model.path, # fsdp_config=self.config.ref.fsdp_config, # optim_config=None, # override_model_config=override_model_config, # use_remove_padding=use_remove_padding, # use_fused_kernels=use_fused_kernels, # trust_remote_code=self.config.model.get("trust_remote_code", False), # use_liger=self.config.model.get("use_liger", False), # role="ref", # )[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.actor.checkpoint, ) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.actor.checkpoint, ) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): assert self._is_ref # Support all hardwares data = data.to(get_device_id()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["temperature"] = self.config.rollout.temperature data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output = self.ref_policy.compute_log_prob(data=data) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.world_size > 1: self.ref_policy.actor_module._handle.reshard(True) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares data = data.to(get_device_id()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output = self.actor.compute_log_prob(data=data) output = DataProto.from_dict( tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature} ) output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.world_size > 1: self.actor.actor_module._handle.reshard(True) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After compute_log_prob", logger=logger) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor_dpo(self, data: DataProto): """ Wrapper for actor update step. Handles FSDP state management. Calls self.actor.update_policy which now contains DPO logic based on pre-calculated log probabilities. """ # Support all hardwares data = data.to(get_device_id()) assert self._is_actor # Make sure this worker has the actor role if self.actor is None: raise RuntimeError("Actor instance (self.actor) not initialized in worker.") # --- FSDP State Management --- if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) log_gpu_memory_usage("Before update policy (DPO via PPO path)", logger=logger) # --- Ulysses Sharding (if used) --- with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) # --- Call the core update method (now containing DPO logic) --- with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name # Calls the modified update_policy method metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION delta_time = timer.last # --- Add Performance Metrics --- # MFU calculation might be less accurate/meaningful here for DPO metrics["perf/approx_tokens_processed"] = torch.sum( data.batch.get("attention_mask", torch.tensor(0)) ).item() # Approx tokens metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size # --- LR Scheduler Step --- lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr self.actor_lr_scheduler.step() log_gpu_memory_usage("After update policy (DPO via PPO path)", logger=logger) # --- Prepare Output --- output = DataProto(meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") # --- FSDP State Management (Offload) --- if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) return output # TODO(sgm): we may need to extract it to dp_reward_model.py class RewardModelWorker(Worker): """ Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. """ def __init__(self, config): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend=get_nccl_backend()) self.config = config # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self.use_remove_padding = self.config.model.get("use_remove_padding", False) # normalize config if self.config.micro_batch_size is not None: self.config.micro_batch_size //= torch.distributed.get_world_size() self.config.micro_batch_size_per_gpu = self.config.micro_batch_size def _build_model(self, config): # the following line is necessary from torch.distributed.fsdp import CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import AutoConfig, AutoModelForTokenClassification # download the checkpoint from hdfs local_path = copy_to_local(config.model.path) if self.config.model.input_tokenizer is None: self._do_switch_chat_template = False else: self._do_switch_chat_template = True input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) self.input_tokenizer = hf_tokenizer( input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) ) self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) trust_remote_code = config.model.get("trust_remote_code", False) model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) model_config.num_labels = 1 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect init_context = get_init_weight_context_manager( use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") model_config.classifier_dropout = 0.0 reward_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, config=model_config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", trust_remote_code=trust_remote_code, ) if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) reward_module.to(torch.bfloat16) auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) reward_module = FSDP( reward_module, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), forward_prefetch=False, device_mesh=self.device_mesh, ) return reward_module @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) self.reward_module = self._build_model(config=self.config) def _forward_micro_batch(self, micro_batch): from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] position_ids = micro_batch["position_ids"] if self.use_remove_padding: input_ids_rmpad, indices, *_ = unpad_input( input_ids.unsqueeze(-1), attention_mask ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary position_ids_rmpad = index_first_axis( rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size ) # only pass input_ids and position_ids to enable flash_attn_varlen output = self.reward_module( input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False ) # prevent model thinks we are generating reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: reward_rmpad = gather_outputs_and_unpad( reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size ) # pad it back rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) else: output = self.reward_module( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False ) rm_score = output.logits # (batch_size, seq_len, 1) rm_score = rm_score.squeeze(-1) # extract the result of the last valid token eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] return rm_score def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): batch_size = data.batch.batch_size[0] # expand as token_level_reward attention_mask = data.batch["attention_mask"] position_ids = data.batch["position_ids"] response_length = data.batch["responses"].shape[-1] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores # select the response part token_level_scores = token_level_scores[:, -response_length:] return token_level_scores def _switch_chat_template(self, data: DataProto): src_max_length = data.batch["attention_mask"].shape[-1] src_tokenizer = self.input_tokenizer target_tokenizer = self.tokenizer rm_input_ids = [] rm_attention_mask = [] for i in range(data.batch.batch_size[0]): # extract raw prompt if isinstance(data.non_tensor_batch["raw_prompt"][i], list): chat: list = data.non_tensor_batch["raw_prompt"][i] else: chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() # extract response response_ids = data.batch["responses"][i] response_length = response_ids.shape[-1] valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode response = src_tokenizer.decode(valid_response_ids) # remove bos and eos response = response.replace(src_tokenizer.eos_token, "") chat.append({"role": "assistant", "content": response}) prompt_with_chat_template = target_tokenizer.apply_chat_template( chat, add_generation_prompt=False, tokenize=False ) if self.rank == 0 and i == 0: # for debugging purpose print(f"Switch template. chat: {prompt_with_chat_template}") # the maximum length is actually determined by the reward model itself max_length = self.config.get("max_length", src_max_length) if max_length is None: max_length = src_max_length model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) input_ids, attention_mask = verl_F.postprocess_data( input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], max_length=max_length, pad_token_id=target_tokenizer.pad_token_id, left_pad=False, # right padding truncation=self.config.get("truncation", "right"), ) # truncate from the right rm_input_ids.append(input_ids) rm_attention_mask.append(attention_mask) rm_input_ids = torch.cat(rm_input_ids, dim=0) rm_attention_mask = torch.cat(rm_attention_mask, dim=0) rm_position_ids = compute_position_id_with_mask(rm_attention_mask) rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} return DataProto.from_dict(rm_inputs) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto): import itertools from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches # Support all hardwares data = data.to(get_device_id()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: rm_input_ids = data.batch["input_ids"] rm_attention_mask = data.batch["attention_mask"] rm_position_ids = data.batch["position_ids"] rm_inputs = { "input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids, } rm_data = DataProto.from_dict(rm_inputs) # Support all hardwares rm_data.batch = rm_data.batch.to(get_device_id()) # perform forward computation with self.ulysses_sharding_manager: rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data) data = self.ulysses_sharding_manager.preprocess_data(data=data) use_dynamic_bsz = self.config.use_dynamic_bsz if use_dynamic_bsz: max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) else: micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) output = [] for micro_batch in micro_batches: rm_score = self._forward_micro_batch(micro_batch) output.append(rm_score) scores = torch.cat(output, dim=0) # (batch_size) if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) scores = scores[revert_indices] token_level_scores = self._expand_to_token_level(data, scores) # Note that this is only the scores, may not be the final rewards used to train RL output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) output = self.ulysses_sharding_manager.postprocess_data(data=output) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module self.reward_module._handle.reshard(True) output = output.to("cpu") return output