Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
......@@ -20,8 +20,8 @@ from typing import Any, Dict
import torch
from verl import DataProto
from verl.workers.critic.config import CriticConfig
from ...protocol import DataProto
from .config import CriticConfig
__all__ = ["BasePPOCritic"]
......
......@@ -17,17 +17,18 @@ Critic config
from dataclasses import dataclass, field
from verl.workers.actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig
from ..actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig
@dataclass
class CriticConfig:
strategy: str = "fsdp"
global_batch_size: int = 256
micro_batch_size_per_device_for_update: int = field(default=-1, init=False)
micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
micro_batch_size_per_device_for_update: int = 4
micro_batch_size_per_device_for_experience: int = 16
max_grad_norm: float = 1.0
cliprange_value: float = 0.5
ppo_epochs: int = 1
padding_free: bool = False
ulysses_sequence_parallel_size: int = 1
model: ModelConfig = field(default_factory=ModelConfig)
......
......@@ -20,17 +20,23 @@ from collections import defaultdict
from typing import Any, Dict
import torch
import torch.distributed
from ray.experimental.tqdm_ray import tqdm
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from tqdm import tqdm
from verl import DataProto
from verl.trainer import core_algos
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean
from verl.workers.critic.base import BasePPOCritic
from verl.workers.critic.config import CriticConfig
from ...protocol import DataProto
from ...trainer import core_algos
from ...utils import torch_functional as VF
from ...utils.py_functional import append_to_dict
from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
from .base import BasePPOCritic
from .config import CriticConfig
try:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
except ImportError:
pass
__all__ = ["DataParallelPPOCritic"]
......@@ -45,6 +51,7 @@ class DataParallelPPOCritic(BasePPOCritic):
def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Tensor:
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
responses = micro_batch["responses"]
......@@ -52,20 +59,61 @@ class DataParallelPPOCritic(BasePPOCritic):
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
vision_inputs = {}
if "pixel_values" in micro_batch:
vision_inputs["pixel_values"] = torch.cat(micro_batch["pixel_values"], dim=0)
vision_inputs["image_grid_thw"] = torch.cat(micro_batch["image_grid_thw"], dim=0)
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch:
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)
if self.config.padding_free:
# TODO (yaowei): preprocess data for padding_free and ulysses
raise NotImplementedError
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.config.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size
)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.critic_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.config.ulysses_sequence_parallel_size > 1:
values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# pad it back
values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
values = values[:, -response_length - 1 : -1]
else:
output = self.critic_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**vision_inputs,
**multi_modal_inputs,
use_cache=False,
)
values: torch.Tensor = output.logits
......@@ -81,7 +129,12 @@ class DataParallelPPOCritic(BasePPOCritic):
self.critic_module.parameters(), max_norm=self.config.max_grad_norm
)
self.critic_optimizer.step()
if not torch.isfinite(grad_norm):
print("Gradient norm is not finite. Skip update.")
else:
self.critic_optimizer.step()
self.critic_optimizer.zero_grad()
return grad_norm
@torch.no_grad()
......@@ -89,18 +142,21 @@ class DataParallelPPOCritic(BasePPOCritic):
self.critic_module.eval()
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
if "pixel_values" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"]
if "multi_modal_inputs" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["multi_modal_inputs"]
else:
non_tensor_select_keys = None
non_tensor_select_keys = []
micro_batches = data.select(select_keys, non_tensor_select_keys).split(
self.config.micro_batch_size_per_device_for_experience
)
values_lst = []
for micro_batch in tqdm(micro_batches, "Compute values", disable=(self.rank != 0)):
micro_batch.to("cuda")
values = self._forward_micro_batch(micro_batch)
if self.rank == 0:
micro_batches = tqdm(micro_batches, desc="Compute values", position=2)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
values = self._forward_micro_batch(model_inputs)
values_lst.append(values)
values = torch.concat(values_lst, dim=0)
......@@ -114,55 +170,56 @@ class DataParallelPPOCritic(BasePPOCritic):
self.critic_module.train()
select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"]
if "pixel_values" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"]
if "multi_modal_inputs" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["multi_modal_inputs"]
else:
non_tensor_select_keys = None
non_tensor_select_keys = []
# TODO (yaowei): support ppo epochs
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)
metrics = defaultdict(list)
n = len(mini_batches)
for i, mini_batch in enumerate(mini_batches):
gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
self.critic_optimizer.zero_grad()
for micro_batch in tqdm(micro_batches, desc=f"Update critic [{i + 1}/{n}]", disable=(self.rank != 0)):
micro_batch.to("cuda")
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
attention_mask = model_inputs["attention_mask"]
values = model_inputs["values"]
returns = model_inputs["returns"]
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length - 1 : -1]
vpreds = self._forward_micro_batch(data)
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value,
)
loss = vf_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, eos_mask).detach().item(),
}
append_to_dict(metrics, batch_metrics)
for _ in range(self.config.ppo_epochs):
if self.rank == 0:
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2)
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"critic/grad_norm": grad_norm.detach().item()})
for mini_batch in mini_batches:
gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
if self.rank == 0:
micro_batches = tqdm(micro_batches, desc="Update critic", position=3)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
attention_mask = model_inputs["attention_mask"]
values = model_inputs["values"]
returns = model_inputs["returns"]
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length - 1 : -1] # shift left for value computation
vpreds = self._forward_micro_batch(model_inputs)
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds,
returns=returns,
values=values,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value,
)
loss = vf_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": VF.masked_mean(vpreds, eos_mask).detach().item(),
}
append_to_dict(metrics, batch_metrics)
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"critic/grad_norm": grad_norm.detach().item()})
self.critic_optimizer.zero_grad()
return metrics
......@@ -15,8 +15,10 @@
The main entry point to run the PPO algorithm
"""
from typing import Literal
from typing import Literal, Optional, Union
import numpy as np
import psutil
import torch
import torch.distributed as dist
from accelerate import init_empty_weights
......@@ -34,13 +36,13 @@ from transformers import (
)
from transformers.modeling_utils import no_init_weights
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import get_tokenizer, get_processor
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fsdp_utils import (
from ..models.monkey_patch import apply_ulysses_patch
from ..protocol import DataProto
from ..single_controller.base import Worker
from ..single_controller.base.decorator import Dispatch, register
from ..utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from ..utils.flops_counter import FlopsCounter
from ..utils.fsdp_utils import (
get_fsdp_wrap_policy,
get_init_fn,
load_fsdp_model,
......@@ -48,16 +50,16 @@ from verl.utils.fsdp_utils import (
offload_fsdp_model,
offload_fsdp_optimizer,
)
from verl.utils.model_utils import print_model_size
from verl.utils.performance import log_gpu_memory_usage
from verl.utils.torch_dtypes import PrecisionType
from verl.utils.torch_functional import get_constant_schedule_with_warmup
from verl.workers.actor import DataParallelPPOActor
from verl.workers.config import FSDPConfig, ModelConfig, OptimConfig, WorkerConfig
from verl.workers.critic import DataParallelPPOCritic
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.sharding_manager import FSDPVLLMShardingManager
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from ..utils.model_utils import print_gpu_memory_usage, print_model_size
from ..utils.tokenizer import get_processor, get_tokenizer
from ..utils.torch_dtypes import PrecisionType
from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_with_warmup
from .actor import DataParallelPPOActor
from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig
from .critic import DataParallelPPOCritic
from .rollout.vllm_rollout import vLLMRollout
from .sharding_manager import FSDPVLLMShardingManager
from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
class FSDPWorker(Worker):
......@@ -68,77 +70,95 @@ class FSDPWorker(Worker):
):
super().__init__()
self.config = config
self.role = role
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
# build device mesh for FSDP
# TODO: support FSDP hybrid shard for larger model
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_critic = self.role == "critic"
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._use_param_offload = False
self._use_optimizer_offload = False
if self._is_actor:
self._use_param_offload = self.config.actor.offload.offload_params
self._use_optimizer_offload = self.config.actor.offload.offload_optimizer
self._init_config(self.config.actor, "actor")
elif self._is_critic:
self._use_param_offload = self.config.critic.offload.offload_params
self._use_optimizer_offload = self.config.critic.offload.offload_optimizer
self._init_config(self.config.critic, "critic")
elif self._is_ref: # NOTE: it seems that manual offload is slower than FSDP offload
self._use_param_offload = self.config.ref.offload.offload_params
self._init_config(self.config.ref, "ref")
def _init_config(
self, config: Union[ActorConfig, CriticConfig, RefConfig], role: Literal["actor", "critic", "ref"]
):
world_size = dist.get_world_size()
self.device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
fsdp_size = config.fsdp.fsdp_size
if fsdp_size <= 0 or fsdp_size >= world_size:
self.device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
else: # hsdp
self.device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=("ddp", "fsdp")
)
# build device mesh for Ulysses Sequence Parallel
self.ulysses_sequence_parallel_size = self.config.actor.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
if config.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
"cuda",
mesh_shape=(world_size // self.ulysses_sequence_parallel_size, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
mesh_shape=(
world_size // config.ulysses_sequence_parallel_size,
config.ulysses_sequence_parallel_size,
),
mesh_dim_names=("dp", "sp"),
)
else:
self.ulysses_device_mesh = None
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.role = role
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_critic = self.role == "critic"
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
if not hasattr(config, "global_batch_size"): # ref model
return
self._use_param_offload = False
self._use_optimizer_offload = False
if self._is_actor:
self._use_param_offload = self.config.actor.offload.param_offload
self._use_optimizer_offload = self.config.actor.offload.optimizer_offload
elif self._is_critic:
self._use_param_offload = self.config.critic.offload.param_offload
self._use_optimizer_offload = self.config.critic.offload.optimizer_offload
elif self._is_ref:
# NOTE: it seems that manual offload is slowly than FSDP offload
self._use_param_offload = self.config.ref.offload.param_offload
if self.config.rollout.n > 1:
config.global_batch_size *= self.config.rollout.n
self.print_rank0(f"{role} will use global batch size {config.global_batch_size}.")
# normalize config
if self._is_actor:
self.config.actor.global_batch_size *= self.config.rollout.n
self.config.actor.global_batch_size_per_device = (
self.config.actor.global_batch_size // self.device_mesh.shape[0] * self.ulysses_sequence_parallel_size
)
assert (
self.config.actor.global_batch_size_per_device
% self.config.actor.micro_batch_size_per_device_for_update
== 0
)
elif self._is_critic:
self.config.critic.global_batch_size *= self.config.rollout.n
self.config.critic.global_batch_size_per_device = (
self.config.critic.global_batch_size // self.device_mesh.shape[0] * self.ulysses_sequence_parallel_size
)
assert (
self.config.critic.global_batch_size_per_device
% self.config.critic.micro_batch_size_per_device_for_update
== 0
)
config.global_batch_size_per_device = (
config.global_batch_size // self.device_mesh.size() * config.ulysses_sequence_parallel_size
)
if config.global_batch_size_per_device == 0:
raise ValueError(f"{role} global batch size must be larger than num gpus.")
if config.global_batch_size_per_device % config.micro_batch_size_per_device_for_update != 0:
raise ValueError(f"{role} global batch size per device must be divisible by the micro batch size.")
if (
config.fsdp.enable_cpu_offload
and config.global_batch_size_per_device != config.micro_batch_size_per_device_for_update
):
raise ValueError(f"{role} cannot use FSDP's CPU offload when gradient accumulation is enabled.")
def _build_model_optimizer(
self,
model_config: ModelConfig,
fsdp_config: FSDPConfig,
optim_config: OptimConfig,
optim_config: Optional[OptimConfig],
padding_free: bool = False,
) -> None:
self.tokenizer = get_tokenizer(model_config.tokenizer_path, trust_remote_code=model_config.trust_remote_code)
self.processor = get_processor(model_config.tokenizer_path)
self.tokenizer = get_tokenizer(
model_config.tokenizer_path,
trust_remote_code=model_config.trust_remote_code,
use_fast=True,
)
self.processor = get_processor(
model_config.tokenizer_path,
trust_remote_code=model_config.trust_remote_code,
use_fast=True,
)
self.model_config = AutoConfig.from_pretrained(
model_config.model_path,
trust_remote_code=model_config.trust_remote_code,
......@@ -156,7 +176,8 @@ class FSDPWorker(Worker):
self.print_rank0(f"Model config: {self.model_config}")
if padding_free:
raise NotImplementedError("Padding free is not implemented yet.")
apply_ulysses_patch(self.model_config.model_type)
self.print_rank0("Ulysses patch applied!")
if fsdp_config.torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor or self._is_critic else torch.bfloat16
......@@ -170,13 +191,13 @@ class FSDPWorker(Worker):
else:
auto_class = AutoModelForCausalLM
if self.rank == 0:
if (not fsdp_config.enable_rank0_init) or self.device_mesh.get_local_rank("fsdp") == 0:
model = auto_class.from_pretrained(
model_config.model_path,
config=self.model_config,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
device_map="cpu",
device_map="cpu" if fsdp_config.enable_rank0_init else "cuda",
low_cpu_mem_usage=True,
trust_remote_code=model_config.trust_remote_code,
)
......@@ -195,29 +216,50 @@ class FSDPWorker(Worker):
if model_config.enable_gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
dist.barrier()
if self.rank == 0:
print_model_size(model)
if not (self._is_actor or self._is_critic):
model.requires_grad_(False)
if model_config.freeze_vision_tower:
if hasattr(model, "visual"):
model.visual.requires_grad_(False)
fsdp_config.use_orig_params = True
self.print_rank0("Vision tower is set to not trainable.")
else:
self.print_rank0("No vision tower found.")
log_gpu_memory_usage("After init from huggingface model")
dist.barrier()
print_model_size(model)
print_gpu_memory_usage("After huggingface model init")
mixed_precision = MixedPrecision(
param_dtype=PrecisionType.to_dtype(fsdp_config.mp_param_dtype),
reduce_dtype=PrecisionType.to_dtype(fsdp_config.mp_reduce_dtype),
buffer_dtype=PrecisionType.to_dtype(fsdp_config.mp_buffer_dtype),
)
auto_wrap_policy = get_fsdp_wrap_policy(model)
if fsdp_config.enable_full_shard:
sharding_strategy = ShardingStrategy.FULL_SHARD
self.print_rank0(f"FSDP wrap policy: {auto_wrap_policy}.")
if self.device_mesh.ndim == 2:
if fsdp_config.enable_full_shard:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
else:
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
else:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
if fsdp_config.enable_full_shard:
sharding_strategy = ShardingStrategy.FULL_SHARD
else:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
if fsdp_config.param_offload or fsdp_config.optimizer_offload:
cpu_offload = CPUOffload(offload_params=fsdp_config.param_offload)
if fsdp_config.enable_cpu_offload:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = None
if self.rank == 0:
print(f"FSDP wrap policy: {auto_wrap_policy}.")
if fsdp_config.enable_rank0_init:
sync_module_states = True
param_init_fn = get_init_fn(model, device="cuda") if self.rank != 0 else None
else:
sync_module_states = False
param_init_fn = None
self.fsdp_module = FSDP(
model,
......@@ -225,53 +267,60 @@ class FSDPWorker(Worker):
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision,
param_init_fn=get_init_fn(model, device="cuda") if self.rank != 0 else None,
param_init_fn=param_init_fn,
device_id=torch.cuda.current_device(),
sync_module_states=True,
sync_module_states=sync_module_states,
forward_prefetch=False,
use_orig_params=False,
use_orig_params=fsdp_config.use_orig_params,
device_mesh=self.device_mesh,
)
log_gpu_memory_usage("After Actor FSDP init")
print_gpu_memory_usage("After FSDP module init")
if self._is_actor or self._is_critic:
self.optimizer = torch.optim.AdamW(
self.fsdp_module.parameters(),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
)
num_warmup_steps = int(optim_config.lr_warmup_steps_ratio * optim_config.training_steps)
if optim_config.strategy == "adamw":
self.optimizer = torch.optim.AdamW(
self.fsdp_module.parameters(),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
fused=True,
)
elif optim_config.strategy == "adamw_bf16":
self.optimizer = AnyPrecisionAdamW(
self.fsdp_module.parameters(),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
)
else:
raise NotImplementedError(f"Optimizer {optim_config.strategy} not supported.")
num_warmup_steps = int(optim_config.lr_warmup_ratio * optim_config.training_steps)
self.lr_scheduler = get_constant_schedule_with_warmup(
optimizer=self.optimizer, num_warmup_steps=num_warmup_steps
)
print_gpu_memory_usage("After optimizer init")
else:
self.optimizer, self.lr_scheduler = None, None
log_gpu_memory_usage("After actor optimizer init")
def _build_rollout(self) -> None:
# TODO(sgm): support FSDP hybrid shard for larger model
tp_size = self.config.rollout.tensor_parallel_size
dp_size = self.world_size // tp_size
assert self.world_size % tp_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by tp_size: {tp_size}"
f"rollout world size: {self.world_size} is not divisible by tp size: {tp_size}"
)
rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=["dp", "tp"])
log_gpu_memory_usage("Before building vllm rollout")
rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp"))
self.rollout = vLLMRollout(
model_path=self.config.actor.model.model_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
)
log_gpu_memory_usage("After building vllm rollout")
self.rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.fsdp_module,
inference_engine=self.rollout.inference_engine,
device_mesh=rollout_device_mesh,
)
log_gpu_memory_usage("After building sharding manager")
print_gpu_memory_usage("After vllm init")
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
......@@ -280,11 +329,21 @@ class FSDPWorker(Worker):
fsdp_config = self.config.critic.fsdp
optim_config = self.config.critic.optim
padding_free = self.config.critic.padding_free
else:
role = "critic"
elif self._is_actor:
model_config = self.config.actor.model
fsdp_config = self.config.actor.fsdp
optim_config = self.config.actor.optim
padding_free = self.config.actor.padding_free
role = "actor"
elif self._is_ref:
model_config = self.config.actor.model
fsdp_config = self.config.ref.fsdp
optim_config = None
padding_free = self.config.ref.padding_free
role = "ref"
else:
raise ValueError(f"Unknown role {role}.")
if self._is_actor or self._is_critic or self._is_ref:
self._build_model_optimizer(
......@@ -293,11 +352,13 @@ class FSDPWorker(Worker):
optim_config=optim_config,
padding_free=padding_free,
)
# get the original unwrapped module
self.unwrapped_model = self.fsdp_module._fsdp_wrapped_module
if self._use_optimizer_offload and not self._is_critic:
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
print_gpu_memory_usage(f"After offload {role} model during init")
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("After offload actor optimizer during init")
print_gpu_memory_usage(f"After offload {role} optimizer during init")
if self._is_actor:
self.actor = DataParallelPPOActor(
......@@ -317,7 +378,10 @@ class FSDPWorker(Worker):
self._build_rollout()
if self._is_ref:
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.fsdp_module)
self.ref_policy = DataParallelPPOActor(
config=self.config.ref,
actor_module=self.fsdp_module,
)
if self._is_actor or self._is_critic:
self.flops_counter = FlopsCounter(self.model_config)
......@@ -325,42 +389,37 @@ class FSDPWorker(Worker):
model=self.fsdp_module,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
tokenizer=self.tokenizer,
processor=self.processor
processing_class=self.processor if self.processor is not None else self.tokenizer,
)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, path: str, global_step: int = 0, remove_previous_ckpt: bool = False):
def save_checkpoint(self, path: str):
assert self._is_actor or self._is_critic
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
self.checkpoint_manager.save_checkpoint(
local_path=path,
global_step=global_step,
remove_previous_ckpt=remove_previous_ckpt,
)
self.checkpoint_manager.save_checkpoint(path)
dist.barrier()
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, path: str, del_local_after_load: bool = True):
def load_checkpoint(self, path: str):
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load)
self.checkpoint_manager.load_checkpoint(path)
dist.barrier()
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
"""ActorRolloutRefWorker"""
if self._use_optimizer_offload:
offload_fsdp_optimizer(self.optimizer)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
assert self._is_actor
data = data.to(torch.cuda.current_device())
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
......@@ -368,7 +427,6 @@ class FSDPWorker(Worker):
if self._use_optimizer_offload:
load_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("Before update policy")
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_policy", logger=None) as timer:
......@@ -377,17 +435,27 @@ class FSDPWorker(Worker):
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
metrics["perf/mfu_actor"] = (
estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size)
)
metrics["perf/max_memory_allocated_gb"] = (
torch.cuda.max_memory_allocated() - self.rollout_sharding_manager.freed_bytes
) / (1024**3)
metrics["perf/max_memory_reserved_gb"] = (
torch.cuda.max_memory_reserved() - self.rollout_sharding_manager.freed_bytes
) / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
log_gpu_memory_usage("After update policy")
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
# Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
output = DataProto(
non_tensor_batch={
key: np.array([value] if np.isscalar(value) else value) for key, value in metrics.items()
}
)
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
......@@ -395,7 +463,7 @@ class FSDPWorker(Worker):
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer)
torch.cuda.empty_cache()
output = output.to("cpu")
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
......@@ -422,22 +490,17 @@ class FSDPWorker(Worker):
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("After entering rollout sharding manager")
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage("After rollout generation")
output = self.rollout_sharding_manager.postprocess_data(output)
output = output.to("cpu")
torch.cuda.empty_cache() # clear kv cache
log_gpu_memory_usage("After recompute log prob")
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
def compute_log_probs(self, data: DataProto):
assert self._is_actor
data = data.to(torch.cuda.current_device())
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
......@@ -452,8 +515,6 @@ class FSDPWorker(Worker):
)
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
......@@ -462,13 +523,13 @@ class FSDPWorker(Worker):
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache()
log_gpu_memory_usage("After compute_log_prob")
output = output.to("cpu")
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
def compute_ref_log_probs(self, data: DataProto):
assert self._is_ref
data = data.to(torch.cuda.current_device())
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
......@@ -476,11 +537,9 @@ class FSDPWorker(Worker):
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = DataProto.from_dict(tensors={"ref_log_probs": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
......@@ -489,15 +548,13 @@ class FSDPWorker(Worker):
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache()
log_gpu_memory_usage("After compute_ref_log_prob")
output = output.to("cpu")
return output
"""CriticWorker"""
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
assert self._is_critic
data = data.to(torch.cuda.current_device())
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
......@@ -507,15 +564,15 @@ class FSDPWorker(Worker):
output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache()
output = output.to("cpu")
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
data = data.to(torch.cuda.current_device())
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
......@@ -530,21 +587,26 @@ class FSDPWorker(Worker):
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["mfu/critic"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
metrics["perf/mfu_critic"] = (
estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size)
)
self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr
output = DataProto(batch=None, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
# Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
output = DataProto(
non_tensor_batch={
metric: np.array([value] if np.isscalar(value) else value) for metric, value in metrics.items()
}
)
output = output.to("cpu")
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer)
torch.cuda.empty_cache()
output = output.to("cpu")
return output
......@@ -13,55 +13,48 @@
# limitations under the License.
from collections import defaultdict
from typing import Any, Callable, Dict, Tuple, TypedDict
import torch
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.utils.reward_score import math_compute_score, r1v_compute_score
from ...protocol import DataProto
from ...utils.reward_score import math_compute_score, r1v_compute_score
class RewardScore(TypedDict):
overall: float
format: float
accuracy: float
class CustomRewardManager:
def __init__(self, tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: str):
def __init__(self, tokenizer: PreTrainedTokenizer, compute_score: str):
self.tokenizer = tokenizer
self.num_examine = num_examine
if compute_score == "math":
self.compute_score = math_compute_score
self.compute_score: Callable[[str, str], RewardScore] = math_compute_score
elif compute_score == "r1v":
self.compute_score = r1v_compute_score
self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score
else:
raise NotImplementedError()
def __call__(self, data: DataProto) -> torch.Tensor:
def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, Any]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
already_print = 0
reward_metrics = defaultdict(list)
for i in range(len(data)):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
response_mask = data_item.batch["response_mask"]
valid_response_length = response_mask.sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
ground_truth = data_item.non_tensor_batch["answer"]
ground_truth = data_item.non_tensor_batch["ground_truth"]
score = self.compute_score(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score
if already_print < self.num_examine:
already_print += 1
print("[prompt]", prompt_str)
print("[response]", response_str)
print("[ground_truth]", ground_truth)
print("[score]", score)
reward_tensor[i, valid_response_length - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
return reward_tensor
return reward_tensor, reward_metrics
......@@ -14,7 +14,7 @@
from abc import ABC, abstractmethod
from verl import DataProto
from ...protocol import DataProto
__all__ = ["BaseRollout"]
......
......@@ -16,15 +16,18 @@ Rollout config
"""
from dataclasses import asdict, dataclass, field
from typing import Any, Dict
@dataclass
class RolloutConfig:
name: str = "vllm"
n: int = 1
temperature: float = 1.0
top_k: int = -1
top_p: float = 1.0
dtype: str = "bfloat16"
top_k: int = -1
limit_images: int = 0
dtype: str = "bf16"
gpu_memory_utilization: float = 0.5
ignore_eos: bool = False
enforce_eager: bool = False
......@@ -34,9 +37,7 @@ class RolloutConfig:
max_num_batched_tokens: int = 8192
max_num_seqs: int = 1024
disable_log_stats: bool = True
do_sample: bool = True
n: int = 1
limit_images: int = 0
val_override_config: Dict[str, Any] = field(default_factory=dict)
"""auto keys"""
prompt_length: int = field(default=-1, init=False)
response_length: int = field(default=-1, init=False)
......
......@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .dtensor_weight_loaders import load_dtensor_weights
from .vllm_rollout_spmd import vLLMRollout
__all__ = ["vLLMRollout", "load_dtensor_weights"]
__all__ = ["vLLMRollout"]
......@@ -18,26 +18,29 @@ When working with FSDP:
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
"""
import os
from contextlib import contextmanager
from typing import Any, List, Union
import numpy as np
import torch
import torch.distributed
from tensordict import TensorDict
from transformers import PreTrainedTokenizer
from vllm import LLM, RequestOutput, SamplingParams
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.config import RolloutConfig
from ....protocol import DataProto
from ....utils import torch_functional as VF
from ....utils.torch_dtypes import PrecisionType
from ..base import BaseRollout
from ..config import RolloutConfig
def _repeat_interleave(features: Union[torch.Tensor, List[Any]], repeats: int) -> Union[torch.Tensor, List[Any]]:
if isinstance(features, torch.Tensor):
return features.repeat_interleave(repeats, dim=0)
def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
if isinstance(value, torch.Tensor):
return value.repeat_interleave(repeats, dim=0)
else:
return [feature for feature in features for _ in range(repeats)]
return np.repeat(value, repeats, axis=0)
class vLLMRollout(BaseRollout):
......@@ -50,6 +53,7 @@ class vLLMRollout(BaseRollout):
tokenizer: the task/model tokenizer
"""
super().__init__()
self.rank = int(os.getenv("RANK", "0"))
self.config = config
self.pad_token_id = tokenizer.pad_token_id
if config.tensor_parallel_size > torch.distributed.get_world_size():
......@@ -69,7 +73,7 @@ class vLLMRollout(BaseRollout):
model=model_path,
skip_tokenizer_init=False,
tensor_parallel_size=config.tensor_parallel_size,
dtype=config.dtype,
dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
gpu_memory_utilization=config.gpu_memory_utilization,
enforce_eager=config.enforce_eager,
max_model_len=config.prompt_length + config.response_length,
......@@ -77,6 +81,7 @@ class vLLMRollout(BaseRollout):
enable_sleep_mode=True,
distributed_executor_backend="external_launcher",
disable_custom_all_reduce=True,
disable_mm_preprocessor_cache=True,
disable_log_stats=config.disable_log_stats,
enable_chunked_prefill=config.enable_chunked_prefill,
**vllm_init_kwargs,
......@@ -111,7 +116,7 @@ class vLLMRollout(BaseRollout):
setattr(self.sampling_params, key, value)
@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
def generate_sequences(self, prompts: DataProto) -> DataProto:
# left-padded attention_mask
input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length)
attention_mask: torch.Tensor = prompts.batch["attention_mask"]
......@@ -119,54 +124,40 @@ class vLLMRollout(BaseRollout):
eos_token_id: int = prompts.meta_info["eos_token_id"]
batch_size = input_ids.size(0)
do_sample = prompts.meta_info.get("do_sample", True)
if not do_sample:
kwargs = {
"n": 1,
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
non_tensor_batch = prompts.non_tensor_batch
if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
raise RuntimeError("vllm sharding manager is not work properly.")
if "images" in non_tensor_batch:
if "multi_modal_data" in non_tensor_batch:
vllm_inputs = []
for raw_prompt_ids, images in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("images")):
vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": {"image": images}})
for raw_prompt_ids, multi_modal_data in zip(
non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")
):
vllm_inputs.append({"prompt_token_ids": list(raw_prompt_ids), "multi_modal_data": multi_modal_data})
else:
vllm_inputs = [
{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
{"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
]
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
with self.update_sampling_params(**prompts.meta_info):
completions: List[RequestOutput] = self.inference_engine.generate(
prompts=vllm_inputs, sampling_params=self.sampling_params
prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=(self.rank == 0)
)
response_ids = []
for completion in completions:
for output in completion.outputs:
response_ids.append(output.token_ids)
response_ids = pad_2d_list_to_length(
response_ids, self.pad_token_id, max_length=self.config.response_length
).to(input_ids.device)
if self.config.n > 1 and do_sample:
batch_size = batch_size * self.config.n
input_ids = _repeat_interleave(input_ids, self.config.n)
attention_mask = _repeat_interleave(attention_mask, self.config.n)
position_ids = _repeat_interleave(position_ids, self.config.n)
if "pixel_values" in non_tensor_batch.keys():
non_tensor_batch["pixel_values"] = _repeat_interleave(non_tensor_batch["pixel_values"], self.config.n)
non_tensor_batch["image_grid_thw"] = _repeat_interleave(
non_tensor_batch["image_grid_thw"], self.config.n
)
response_ids = [output.token_ids for completion in completions for output in completion.outputs]
response_ids = VF.pad_2d_list_to_length(
response_ids, self.pad_token_id, max_length=self.config.response_length
).to(input_ids.device)
if self.sampling_params.n > 1:
batch_size = batch_size * self.sampling_params.n
input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
if "multi_modal_inputs" in non_tensor_batch.keys():
non_tensor_batch["multi_modal_inputs"] = _repeat_interleave(
non_tensor_batch["multi_modal_inputs"], self.sampling_params.n
)
sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
response_length = response_ids.size(1)
......@@ -180,10 +171,10 @@ class vLLMRollout(BaseRollout):
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[..., -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(
response_ids=response_ids, eos_token=eos_token_id, dtype=attention_mask.dtype
response_mask = VF.get_eos_mask(
response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict(
......@@ -192,6 +183,7 @@ class vLLMRollout(BaseRollout):
"responses": response_ids,
"input_ids": sequence_ids, # here input_ids become the whole sentences
"attention_mask": attention_mask,
"response_mask": response_mask,
"position_ids": position_ids,
},
batch_size=batch_size,
......
......@@ -15,7 +15,7 @@
Sharding manager to implement HybridEngine
"""
from verl import DataProto
from ...protocol import DataProto
class BaseShardingManager:
......
......@@ -17,9 +17,8 @@ Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
from torch.distributed.device_mesh import DeviceMesh
from verl import DataProto
from verl.utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group
from ...protocol import DataProto, all_gather_data_proto
from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group
from .base import BaseShardingManager
......@@ -48,9 +47,9 @@ class FSDPUlyssesShardingManager(BaseShardingManager):
In Ulysses, we need to make sure the same data is used across a SP group
"""
if self.device_mesh is not None:
sp_size = self.device_mesh["sp"].size()
sp_group = self.device_mesh["sp"].get_group()
data = data.to("cuda")
data.all_gather(sp_group)
all_gather_data_proto(data, size=sp_size, group=sp_group)
return data
......
......@@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Dict, Iterable, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from vllm import LLM
from vllm.distributed import parallel_state as vllm_ps
from verl import DataProto
from verl.utils.performance import log_gpu_memory_usage
from verl.workers.rollout.vllm_rollout import load_dtensor_weights
from ...protocol import DataProto, all_gather_data_proto
from ...utils.model_utils import print_gpu_memory_usage
from .base import BaseShardingManager
......@@ -38,11 +39,22 @@ class FSDPVLLMShardingManager(BaseShardingManager):
self.module = module
self.inference_engine = inference_engine
self.device_mesh = device_mesh
FSDP.set_state_dict_type(
self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
FSDP.set_state_dict_type(
self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
self.world_size = dist.get_world_size()
self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group
# Record freed bytes to estimate memory usage correctly
# https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119
self.freed_bytes = 0
# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
......@@ -55,29 +67,45 @@ class FSDPVLLMShardingManager(BaseShardingManager):
else:
self.gen_random_states = None
def _make_weight_iterator(
self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
for name, tensor in actor_weights.items():
yield name, tensor.full_tensor() if self.world_size != 1 else tensor
def __enter__(self):
log_gpu_memory_usage("Before state_dict() in sharding manager")
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()
print_gpu_memory_usage("Before state_dict() in sharding manager")
actor_weights = self.module.state_dict()
log_gpu_memory_usage("After state_dict() in sharding manager")
print_gpu_memory_usage("After state_dict() in sharding manager")
self.inference_engine.wake_up()
load_dtensor_weights(
actor_weights, self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
log_gpu_memory_usage("After sync model weights in sharding manager")
model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
model.load_weights(self._make_weight_iterator(actor_weights))
print_gpu_memory_usage("After sync model weights in sharding manager")
del actor_weights
torch.cuda.empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states)
def __exit__(self, exc_type, exc_value, traceback):
log_gpu_memory_usage("Before vllm offload in sharding manager")
print_gpu_memory_usage("Before vllm offload in sharding manager")
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
self.inference_engine.sleep(level=1)
log_gpu_memory_usage("After vllm offload in sharding manager")
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print_gpu_memory_usage("After vllm offload in sharding manager")
self.module.train()
torch.cuda.empty_cache() # add empty cache after each compute
......@@ -88,15 +116,13 @@ class FSDPVLLMShardingManager(BaseShardingManager):
torch.cuda.set_rng_state(self.torch_random_states)
def preprocess_data(self, data: DataProto) -> DataProto:
tp_group = vllm_ps.get_tensor_model_parallel_group().device_group
data = data.to("cuda")
data.all_gather(tp_group)
"""All gather across tp group to make each rank has identical input."""
all_gather_data_proto(data, size=self.tp_size, group=self.tp_group)
return data
def postprocess_data(self, data: DataProto) -> DataProto:
dp_rank = dist.get_rank()
tp_size = vllm_ps.get_tensor_model_parallel_world_size()
if tp_size > 1:
data = data.chunk(chunks=tp_size)[dp_rank % tp_size]
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size > 1:
data = data.chunk(chunks=self.tp_size)[self.tp_rank]
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment