# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2025 Meituan Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ import os import socket import hydra import ray from omegaconf import OmegaConf from verl.trainer.constants_ppo import get_ppo_ray_runtime_env from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler from verl.trainer.ppo.reward import load_reward_manager from .ray_trainer import OneStepOffRayTrainer @hydra.main(config_path="config", config_name="one_step_off_ppo_trainer", version_base=None) def main(config): run_ppo(config) # Define a function to run the PPO-like training process def run_ppo(config) -> None: # Check if Ray is not initialized if not ray.is_initialized(): # Initialize Ray with a local cluster configuration # Set environment variables in the runtime environment to control tokenizer parallelism, # NCCL debug level, VLLM logging level, and allow runtime LoRA updating # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration ray.init( runtime_env=get_ppo_ray_runtime_env(), num_cpus=config.ray_init.num_cpus, ) # Create a remote instance of the TaskRunner class, and # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete if ( OmegaConf.select(config.trainer, "profile_steps") is not None and len(OmegaConf.select(config.trainer, "profile_steps")) > 0 ): nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() else: runner = TaskRunner.remote() ray.get(runner.run.remote(config)) # [Optional] get the path of the timeline trace file from the configuration, default to None # This file is used for performance analysis timeline_json_file = config.ray_init.get("timeline_json_file", None) if timeline_json_file: ray.timeline(filename=timeline_json_file) @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: def run(self, config): # Print the initial configuration. `resolve=True` will evaluate symbolic values. from pprint import pprint from omegaconf import OmegaConf from verl.utils.fs import copy_to_local print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") pprint(OmegaConf.to_container(config, resolve=True)) OmegaConf.resolve(config) # Download the checkpoint from HDFS to the local machine. # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on local_path = copy_to_local( config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) ) # Instantiate the tokenizer and processor. from verl.utils import hf_processor, hf_tokenizer trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) # Used for multimodal LLM, could be None processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) # Define worker classes based on the actor strategy. if config.actor_rollout_ref.actor.strategy == "fsdp2": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray import RayWorkerGroup from .fsdp_workers import ( ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker, RolloutWorker, ) actor_rollout_cls = ( AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker ) ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from .megatron_workers import ( ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker, RolloutWorker, ) actor_rollout_cls = ( AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker ) ray_worker_group_cls = NVMegatronRayWorkerGroup else: raise NotImplementedError from .ray_trainer import ResourcePoolManager, Role role_worker_mapping = { Role.Actor: ray.remote(actor_rollout_cls), Role.Rollout: ray.remote(RolloutWorker), Role.Critic: ray.remote(CriticWorker), } global_pool_id = "actor_pool" assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" actor_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes resource_pool_spec = { "actor_pool": actor_pool, "rollout_pool": rollout_pool, } mapping = { Role.Actor: "actor_pool", Role.Rollout: "rollout_pool", Role.Critic: "actor_pool", } print(f"resource_pool_spec: {resource_pool_spec}") # We should adopt a multi-source reward function here: # - for rule-based rm, we directly call a reward score # - for model-based rm, we call a model # - for code related prompt, we send to a sandbox if there are test cases # finally, we combine all the rewards together # The reward type depends on the tag of the data if config.reward_model.enable: if config.reward_model.strategy in ["fsdp2"]: from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) mapping[Role.RewardModel] = global_pool_id # Add a reference policy worker if KL loss or KL reward is used. if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id # Load the reward manager for training and validation. reward_fn = load_reward_manager( config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) ) val_reward_fn = load_reward_manager( config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) ) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) from verl.utils.dataset.rl_dataset import collate_fn # Create training and validation datasets. train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) train_sampler = create_rl_sampler(config.data, train_dataset) # Initialize the PPO trainer. trainer = OneStepOffRayTrainer( config=config, tokenizer=tokenizer, processor=processor, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, train_dataset=train_dataset, val_dataset=val_dataset, collate_fn=collate_fn, train_sampler=train_sampler, device_name=config.trainer.device, ) # Initialize the workers of the trainer. trainer.init_workers() # Start the training process. trainer.fit() if __name__ == "__main__": main()