# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2023-2024 SGLang Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from omegaconf import OmegaConf, open_dict from verl.single_controller.base.decorator import Dispatch, register from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.flops_counter import FlopsCounter from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer from verl.utils.import_utils import import_external_libs from verl.utils.profiler import log_gpu_memory_usage from verl.workers.fsdp_workers import ActorRolloutRefWorker logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) class SPPOActorRolloutRefWorker(ActorRolloutRefWorker): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy or a hybrid engine based on the config.rollout """ @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from .dp_actor import DataParallelSPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) use_remove_padding = self.config.model.get("use_remove_padding", False) use_fused_kernels = self.config.model.get("use_fused_kernels", False) if self._is_actor or self._is_rollout: # we need the model for actor and rollout if self._is_actor: optim_config = self.config.actor.optim fsdp_config = self.config.actor.fsdp_config else: optim_config = None fsdp_config = OmegaConf.create() self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( self._build_model_optimizer( model_path=self.config.model.path, fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", ) ) # get the original unwrapped module self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After offload actor model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) # load from checkpoint if self._is_actor: OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding self.config.actor.use_fused_kernels = use_fused_kernels self.actor = DataParallelSPPOActor( config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer ) if self._is_rollout: self.rollout, self.rollout_sharding_manager = self._build_rollout( trust_remote_code=self.config.model.get("trust_remote_code", False) ) if self._is_ref: self.ref_module_fsdp = self._build_model_optimizer( model_path=self.config.model.path, fsdp_config=self.config.ref.fsdp_config, optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", )[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.actor.checkpoint, )