# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ The concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP) """ import gc import itertools import logging import os import warnings from typing import Callable import torch import torch.distributed from omegaconf import OmegaConf from peft import LoraConfig, TaskType, get_peft_model from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from verl import DataProto from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage from verl.utils.device import ( get_device_id, get_device_name, get_torch_device, is_cuda_available, is_npu_available, ) from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, apply_fsdp2, fsdp2_clip_grad_norm_, fsdp2_load_full_state_dict, get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer, ) from verl.utils.import_utils import import_external_libs from verl.utils.py_functional import append_to_dict, convert_to_regular_types from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager if is_cuda_available: from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input elif is_npu_available: from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input from ..base import BaseEngine, EngineRegistry from .utils import create_device_mesh, get_sharding_strategy logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) device_name = get_device_name() @EngineRegistry.register("fsdp") class FSDPEngine(BaseEngine): """ Concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP). Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism. """ def __init__(self, config): """ Initialize the FSDPEngine. Sets up distributed device meshes, LoRA, and offload policies based on config. Args: config: Configuration object with FSDP and model settings. """ self.config = config self.rank = torch.distributed.get_rank() # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.use_remove_padding = config.model.get("use_remove_padding", False) self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) # set FSDP offload params self._is_offload_param = self.config.model.fsdp_config.param_offload self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config self.config.ppo_mini_batch_size *= self.config.rollout_n self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size if self.config.ppo_micro_batch_size is not None: self.config.ppo_micro_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) self.config.forward_micro_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size if self.config.ppo_micro_batch_size_per_gpu is not None: assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) self._is_lora = self.config.model.get("lora_rank", 0) > 0 def init_model(self): """ Build the model, optimizer, and learning rate scheduler under FSDP. Applies device, dtype, and precision configurations, including mixed precision. Sets up checkpoint manager and FLOPs counter. """ # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) self.module, self.optimizer, self.lr_scheduler = self._build_model_optimizer(self.config) if self._is_offload_param: offload_fsdp_model_to_cpu(self.module) log_gpu_memory_usage("After offload model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.optimizer) log_gpu_memory_usage("After offload optimizer during init", logger=logger) self.flops_counter = FlopsCounter(self.model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.module, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.checkpoint, ) def _build_model_optimizer(self, config): # the following line is necessary from torch import optim from torch.distributed.fsdp import MixedPrecision from verl.utils.model import load_valuehead_model, print_model_size from verl.utils.torch_dtypes import PrecisionType use_shm = config.model.get("use_shm", False) local_path = copy_to_local(config.model.path, use_shm=use_shm) # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info # using random initialized model from any architecture. May not be the same as Actor. tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) if self.config.model.get("custom_chat_template", None) is not None: if self.processor is not None: self.processor.chat_template = self.config.model.custom_chat_template else: self.tokenizer.chat_template = self.config.model.custom_chat_template override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_config) if self.rank == 0: print(f"Engine overriding config {override_config_kwargs}") torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig model_config = AutoConfig.from_pretrained( local_path, attn_implementation="flash_attention_2", trust_remote_code=config.model.get("trust_remote_code", False), ) model_config.num_labels = 1 # patch for kimi-vl if getattr(model_config, "model_type", None) == "kimi_vl": model_config.text_config.topk_method = "greedy" init_context = get_init_weight_context_manager( use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") model_config.classifier_dropout = 0.0 model_config.hidden_dropout = "0" model_config.summary_dropout_prob = 0.0 module = load_valuehead_model( local_path, torch_dtype, model_config, config.model.get("trust_remote_code", False), ) apply_monkey_patch( model=module, use_remove_padding=self.use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, ) # some parameters may not in torch_dtype module.to(torch_dtype) if config.model.get("enable_gradient_checkpointing", False): module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if self._is_lora: print("Applying LoRA to the module") module.enable_input_require_grads() # Convert config to regular Python types before creating PEFT model lora_config = { "task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, "target_modules": convert_to_regular_types(self.config.model.target_modules), "bias": "none", } module = get_peft_model(module, LoraConfig(**lora_config)) if self.rank == 0: print_model_size(module) self.model_config = model_config fsdp_config = self.config.model.fsdp_config mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) auto_wrap_policy = get_fsdp_wrap_policy( module=module, config=self.config.model.fsdp_config.wrap_policy, is_lora=self.config.model.get("lora_rank", 0) > 0, ) log_gpu_memory_usage("Before FSDP", logger=None) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload because it causes incorrect results when using grad accumulation if config.strategy == "fsdp": module = FSDP( module, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, forward_prefetch=self.config.model.fsdp_config.forward_prefetch, device_mesh=self.device_mesh, cpu_offload=None, ) elif config.strategy == "fsdp2": assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True ) offload_policy = None if fsdp_config.offload_policy: self._is_offload_param = False self._is_offload_optimizer = False offload_policy = CPUOffloadPolicy(pin_memory=True) fsdp_kwargs = { "mesh": fsdp_mesh, "mp_policy": mp_policy, "offload_policy": offload_policy, "reshard_after_forward": fsdp_config.reshard_after_forward, } full_state = module.state_dict() apply_fsdp2(module, fsdp_kwargs, fsdp_config) fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy) else: raise NotImplementedError(f"Unknown strategy {config.strategy}") if config.model.get("enable_activation_offload", False): enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) enable_activation_offloading(module, config.strategy, enable_gradient_checkpointing) log_gpu_memory_usage("After FSDP", logger=None) optimizer = optim.AdamW( module.parameters(), lr=config.optim.lr, betas=config.optim.get("betas", (0.9, 0.999)), weight_decay=config.optim.get("weight_decay", 1e-2), ) total_steps = config.optim.get("total_training_steps", 0) num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) warmup_style = config.optim.get("warmup_style", "constant") if num_warmup_steps < 0: num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup if warmup_style == "constant": lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps) elif warmup_style == "cosine": lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps ) else: raise NotImplementedError(f"Warmup style {warmup_style} is not supported") return module, optimizer, lr_scheduler def train_mode(self): """ Return a context manager that switches to training mode with FSDP-specific handling. Includes parameter and optimizer offload entry/exit. """ return EngineTrainModeCtx(self) def eval_mode(self): """ Return a context manager that switches to evaluation mode with FSDP-specific handling. Includes activation offload entry/exit. """ return EngineEvalModeCtx(self) def shard_data(self, data): """ Preprocess data into sharded format via UlyssesShardingManager. """ return self.ulysses_sharding_manager.preprocess_data(data) def unshard_data(self, data): """ Postprocess data from sharded format back to full format. """ return self.ulysses_sharding_manager.postprocess_data(data) def get_default_ctx(self): use_value_head_model = hasattr(self.module, "v_head") ctx = { "use_value_head_model": use_value_head_model, "ulysses_sequence_parallel_size": self.ulysses_sequence_parallel_size, } return ctx def _forward_micro_batch(self, micro_batch): multi_modal_inputs = {} if "multi_modal_inputs" in micro_batch.keys(): for key in micro_batch["multi_modal_inputs"][0].keys(): multi_modal_inputs[key] = torch.cat( [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 ) with torch.autocast(device_type=device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] position_ids = micro_batch["position_ids"] if position_ids.dim() == 3: # qwen2vl mrope position_ids = position_ids.transpose(0, 1) if self.use_remove_padding: input_ids_rmpad, indices, *_ = unpad_input( input_ids.unsqueeze(-1), attention_mask ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary if position_ids.dim() == 3: position_ids_rmpad = ( index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) .transpose(0, 1) .unsqueeze(1) ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: position_ids_rmpad = index_first_axis( rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size ) # only pass input_ids and position_ids to enable flash_attn_varlen preds = self.module( input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, **multi_modal_inputs, use_cache=False, ) # prevent model thinks we are generating if hasattr(self.module, "v_head"): # For trl.AutoModelForCausalLMWithValueHead preds_rmpad = preds[2].squeeze(0).unsqueeze(-1) else: preds_rmpad = preds.logits preds_rmpad = preds_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: preds_rmpad = gather_outpus_and_unpad(preds_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) # pad it back preds = pad_input(preds_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) else: preds = self.module( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, **multi_modal_inputs, use_cache=False, ) # prevent model thinks we are generating if hasattr(self.module, "v_head"): # For trl.AutoModelForCausalLMWithValueHead preds = preds[2] else: preds = preds.logits return preds def infer_batch( self, data: DataProto, post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]], ) -> dict[str, torch.Tensor]: """ Perform inference on a mini batch of data. Args: data: The input data for inference, typically containing tensors and metadata. post_fn: A post-processing function that takes a micro-batch and predictions as input, and returns a tuple containing processed predictions and a dictionary of outputs. Returns: dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch. """ assert self.mode == "eval" micro_batch_size = data.meta_info["micro_batch_size"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] batch = data.select(batch_keys=select_keys).batch use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() if has_multi_modal_inputs: num_micro_batches = data.batch.batch_size[0] // micro_batch_size non_tensor_select_keys = ["multi_modal_inputs"] micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) elif use_dynamic_bsz: # split using dynamic bsz max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) else: micro_batches = batch.split(micro_batch_size) preds_list = {} for micro_batch in micro_batches: if isinstance(micro_batch, DataProto): micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} with torch.no_grad(): # micro_batch_preds would be a dict[str, torch.Tensor] preds = self._forward_micro_batch(micro_batch) _, outputs = post_fn(micro_batch, preds) assert isinstance(outputs, dict) # append micro batch preds to dict[str, List[torch.Tensor]] append_to_dict(preds_list, outputs) # reorganize mini batch preds from # dict[str, List[torch.Tensor]] to dict[str, torch.Tensor] mini_batch_preds = {} for key, t_list in preds_list.items(): t_concat = torch.concat(t_list, dim=0) if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) assert len(indices) == t_concat.size(0), f"{len(indices)} vs. {t_concat.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) t_concat = t_concat[revert_indices] mini_batch_preds[key] = t_concat return mini_batch_preds def train_batch( self, data: DataProto, loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]], ) -> dict[str, torch.Tensor]: """ Perform a training step on a mini-batch of data. Args: data (DataProto): The input data for training, typically containing tensors and metadata. loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions. Returns: dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch. """ assert self.mode == "train" # split batch into micro_batches mini_batch = data select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids"] if "multi_modal_inputs" in mini_batch: non_tensor_select_keys = ["multi_modal_inputs"] num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu micro_batches = mini_batch.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu elif self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu mini_batch_metrics = {} for micro_batch in micro_batches: # Support all devices if isinstance(micro_batch, DataProto): micro_batch = {**micro_batch.batch.to(get_device_id()), **micro_batch.non_tensor_batch} else: micro_batch = micro_batch.to(get_device_id()) # critic device is cpu when using offload preds = self._forward_micro_batch(micro_batch) loss, micro_batch_metrics = loss_fn(micro_batch, preds) append_to_dict(mini_batch_metrics, micro_batch_metrics) loss.backward() return mini_batch_metrics def optimizer_zero_grad(self): """ Zero gradients and enforce FSDP grad-clipping logic. """ self.optimizer.zero_grad() def optimizer_step(self): """ Clip gradients, skip update if non-finite, and step optimizer. Returns: grad_norm (float): Norm of gradients before clipping. """ assert self.config.grad_clip is not None if isinstance(self.module, FSDP): grad_norm = self.module.clip_grad_norm_(self.config.grad_clip) elif isinstance(self.module, FSDPModule): grad_norm = fsdp2_clip_grad_norm_(self.module.parameters(), max_norm=self.config.grad_clip) else: grad_norm = torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm=self.config.grad_clip) # if grad_norm is not finite, skip the update if not torch.isfinite(grad_norm): print(f"WARN: grad_norm is not finite: {grad_norm}") self.optimizer.zero_grad() else: self.optimizer.step() return grad_norm def lr_scheduler_step(self): """ Advance FSDP scheduler and return updated learning rate. """ self.lr_scheduler.step() lr = self.lr_scheduler.get_last_lr() return lr def to(self, device: str, model: bool = True, optimizer: bool = True): """ Move FSDP model and/or optimizer to CPU or GPU with offload support. """ assert device in ("cuda", "cpu") if device == "cuda": if not self.config.model.fsdp_config.param_offload: if model: load_fsdp_model_to_gpu(self.model_module) if optimizer and self.optimizer is not None: load_fsdp_optimizer(self.optimizer, device) gc.collect() elif device == "cpu": if not self.config.model.fsdp_config.param_offload: if model: offload_fsdp_model_to_cpu(self.model_module) if optimizer and self.optimizer is not None: offload_fsdp_optimizer(self.optimizer) else: raise ValueError(f"Invalid device type: {device}") def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): """ Save FSDP checkpoint, handling parameter offload as needed. """ if self._is_offload_param: load_fsdp_model_to_gpu(self.module) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.module) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): """ Load FSDP checkpoint, restoring parameters and optimizer state. """ import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.module) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.module) if self._is_offload_optimizer: offload_fsdp_optimizer(self.optimizer) class EngineEvalModeCtx: def __init__(self, engine): self.engine = engine def __enter__(self): self.engine.mode = "eval" if self.engine._is_offload_param: load_fsdp_model_to_gpu(self.engine.module) self.engine.ulysses_sharding_manager.__enter__() self.engine.module.eval() def __exit__(self, exc_type, exc_value, traceback): self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback) if self.engine._is_offload_param: offload_fsdp_model_to_cpu(self.engine.module) self.engine.mode = None class EngineTrainModeCtx: def __init__(self, engine): self.engine = engine def __enter__(self): self.engine.mode = "train" if self.engine._is_offload_param: load_fsdp_model_to_gpu(self.engine.module) if self.engine._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.engine.optimizer, device_id=get_torch_device().current_device()) self.engine.ulysses_sharding_manager.__enter__() self.engine.module.train() def __exit__(self, exc_type, exc_value, traceback): self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback) if self.engine._is_offload_param: offload_fsdp_model_to_cpu(self.engine.module) if self.engine._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.optimizer) self.engine.mode = None