Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
...@@ -20,8 +20,8 @@ from typing import Any, Dict ...@@ -20,8 +20,8 @@ from typing import Any, Dict
import torch import torch
from verl import DataProto from ...protocol import DataProto
from verl.workers.critic.config import CriticConfig from .config import CriticConfig
__all__ = ["BasePPOCritic"] __all__ = ["BasePPOCritic"]
......
...@@ -17,17 +17,18 @@ Critic config ...@@ -17,17 +17,18 @@ Critic config
from dataclasses import dataclass, field from dataclasses import dataclass, field
from verl.workers.actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig from ..actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig
@dataclass @dataclass
class CriticConfig: class CriticConfig:
strategy: str = "fsdp" strategy: str = "fsdp"
global_batch_size: int = 256 global_batch_size: int = 256
micro_batch_size_per_device_for_update: int = field(default=-1, init=False) micro_batch_size_per_device_for_update: int = 4
micro_batch_size_per_device_for_experience: int = field(default=-1, init=False) micro_batch_size_per_device_for_experience: int = 16
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
cliprange_value: float = 0.5 cliprange_value: float = 0.5
ppo_epochs: int = 1
padding_free: bool = False padding_free: bool = False
ulysses_sequence_parallel_size: int = 1 ulysses_sequence_parallel_size: int = 1
model: ModelConfig = field(default_factory=ModelConfig) model: ModelConfig = field(default_factory=ModelConfig)
......
...@@ -20,17 +20,23 @@ from collections import defaultdict ...@@ -20,17 +20,23 @@ from collections import defaultdict
from typing import Any, Dict from typing import Any, Dict
import torch import torch
import torch.distributed from ray.experimental.tqdm_ray import tqdm
from torch import nn from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from tqdm import tqdm
from verl import DataProto from ...protocol import DataProto
from verl.trainer import core_algos from ...trainer import core_algos
from verl.utils.py_functional import append_to_dict from ...utils import torch_functional as VF
from verl.utils.torch_functional import masked_mean from ...utils.py_functional import append_to_dict
from verl.workers.critic.base import BasePPOCritic from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
from verl.workers.critic.config import CriticConfig from .base import BasePPOCritic
from .config import CriticConfig
try:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
except ImportError:
pass
__all__ = ["DataParallelPPOCritic"] __all__ = ["DataParallelPPOCritic"]
...@@ -45,6 +51,7 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -45,6 +51,7 @@ class DataParallelPPOCritic(BasePPOCritic):
def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Tensor: def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Tensor:
input_ids = micro_batch["input_ids"] input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"] attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"] position_ids = micro_batch["position_ids"]
responses = micro_batch["responses"] responses = micro_batch["responses"]
...@@ -52,20 +59,61 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -52,20 +59,61 @@ class DataParallelPPOCritic(BasePPOCritic):
if position_ids.dim() == 3: # qwen2vl mrope if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
vision_inputs = {} multi_modal_inputs = {}
if "pixel_values" in micro_batch: if "multi_modal_inputs" in micro_batch:
vision_inputs["pixel_values"] = torch.cat(micro_batch["pixel_values"], dim=0) for key in micro_batch["multi_modal_inputs"][0].keys():
vision_inputs["image_grid_thw"] = torch.cat(micro_batch["image_grid_thw"], dim=0) multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)
if self.config.padding_free: if self.config.padding_free:
# TODO (yaowei): preprocess data for padding_free and ulysses input_ids_rmpad, indices, *_ = unpad_input(
raise NotImplementedError input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.config.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size
)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.critic_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.config.ulysses_sequence_parallel_size > 1:
values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# pad it back
values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
values = values[:, -response_length - 1 : -1]
else: else:
output = self.critic_module( output = self.critic_module(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
**vision_inputs, **multi_modal_inputs,
use_cache=False, use_cache=False,
) )
values: torch.Tensor = output.logits values: torch.Tensor = output.logits
...@@ -81,7 +129,12 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -81,7 +129,12 @@ class DataParallelPPOCritic(BasePPOCritic):
self.critic_module.parameters(), max_norm=self.config.max_grad_norm self.critic_module.parameters(), max_norm=self.config.max_grad_norm
) )
self.critic_optimizer.step() if not torch.isfinite(grad_norm):
print("Gradient norm is not finite. Skip update.")
else:
self.critic_optimizer.step()
self.critic_optimizer.zero_grad()
return grad_norm return grad_norm
@torch.no_grad() @torch.no_grad()
...@@ -89,18 +142,21 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -89,18 +142,21 @@ class DataParallelPPOCritic(BasePPOCritic):
self.critic_module.eval() self.critic_module.eval()
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
if "pixel_values" in data.non_tensor_batch.keys(): if "multi_modal_inputs" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"] non_tensor_select_keys = ["multi_modal_inputs"]
else: else:
non_tensor_select_keys = None non_tensor_select_keys = []
micro_batches = data.select(select_keys, non_tensor_select_keys).split( micro_batches = data.select(select_keys, non_tensor_select_keys).split(
self.config.micro_batch_size_per_device_for_experience self.config.micro_batch_size_per_device_for_experience
) )
values_lst = [] values_lst = []
for micro_batch in tqdm(micro_batches, "Compute values", disable=(self.rank != 0)): if self.rank == 0:
micro_batch.to("cuda") micro_batches = tqdm(micro_batches, desc="Compute values", position=2)
values = self._forward_micro_batch(micro_batch)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
values = self._forward_micro_batch(model_inputs)
values_lst.append(values) values_lst.append(values)
values = torch.concat(values_lst, dim=0) values = torch.concat(values_lst, dim=0)
...@@ -114,55 +170,56 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -114,55 +170,56 @@ class DataParallelPPOCritic(BasePPOCritic):
self.critic_module.train() self.critic_module.train()
select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"]
if "pixel_values" in data.non_tensor_batch.keys(): if "multi_modal_inputs" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"] non_tensor_select_keys = ["multi_modal_inputs"]
else: else:
non_tensor_select_keys = None non_tensor_select_keys = []
# TODO (yaowei): support ppo epochs
# Split to make minibatch iterator for updating the actor # Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347 # See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device) mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)
metrics = defaultdict(list) metrics = defaultdict(list)
n = len(mini_batches) for _ in range(self.config.ppo_epochs):
for i, mini_batch in enumerate(mini_batches): if self.rank == 0:
gradient_accumulation = ( mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2)
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
self.critic_optimizer.zero_grad()
for micro_batch in tqdm(micro_batches, desc=f"Update critic [{i + 1}/{n}]", disable=(self.rank != 0)):
micro_batch.to("cuda")
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
attention_mask = model_inputs["attention_mask"]
values = model_inputs["values"]
returns = model_inputs["returns"]
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length - 1 : -1]
vpreds = self._forward_micro_batch(data)
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value,
)
loss = vf_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, eos_mask).detach().item(),
}
append_to_dict(metrics, batch_metrics)
grad_norm = self._optimizer_step() for mini_batch in mini_batches:
append_to_dict(metrics, {"critic/grad_norm": grad_norm.detach().item()}) gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
if self.rank == 0:
micro_batches = tqdm(micro_batches, desc="Update critic", position=3)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
attention_mask = model_inputs["attention_mask"]
values = model_inputs["values"]
returns = model_inputs["returns"]
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length - 1 : -1] # shift left for value computation
vpreds = self._forward_micro_batch(model_inputs)
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds,
returns=returns,
values=values,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value,
)
loss = vf_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": VF.masked_mean(vpreds, eos_mask).detach().item(),
}
append_to_dict(metrics, batch_metrics)
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"critic/grad_norm": grad_norm.detach().item()})
self.critic_optimizer.zero_grad()
return metrics return metrics
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
The main entry point to run the PPO algorithm The main entry point to run the PPO algorithm
""" """
from typing import Literal from typing import Literal, Optional, Union
import numpy as np
import psutil
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import init_empty_weights from accelerate import init_empty_weights
...@@ -34,13 +36,13 @@ from transformers import ( ...@@ -34,13 +36,13 @@ from transformers import (
) )
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import no_init_weights
from verl import DataProto from ..models.monkey_patch import apply_ulysses_patch
from verl.single_controller.base import Worker from ..protocol import DataProto
from verl.single_controller.base.decorator import Dispatch, register from ..single_controller.base import Worker
from verl.utils import get_tokenizer, get_processor from ..single_controller.base.decorator import Dispatch, register
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from ..utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.flops_counter import FlopsCounter from ..utils.flops_counter import FlopsCounter
from verl.utils.fsdp_utils import ( from ..utils.fsdp_utils import (
get_fsdp_wrap_policy, get_fsdp_wrap_policy,
get_init_fn, get_init_fn,
load_fsdp_model, load_fsdp_model,
...@@ -48,16 +50,16 @@ from verl.utils.fsdp_utils import ( ...@@ -48,16 +50,16 @@ from verl.utils.fsdp_utils import (
offload_fsdp_model, offload_fsdp_model,
offload_fsdp_optimizer, offload_fsdp_optimizer,
) )
from verl.utils.model_utils import print_model_size from ..utils.model_utils import print_gpu_memory_usage, print_model_size
from verl.utils.performance import log_gpu_memory_usage from ..utils.tokenizer import get_processor, get_tokenizer
from verl.utils.torch_dtypes import PrecisionType from ..utils.torch_dtypes import PrecisionType
from verl.utils.torch_functional import get_constant_schedule_with_warmup from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_with_warmup
from verl.workers.actor import DataParallelPPOActor from .actor import DataParallelPPOActor
from verl.workers.config import FSDPConfig, ModelConfig, OptimConfig, WorkerConfig from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig
from verl.workers.critic import DataParallelPPOCritic from .critic import DataParallelPPOCritic
from verl.workers.rollout.vllm_rollout import vLLMRollout from .rollout.vllm_rollout import vLLMRollout
from verl.workers.sharding_manager import FSDPVLLMShardingManager from .sharding_manager import FSDPVLLMShardingManager
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
class FSDPWorker(Worker): class FSDPWorker(Worker):
...@@ -68,77 +70,95 @@ class FSDPWorker(Worker): ...@@ -68,77 +70,95 @@ class FSDPWorker(Worker):
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.role = role
if not dist.is_initialized(): if not dist.is_initialized():
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
# build device mesh for FSDP self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
# TODO: support FSDP hybrid shard for larger model self._is_critic = self.role == "critic"
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._use_param_offload = False
self._use_optimizer_offload = False
if self._is_actor:
self._use_param_offload = self.config.actor.offload.offload_params
self._use_optimizer_offload = self.config.actor.offload.offload_optimizer
self._init_config(self.config.actor, "actor")
elif self._is_critic:
self._use_param_offload = self.config.critic.offload.offload_params
self._use_optimizer_offload = self.config.critic.offload.offload_optimizer
self._init_config(self.config.critic, "critic")
elif self._is_ref: # NOTE: it seems that manual offload is slower than FSDP offload
self._use_param_offload = self.config.ref.offload.offload_params
self._init_config(self.config.ref, "ref")
def _init_config(
self, config: Union[ActorConfig, CriticConfig, RefConfig], role: Literal["actor", "critic", "ref"]
):
world_size = dist.get_world_size() world_size = dist.get_world_size()
self.device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) fsdp_size = config.fsdp.fsdp_size
if fsdp_size <= 0 or fsdp_size >= world_size:
self.device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
else: # hsdp
self.device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=("ddp", "fsdp")
)
# build device mesh for Ulysses Sequence Parallel if config.ulysses_sequence_parallel_size > 1:
self.ulysses_sequence_parallel_size = self.config.actor.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh( self.ulysses_device_mesh = init_device_mesh(
"cuda", "cuda",
mesh_shape=(world_size // self.ulysses_sequence_parallel_size, self.ulysses_sequence_parallel_size), mesh_shape=(
mesh_dim_names=["dp", "sp"], world_size // config.ulysses_sequence_parallel_size,
config.ulysses_sequence_parallel_size,
),
mesh_dim_names=("dp", "sp"),
) )
else: else:
self.ulysses_device_mesh = None self.ulysses_device_mesh = None
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.role = role if not hasattr(config, "global_batch_size"): # ref model
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] return
self._is_critic = self.role == "critic"
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._use_param_offload = False if self.config.rollout.n > 1:
self._use_optimizer_offload = False config.global_batch_size *= self.config.rollout.n
if self._is_actor: self.print_rank0(f"{role} will use global batch size {config.global_batch_size}.")
self._use_param_offload = self.config.actor.offload.param_offload
self._use_optimizer_offload = self.config.actor.offload.optimizer_offload
elif self._is_critic:
self._use_param_offload = self.config.critic.offload.param_offload
self._use_optimizer_offload = self.config.critic.offload.optimizer_offload
elif self._is_ref:
# NOTE: it seems that manual offload is slowly than FSDP offload
self._use_param_offload = self.config.ref.offload.param_offload
# normalize config config.global_batch_size_per_device = (
if self._is_actor: config.global_batch_size // self.device_mesh.size() * config.ulysses_sequence_parallel_size
self.config.actor.global_batch_size *= self.config.rollout.n )
self.config.actor.global_batch_size_per_device = ( if config.global_batch_size_per_device == 0:
self.config.actor.global_batch_size // self.device_mesh.shape[0] * self.ulysses_sequence_parallel_size raise ValueError(f"{role} global batch size must be larger than num gpus.")
)
assert ( if config.global_batch_size_per_device % config.micro_batch_size_per_device_for_update != 0:
self.config.actor.global_batch_size_per_device raise ValueError(f"{role} global batch size per device must be divisible by the micro batch size.")
% self.config.actor.micro_batch_size_per_device_for_update
== 0 if (
) config.fsdp.enable_cpu_offload
elif self._is_critic: and config.global_batch_size_per_device != config.micro_batch_size_per_device_for_update
self.config.critic.global_batch_size *= self.config.rollout.n ):
self.config.critic.global_batch_size_per_device = ( raise ValueError(f"{role} cannot use FSDP's CPU offload when gradient accumulation is enabled.")
self.config.critic.global_batch_size // self.device_mesh.shape[0] * self.ulysses_sequence_parallel_size
)
assert (
self.config.critic.global_batch_size_per_device
% self.config.critic.micro_batch_size_per_device_for_update
== 0
)
def _build_model_optimizer( def _build_model_optimizer(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
fsdp_config: FSDPConfig, fsdp_config: FSDPConfig,
optim_config: OptimConfig, optim_config: Optional[OptimConfig],
padding_free: bool = False, padding_free: bool = False,
) -> None: ) -> None:
self.tokenizer = get_tokenizer(model_config.tokenizer_path, trust_remote_code=model_config.trust_remote_code) self.tokenizer = get_tokenizer(
self.processor = get_processor(model_config.tokenizer_path) model_config.tokenizer_path,
trust_remote_code=model_config.trust_remote_code,
use_fast=True,
)
self.processor = get_processor(
model_config.tokenizer_path,
trust_remote_code=model_config.trust_remote_code,
use_fast=True,
)
self.model_config = AutoConfig.from_pretrained( self.model_config = AutoConfig.from_pretrained(
model_config.model_path, model_config.model_path,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
...@@ -156,7 +176,8 @@ class FSDPWorker(Worker): ...@@ -156,7 +176,8 @@ class FSDPWorker(Worker):
self.print_rank0(f"Model config: {self.model_config}") self.print_rank0(f"Model config: {self.model_config}")
if padding_free: if padding_free:
raise NotImplementedError("Padding free is not implemented yet.") apply_ulysses_patch(self.model_config.model_type)
self.print_rank0("Ulysses patch applied!")
if fsdp_config.torch_dtype is None: if fsdp_config.torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor or self._is_critic else torch.bfloat16 torch_dtype = torch.float32 if self._is_actor or self._is_critic else torch.bfloat16
...@@ -170,13 +191,13 @@ class FSDPWorker(Worker): ...@@ -170,13 +191,13 @@ class FSDPWorker(Worker):
else: else:
auto_class = AutoModelForCausalLM auto_class = AutoModelForCausalLM
if self.rank == 0: if (not fsdp_config.enable_rank0_init) or self.device_mesh.get_local_rank("fsdp") == 0:
model = auto_class.from_pretrained( model = auto_class.from_pretrained(
model_config.model_path, model_config.model_path,
config=self.model_config, config=self.model_config,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
device_map="cpu", device_map="cpu" if fsdp_config.enable_rank0_init else "cuda",
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
...@@ -195,29 +216,50 @@ class FSDPWorker(Worker): ...@@ -195,29 +216,50 @@ class FSDPWorker(Worker):
if model_config.enable_gradient_checkpointing: if model_config.enable_gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
dist.barrier() if not (self._is_actor or self._is_critic):
if self.rank == 0: model.requires_grad_(False)
print_model_size(model)
if model_config.freeze_vision_tower:
if hasattr(model, "visual"):
model.visual.requires_grad_(False)
fsdp_config.use_orig_params = True
self.print_rank0("Vision tower is set to not trainable.")
else:
self.print_rank0("No vision tower found.")
log_gpu_memory_usage("After init from huggingface model") dist.barrier()
print_model_size(model)
print_gpu_memory_usage("After huggingface model init")
mixed_precision = MixedPrecision( mixed_precision = MixedPrecision(
param_dtype=PrecisionType.to_dtype(fsdp_config.mp_param_dtype), param_dtype=PrecisionType.to_dtype(fsdp_config.mp_param_dtype),
reduce_dtype=PrecisionType.to_dtype(fsdp_config.mp_reduce_dtype), reduce_dtype=PrecisionType.to_dtype(fsdp_config.mp_reduce_dtype),
buffer_dtype=PrecisionType.to_dtype(fsdp_config.mp_buffer_dtype), buffer_dtype=PrecisionType.to_dtype(fsdp_config.mp_buffer_dtype),
) )
auto_wrap_policy = get_fsdp_wrap_policy(model) auto_wrap_policy = get_fsdp_wrap_policy(model)
if fsdp_config.enable_full_shard: self.print_rank0(f"FSDP wrap policy: {auto_wrap_policy}.")
sharding_strategy = ShardingStrategy.FULL_SHARD
if self.device_mesh.ndim == 2:
if fsdp_config.enable_full_shard:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
else:
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
else: else:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP if fsdp_config.enable_full_shard:
sharding_strategy = ShardingStrategy.FULL_SHARD
else:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
if fsdp_config.param_offload or fsdp_config.optimizer_offload: if fsdp_config.enable_cpu_offload:
cpu_offload = CPUOffload(offload_params=fsdp_config.param_offload) cpu_offload = CPUOffload(offload_params=True)
else: else:
cpu_offload = None cpu_offload = None
if self.rank == 0: if fsdp_config.enable_rank0_init:
print(f"FSDP wrap policy: {auto_wrap_policy}.") sync_module_states = True
param_init_fn = get_init_fn(model, device="cuda") if self.rank != 0 else None
else:
sync_module_states = False
param_init_fn = None
self.fsdp_module = FSDP( self.fsdp_module = FSDP(
model, model,
...@@ -225,53 +267,60 @@ class FSDPWorker(Worker): ...@@ -225,53 +267,60 @@ class FSDPWorker(Worker):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy, auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
param_init_fn=get_init_fn(model, device="cuda") if self.rank != 0 else None, param_init_fn=param_init_fn,
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
sync_module_states=True, sync_module_states=sync_module_states,
forward_prefetch=False, forward_prefetch=False,
use_orig_params=False, use_orig_params=fsdp_config.use_orig_params,
device_mesh=self.device_mesh, device_mesh=self.device_mesh,
) )
log_gpu_memory_usage("After Actor FSDP init") print_gpu_memory_usage("After FSDP module init")
if self._is_actor or self._is_critic: if self._is_actor or self._is_critic:
self.optimizer = torch.optim.AdamW( if optim_config.strategy == "adamw":
self.fsdp_module.parameters(), self.optimizer = torch.optim.AdamW(
lr=optim_config.lr, self.fsdp_module.parameters(),
betas=optim_config.betas, lr=optim_config.lr,
weight_decay=optim_config.weight_decay, betas=optim_config.betas,
) weight_decay=optim_config.weight_decay,
num_warmup_steps = int(optim_config.lr_warmup_steps_ratio * optim_config.training_steps) fused=True,
)
elif optim_config.strategy == "adamw_bf16":
self.optimizer = AnyPrecisionAdamW(
self.fsdp_module.parameters(),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
)
else:
raise NotImplementedError(f"Optimizer {optim_config.strategy} not supported.")
num_warmup_steps = int(optim_config.lr_warmup_ratio * optim_config.training_steps)
self.lr_scheduler = get_constant_schedule_with_warmup( self.lr_scheduler = get_constant_schedule_with_warmup(
optimizer=self.optimizer, num_warmup_steps=num_warmup_steps optimizer=self.optimizer, num_warmup_steps=num_warmup_steps
) )
print_gpu_memory_usage("After optimizer init")
else: else:
self.optimizer, self.lr_scheduler = None, None self.optimizer, self.lr_scheduler = None, None
log_gpu_memory_usage("After actor optimizer init")
def _build_rollout(self) -> None: def _build_rollout(self) -> None:
# TODO(sgm): support FSDP hybrid shard for larger model
tp_size = self.config.rollout.tensor_parallel_size tp_size = self.config.rollout.tensor_parallel_size
dp_size = self.world_size // tp_size dp_size = self.world_size // tp_size
assert self.world_size % tp_size == 0, ( assert self.world_size % tp_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by tp_size: {tp_size}" f"rollout world size: {self.world_size} is not divisible by tp size: {tp_size}"
) )
rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=["dp", "tp"]) rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp"))
log_gpu_memory_usage("Before building vllm rollout")
self.rollout = vLLMRollout( self.rollout = vLLMRollout(
model_path=self.config.actor.model.model_path, model_path=self.config.actor.model.model_path,
config=self.config.rollout, config=self.config.rollout,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
) )
log_gpu_memory_usage("After building vllm rollout")
self.rollout_sharding_manager = FSDPVLLMShardingManager( self.rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.fsdp_module, module=self.fsdp_module,
inference_engine=self.rollout.inference_engine, inference_engine=self.rollout.inference_engine,
device_mesh=rollout_device_mesh, device_mesh=rollout_device_mesh,
) )
log_gpu_memory_usage("After building sharding manager") print_gpu_memory_usage("After vllm init")
@register(dispatch_mode=Dispatch.ONE_TO_ALL) @register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self): def init_model(self):
...@@ -280,11 +329,21 @@ class FSDPWorker(Worker): ...@@ -280,11 +329,21 @@ class FSDPWorker(Worker):
fsdp_config = self.config.critic.fsdp fsdp_config = self.config.critic.fsdp
optim_config = self.config.critic.optim optim_config = self.config.critic.optim
padding_free = self.config.critic.padding_free padding_free = self.config.critic.padding_free
else: role = "critic"
elif self._is_actor:
model_config = self.config.actor.model model_config = self.config.actor.model
fsdp_config = self.config.actor.fsdp fsdp_config = self.config.actor.fsdp
optim_config = self.config.actor.optim optim_config = self.config.actor.optim
padding_free = self.config.actor.padding_free padding_free = self.config.actor.padding_free
role = "actor"
elif self._is_ref:
model_config = self.config.actor.model
fsdp_config = self.config.ref.fsdp
optim_config = None
padding_free = self.config.ref.padding_free
role = "ref"
else:
raise ValueError(f"Unknown role {role}.")
if self._is_actor or self._is_critic or self._is_ref: if self._is_actor or self._is_critic or self._is_ref:
self._build_model_optimizer( self._build_model_optimizer(
...@@ -293,11 +352,13 @@ class FSDPWorker(Worker): ...@@ -293,11 +352,13 @@ class FSDPWorker(Worker):
optim_config=optim_config, optim_config=optim_config,
padding_free=padding_free, padding_free=padding_free,
) )
# get the original unwrapped module if self._use_param_offload:
self.unwrapped_model = self.fsdp_module._fsdp_wrapped_module offload_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload and not self._is_critic: print_gpu_memory_usage(f"After offload {role} model during init")
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer) offload_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("After offload actor optimizer during init") print_gpu_memory_usage(f"After offload {role} optimizer during init")
if self._is_actor: if self._is_actor:
self.actor = DataParallelPPOActor( self.actor = DataParallelPPOActor(
...@@ -317,7 +378,10 @@ class FSDPWorker(Worker): ...@@ -317,7 +378,10 @@ class FSDPWorker(Worker):
self._build_rollout() self._build_rollout()
if self._is_ref: if self._is_ref:
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.fsdp_module) self.ref_policy = DataParallelPPOActor(
config=self.config.ref,
actor_module=self.fsdp_module,
)
if self._is_actor or self._is_critic: if self._is_actor or self._is_critic:
self.flops_counter = FlopsCounter(self.model_config) self.flops_counter = FlopsCounter(self.model_config)
...@@ -325,42 +389,37 @@ class FSDPWorker(Worker): ...@@ -325,42 +389,37 @@ class FSDPWorker(Worker):
model=self.fsdp_module, model=self.fsdp_module,
optimizer=self.optimizer, optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler, lr_scheduler=self.lr_scheduler,
tokenizer=self.tokenizer, processing_class=self.processor if self.processor is not None else self.tokenizer,
processor=self.processor
) )
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.ONE_TO_ALL) @register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, path: str, global_step: int = 0, remove_previous_ckpt: bool = False): def save_checkpoint(self, path: str):
assert self._is_actor or self._is_critic assert self._is_actor or self._is_critic
if self._use_param_offload: if self._use_param_offload:
load_fsdp_model(self.fsdp_module) load_fsdp_model(self.fsdp_module)
self.checkpoint_manager.save_checkpoint( self.checkpoint_manager.save_checkpoint(path)
local_path=path,
global_step=global_step,
remove_previous_ckpt=remove_previous_ckpt,
)
dist.barrier() dist.barrier()
if self._use_param_offload: if self._use_param_offload:
offload_fsdp_model(self.fsdp_module) offload_fsdp_model(self.fsdp_module)
@register(dispatch_mode=Dispatch.ONE_TO_ALL) @register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, path: str, del_local_after_load: bool = True): def load_checkpoint(self, path: str):
if self._use_param_offload: if self._use_param_offload:
load_fsdp_model(self.fsdp_module) load_fsdp_model(self.fsdp_module)
self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load) self.checkpoint_manager.load_checkpoint(path)
dist.barrier() dist.barrier()
if self._use_param_offload: if self._use_param_offload:
offload_fsdp_model(self.fsdp_module) offload_fsdp_model(self.fsdp_module)
"""ActorRolloutRefWorker""" if self._use_optimizer_offload:
offload_fsdp_optimizer(self.optimizer)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto): def update_actor(self, data: DataProto):
assert self._is_actor assert self._is_actor
data = data.to(torch.cuda.current_device())
if self._use_param_offload: if self._use_param_offload:
load_fsdp_model(self.fsdp_module) load_fsdp_model(self.fsdp_module)
...@@ -368,7 +427,6 @@ class FSDPWorker(Worker): ...@@ -368,7 +427,6 @@ class FSDPWorker(Worker):
if self._use_optimizer_offload: if self._use_optimizer_offload:
load_fsdp_optimizer(optimizer=self.optimizer) load_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("Before update policy")
with self.ulysses_sharding_manager: with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data) data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_policy", logger=None) as timer: with Timer(name="update_policy", logger=None) as timer:
...@@ -377,17 +435,27 @@ class FSDPWorker(Worker): ...@@ -377,17 +435,27 @@ class FSDPWorker(Worker):
delta_time = timer.last delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"] global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size metrics["perf/mfu_actor"] = (
estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size)
)
metrics["perf/max_memory_allocated_gb"] = (
torch.cuda.max_memory_allocated() - self.rollout_sharding_manager.freed_bytes
) / (1024**3)
metrics["perf/max_memory_reserved_gb"] = (
torch.cuda.max_memory_reserved() - self.rollout_sharding_manager.freed_bytes
) / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
self.lr_scheduler.step() self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()[0] lr = self.lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr metrics["actor/lr"] = lr
log_gpu_memory_usage("After update policy")
# TODO: here, we should return all metrics # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
output = DataProto(meta_info={"metrics": metrics}) output = DataProto(
output = self.ulysses_sharding_manager.postprocess_data(data=output) non_tensor_batch={
output = output.to("cpu") key: np.array([value] if np.isscalar(value) else value) for key, value in metrics.items()
}
)
if self._use_param_offload: if self._use_param_offload:
offload_fsdp_model(self.fsdp_module) offload_fsdp_model(self.fsdp_module)
...@@ -395,7 +463,7 @@ class FSDPWorker(Worker): ...@@ -395,7 +463,7 @@ class FSDPWorker(Worker):
if self._use_optimizer_offload: if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer) offload_fsdp_optimizer(optimizer=self.optimizer)
torch.cuda.empty_cache() output = output.to("cpu")
return output return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
...@@ -422,22 +490,17 @@ class FSDPWorker(Worker): ...@@ -422,22 +490,17 @@ class FSDPWorker(Worker):
if self._use_optimizer_offload: if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer) offload_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("After entering rollout sharding manager")
prompts = self.rollout_sharding_manager.preprocess_data(prompts) prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts) output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage("After rollout generation")
output = self.rollout_sharding_manager.postprocess_data(output) output = self.rollout_sharding_manager.postprocess_data(output)
output = output.to("cpu") output = output.to("cpu")
torch.cuda.empty_cache() # clear kv cache
log_gpu_memory_usage("After recompute log prob")
return output return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto): def compute_log_probs(self, data: DataProto):
assert self._is_actor assert self._is_actor
data = data.to(torch.cuda.current_device())
if self._use_param_offload: if self._use_param_offload:
load_fsdp_model(self.fsdp_module) load_fsdp_model(self.fsdp_module)
...@@ -452,8 +515,6 @@ class FSDPWorker(Worker): ...@@ -452,8 +515,6 @@ class FSDPWorker(Worker):
) )
output = self.ulysses_sharding_manager.postprocess_data(output) output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module # unshard the root FSDP module
if self.world_size > 1: if self.world_size > 1:
...@@ -462,13 +523,13 @@ class FSDPWorker(Worker): ...@@ -462,13 +523,13 @@ class FSDPWorker(Worker):
if self._use_param_offload: if self._use_param_offload:
offload_fsdp_model(self.fsdp_module) offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache() output = output.to("cpu")
log_gpu_memory_usage("After compute_log_prob")
return output return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto): def compute_ref_log_probs(self, data: DataProto):
assert self._is_ref assert self._is_ref
data = data.to(torch.cuda.current_device())
if self._use_param_offload: if self._use_param_offload:
load_fsdp_model(self.fsdp_module) load_fsdp_model(self.fsdp_module)
...@@ -476,11 +537,9 @@ class FSDPWorker(Worker): ...@@ -476,11 +537,9 @@ class FSDPWorker(Worker):
with self.ulysses_sharding_manager: with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data) data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.ref_policy.compute_log_prob(data=data) output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = DataProto.from_dict(tensors={"ref_log_probs": output})
output = self.ulysses_sharding_manager.postprocess_data(output) output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module # unshard the root FSDP module
if self.world_size > 1: if self.world_size > 1:
...@@ -489,15 +548,13 @@ class FSDPWorker(Worker): ...@@ -489,15 +548,13 @@ class FSDPWorker(Worker):
if self._use_param_offload: if self._use_param_offload:
offload_fsdp_model(self.fsdp_module) offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache() output = output.to("cpu")
log_gpu_memory_usage("After compute_ref_log_prob")
return output return output
"""CriticWorker"""
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto): def compute_values(self, data: DataProto):
assert self._is_critic assert self._is_critic
data = data.to(torch.cuda.current_device())
if self._use_param_offload: if self._use_param_offload:
load_fsdp_model(self.fsdp_module) load_fsdp_model(self.fsdp_module)
...@@ -507,15 +564,15 @@ class FSDPWorker(Worker): ...@@ -507,15 +564,15 @@ class FSDPWorker(Worker):
output = DataProto.from_dict(tensors={"values": values}) output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output) output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._use_param_offload: if self._use_param_offload:
offload_fsdp_model(self.fsdp_module) offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache() output = output.to("cpu")
return output return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto): def update_critic(self, data: DataProto):
data = data.to(torch.cuda.current_device())
if self._use_param_offload: if self._use_param_offload:
load_fsdp_model(self.fsdp_module) load_fsdp_model(self.fsdp_module)
...@@ -530,21 +587,26 @@ class FSDPWorker(Worker): ...@@ -530,21 +587,26 @@ class FSDPWorker(Worker):
delta_time = timer.last delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"] global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["mfu/critic"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size metrics["perf/mfu_critic"] = (
estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size)
)
self.lr_scheduler.step() self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()[0] lr = self.lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr metrics["critic/lr"] = lr
output = DataProto(batch=None, meta_info={"metrics": metrics}) # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
output = self.ulysses_sharding_manager.postprocess_data(data=output) output = DataProto(
non_tensor_batch={
metric: np.array([value] if np.isscalar(value) else value) for metric, value in metrics.items()
}
)
output = output.to("cpu")
if self._use_param_offload: if self._use_param_offload:
offload_fsdp_model(self.fsdp_module) offload_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload: if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer) offload_fsdp_optimizer(optimizer=self.optimizer)
torch.cuda.empty_cache() output = output.to("cpu")
return output return output
...@@ -13,55 +13,48 @@ ...@@ -13,55 +13,48 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import Any, Callable, Dict, Tuple, TypedDict
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from verl import DataProto from ...protocol import DataProto
from verl.utils.reward_score import math_compute_score, r1v_compute_score from ...utils.reward_score import math_compute_score, r1v_compute_score
class RewardScore(TypedDict):
overall: float
format: float
accuracy: float
class CustomRewardManager: class CustomRewardManager:
def __init__(self, tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: str): def __init__(self, tokenizer: PreTrainedTokenizer, compute_score: str):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_examine = num_examine
if compute_score == "math": if compute_score == "math":
self.compute_score = math_compute_score self.compute_score: Callable[[str, str], RewardScore] = math_compute_score
elif compute_score == "r1v": elif compute_score == "r1v":
self.compute_score = r1v_compute_score self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score
else: else:
raise NotImplementedError() raise NotImplementedError()
def __call__(self, data: DataProto) -> torch.Tensor: def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, Any]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
already_print = 0 reward_metrics = defaultdict(list)
for i in range(len(data)): for i in range(len(data)):
data_item = data[i] # DataProtoItem data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch["responses"] response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() response_mask = data_item.batch["response_mask"]
valid_response_length = response_mask.sum()
valid_response_ids = response_ids[:valid_response_length] valid_response_ids = response_ids[:valid_response_length]
# decode
prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
ground_truth = data_item.non_tensor_batch["ground_truth"]
ground_truth = data_item.non_tensor_batch["answer"]
score = self.compute_score(response_str, ground_truth) score = self.compute_score(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score reward_tensor[i, valid_response_length - 1] = score["overall"]
for key, value in score.items():
if already_print < self.num_examine: reward_metrics[key].append(value)
already_print += 1
print("[prompt]", prompt_str)
print("[response]", response_str)
print("[ground_truth]", ground_truth)
print("[score]", score)
return reward_tensor return reward_tensor, reward_metrics
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from verl import DataProto from ...protocol import DataProto
__all__ = ["BaseRollout"] __all__ = ["BaseRollout"]
......
...@@ -16,15 +16,18 @@ Rollout config ...@@ -16,15 +16,18 @@ Rollout config
""" """
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict
@dataclass @dataclass
class RolloutConfig: class RolloutConfig:
name: str = "vllm" name: str = "vllm"
n: int = 1
temperature: float = 1.0 temperature: float = 1.0
top_k: int = -1
top_p: float = 1.0 top_p: float = 1.0
dtype: str = "bfloat16" top_k: int = -1
limit_images: int = 0
dtype: str = "bf16"
gpu_memory_utilization: float = 0.5 gpu_memory_utilization: float = 0.5
ignore_eos: bool = False ignore_eos: bool = False
enforce_eager: bool = False enforce_eager: bool = False
...@@ -34,9 +37,7 @@ class RolloutConfig: ...@@ -34,9 +37,7 @@ class RolloutConfig:
max_num_batched_tokens: int = 8192 max_num_batched_tokens: int = 8192
max_num_seqs: int = 1024 max_num_seqs: int = 1024
disable_log_stats: bool = True disable_log_stats: bool = True
do_sample: bool = True val_override_config: Dict[str, Any] = field(default_factory=dict)
n: int = 1
limit_images: int = 0
"""auto keys""" """auto keys"""
prompt_length: int = field(default=-1, init=False) prompt_length: int = field(default=-1, init=False)
response_length: int = field(default=-1, init=False) response_length: int = field(default=-1, init=False)
......
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .dtensor_weight_loaders import load_dtensor_weights
from .vllm_rollout_spmd import vLLMRollout from .vllm_rollout_spmd import vLLMRollout
__all__ = ["vLLMRollout", "load_dtensor_weights"] __all__ = ["vLLMRollout"]
...@@ -18,26 +18,29 @@ When working with FSDP: ...@@ -18,26 +18,29 @@ When working with FSDP:
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM - Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
""" """
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, List, Union from typing import Any, List, Union
import numpy as np
import torch import torch
import torch.distributed import torch.distributed
from tensordict import TensorDict from tensordict import TensorDict
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm import LLM, RequestOutput, SamplingParams from vllm import LLM, RequestOutput, SamplingParams
from verl import DataProto from ....protocol import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length from ....utils import torch_functional as VF
from verl.workers.rollout.base import BaseRollout from ....utils.torch_dtypes import PrecisionType
from verl.workers.rollout.config import RolloutConfig from ..base import BaseRollout
from ..config import RolloutConfig
def _repeat_interleave(features: Union[torch.Tensor, List[Any]], repeats: int) -> Union[torch.Tensor, List[Any]]: def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
if isinstance(features, torch.Tensor): if isinstance(value, torch.Tensor):
return features.repeat_interleave(repeats, dim=0) return value.repeat_interleave(repeats, dim=0)
else: else:
return [feature for feature in features for _ in range(repeats)] return np.repeat(value, repeats, axis=0)
class vLLMRollout(BaseRollout): class vLLMRollout(BaseRollout):
...@@ -50,6 +53,7 @@ class vLLMRollout(BaseRollout): ...@@ -50,6 +53,7 @@ class vLLMRollout(BaseRollout):
tokenizer: the task/model tokenizer tokenizer: the task/model tokenizer
""" """
super().__init__() super().__init__()
self.rank = int(os.getenv("RANK", "0"))
self.config = config self.config = config
self.pad_token_id = tokenizer.pad_token_id self.pad_token_id = tokenizer.pad_token_id
if config.tensor_parallel_size > torch.distributed.get_world_size(): if config.tensor_parallel_size > torch.distributed.get_world_size():
...@@ -69,7 +73,7 @@ class vLLMRollout(BaseRollout): ...@@ -69,7 +73,7 @@ class vLLMRollout(BaseRollout):
model=model_path, model=model_path,
skip_tokenizer_init=False, skip_tokenizer_init=False,
tensor_parallel_size=config.tensor_parallel_size, tensor_parallel_size=config.tensor_parallel_size,
dtype=config.dtype, dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
gpu_memory_utilization=config.gpu_memory_utilization, gpu_memory_utilization=config.gpu_memory_utilization,
enforce_eager=config.enforce_eager, enforce_eager=config.enforce_eager,
max_model_len=config.prompt_length + config.response_length, max_model_len=config.prompt_length + config.response_length,
...@@ -77,6 +81,7 @@ class vLLMRollout(BaseRollout): ...@@ -77,6 +81,7 @@ class vLLMRollout(BaseRollout):
enable_sleep_mode=True, enable_sleep_mode=True,
distributed_executor_backend="external_launcher", distributed_executor_backend="external_launcher",
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
disable_mm_preprocessor_cache=True,
disable_log_stats=config.disable_log_stats, disable_log_stats=config.disable_log_stats,
enable_chunked_prefill=config.enable_chunked_prefill, enable_chunked_prefill=config.enable_chunked_prefill,
**vllm_init_kwargs, **vllm_init_kwargs,
...@@ -111,7 +116,7 @@ class vLLMRollout(BaseRollout): ...@@ -111,7 +116,7 @@ class vLLMRollout(BaseRollout):
setattr(self.sampling_params, key, value) setattr(self.sampling_params, key, value)
@torch.no_grad() @torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: def generate_sequences(self, prompts: DataProto) -> DataProto:
# left-padded attention_mask # left-padded attention_mask
input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length) input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length)
attention_mask: torch.Tensor = prompts.batch["attention_mask"] attention_mask: torch.Tensor = prompts.batch["attention_mask"]
...@@ -119,54 +124,40 @@ class vLLMRollout(BaseRollout): ...@@ -119,54 +124,40 @@ class vLLMRollout(BaseRollout):
eos_token_id: int = prompts.meta_info["eos_token_id"] eos_token_id: int = prompts.meta_info["eos_token_id"]
batch_size = input_ids.size(0) batch_size = input_ids.size(0)
do_sample = prompts.meta_info.get("do_sample", True)
if not do_sample:
kwargs = {
"n": 1,
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
non_tensor_batch = prompts.non_tensor_batch non_tensor_batch = prompts.non_tensor_batch
if batch_size != len(non_tensor_batch["raw_prompt_ids"]): if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
raise RuntimeError("vllm sharding manager is not work properly.") raise RuntimeError("vllm sharding manager is not work properly.")
if "images" in non_tensor_batch: if "multi_modal_data" in non_tensor_batch:
vllm_inputs = [] vllm_inputs = []
for raw_prompt_ids, images in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("images")): for raw_prompt_ids, multi_modal_data in zip(
vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": {"image": images}}) non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")
):
vllm_inputs.append({"prompt_token_ids": list(raw_prompt_ids), "multi_modal_data": multi_modal_data})
else: else:
vllm_inputs = [ vllm_inputs = [
{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") {"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
] ]
# users can customize different sampling_params at different run # users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs): with self.update_sampling_params(**prompts.meta_info):
completions: List[RequestOutput] = self.inference_engine.generate( completions: List[RequestOutput] = self.inference_engine.generate(
prompts=vllm_inputs, sampling_params=self.sampling_params prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=(self.rank == 0)
) )
response_ids = [output.token_ids for completion in completions for output in completion.outputs]
response_ids = [] response_ids = VF.pad_2d_list_to_length(
for completion in completions: response_ids, self.pad_token_id, max_length=self.config.response_length
for output in completion.outputs: ).to(input_ids.device)
response_ids.append(output.token_ids)
if self.sampling_params.n > 1:
response_ids = pad_2d_list_to_length( batch_size = batch_size * self.sampling_params.n
response_ids, self.pad_token_id, max_length=self.config.response_length input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
).to(input_ids.device) attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
if self.config.n > 1 and do_sample: if "multi_modal_inputs" in non_tensor_batch.keys():
batch_size = batch_size * self.config.n non_tensor_batch["multi_modal_inputs"] = _repeat_interleave(
input_ids = _repeat_interleave(input_ids, self.config.n) non_tensor_batch["multi_modal_inputs"], self.sampling_params.n
attention_mask = _repeat_interleave(attention_mask, self.config.n) )
position_ids = _repeat_interleave(position_ids, self.config.n)
if "pixel_values" in non_tensor_batch.keys():
non_tensor_batch["pixel_values"] = _repeat_interleave(non_tensor_batch["pixel_values"], self.config.n)
non_tensor_batch["image_grid_thw"] = _repeat_interleave(
non_tensor_batch["image_grid_thw"], self.config.n
)
sequence_ids = torch.cat([input_ids, response_ids], dim=-1) sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
response_length = response_ids.size(1) response_length = response_ids.size(1)
...@@ -180,10 +171,10 @@ class vLLMRollout(BaseRollout): ...@@ -180,10 +171,10 @@ class vLLMRollout(BaseRollout):
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11] # position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[..., -1:] + delta_position_id response_position_ids = position_ids[..., -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1) position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask( response_mask = VF.get_eos_mask(
response_ids=response_ids, eos_token=eos_token_id, dtype=attention_mask.dtype response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
) )
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid # all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict( batch = TensorDict(
...@@ -192,6 +183,7 @@ class vLLMRollout(BaseRollout): ...@@ -192,6 +183,7 @@ class vLLMRollout(BaseRollout):
"responses": response_ids, "responses": response_ids,
"input_ids": sequence_ids, # here input_ids become the whole sentences "input_ids": sequence_ids, # here input_ids become the whole sentences
"attention_mask": attention_mask, "attention_mask": attention_mask,
"response_mask": response_mask,
"position_ids": position_ids, "position_ids": position_ids,
}, },
batch_size=batch_size, batch_size=batch_size,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
Sharding manager to implement HybridEngine Sharding manager to implement HybridEngine
""" """
from verl import DataProto from ...protocol import DataProto
class BaseShardingManager: class BaseShardingManager:
......
...@@ -17,9 +17,8 @@ Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT ...@@ -17,9 +17,8 @@ Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from verl import DataProto from ...protocol import DataProto, all_gather_data_proto
from verl.utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group
from .base import BaseShardingManager from .base import BaseShardingManager
...@@ -48,9 +47,9 @@ class FSDPUlyssesShardingManager(BaseShardingManager): ...@@ -48,9 +47,9 @@ class FSDPUlyssesShardingManager(BaseShardingManager):
In Ulysses, we need to make sure the same data is used across a SP group In Ulysses, we need to make sure the same data is used across a SP group
""" """
if self.device_mesh is not None: if self.device_mesh is not None:
sp_size = self.device_mesh["sp"].size()
sp_group = self.device_mesh["sp"].get_group() sp_group = self.device_mesh["sp"].get_group()
data = data.to("cuda") all_gather_data_proto(data, size=sp_size, group=sp_group)
data.all_gather(sp_group)
return data return data
......
...@@ -12,19 +12,20 @@ ...@@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from typing import Dict, Iterable, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from vllm import LLM from vllm import LLM
from vllm.distributed import parallel_state as vllm_ps from vllm.distributed import parallel_state as vllm_ps
from verl import DataProto from ...protocol import DataProto, all_gather_data_proto
from verl.utils.performance import log_gpu_memory_usage from ...utils.model_utils import print_gpu_memory_usage
from verl.workers.rollout.vllm_rollout import load_dtensor_weights
from .base import BaseShardingManager from .base import BaseShardingManager
...@@ -38,11 +39,22 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -38,11 +39,22 @@ class FSDPVLLMShardingManager(BaseShardingManager):
self.module = module self.module = module
self.inference_engine = inference_engine self.inference_engine = inference_engine
self.device_mesh = device_mesh self.device_mesh = device_mesh
FSDP.set_state_dict_type( with warnings.catch_warnings():
self.module, warnings.simplefilter("ignore")
state_dict_type=StateDictType.SHARDED_STATE_DICT, FSDP.set_state_dict_type(
state_dict_config=ShardedStateDictConfig(), self.module,
) state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
self.world_size = dist.get_world_size()
self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group
# Record freed bytes to estimate memory usage correctly
# https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119
self.freed_bytes = 0
# Note that torch_random_states may be different on each dp rank # Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state() self.torch_random_states = torch.cuda.get_rng_state()
...@@ -55,29 +67,45 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -55,29 +67,45 @@ class FSDPVLLMShardingManager(BaseShardingManager):
else: else:
self.gen_random_states = None self.gen_random_states = None
def _make_weight_iterator(
self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
for name, tensor in actor_weights.items():
yield name, tensor.full_tensor() if self.world_size != 1 else tensor
def __enter__(self): def __enter__(self):
log_gpu_memory_usage("Before state_dict() in sharding manager") # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()
print_gpu_memory_usage("Before state_dict() in sharding manager")
actor_weights = self.module.state_dict() actor_weights = self.module.state_dict()
log_gpu_memory_usage("After state_dict() in sharding manager") print_gpu_memory_usage("After state_dict() in sharding manager")
self.inference_engine.wake_up() self.inference_engine.wake_up()
load_dtensor_weights( model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
actor_weights, self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model model.load_weights(self._make_weight_iterator(actor_weights))
) print_gpu_memory_usage("After sync model weights in sharding manager")
log_gpu_memory_usage("After sync model weights in sharding manager")
del actor_weights del actor_weights
torch.cuda.empty_cache() torch.cuda.empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager") print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
# important: need to manually set the random states of each tp to be identical. # important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None: if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state() self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states) torch.cuda.set_rng_state(self.gen_random_states)
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
log_gpu_memory_usage("Before vllm offload in sharding manager") print_gpu_memory_usage("Before vllm offload in sharding manager")
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
self.inference_engine.sleep(level=1) self.inference_engine.sleep(level=1)
log_gpu_memory_usage("After vllm offload in sharding manager") free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print_gpu_memory_usage("After vllm offload in sharding manager")
self.module.train() self.module.train()
torch.cuda.empty_cache() # add empty cache after each compute torch.cuda.empty_cache() # add empty cache after each compute
...@@ -88,15 +116,13 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -88,15 +116,13 @@ class FSDPVLLMShardingManager(BaseShardingManager):
torch.cuda.set_rng_state(self.torch_random_states) torch.cuda.set_rng_state(self.torch_random_states)
def preprocess_data(self, data: DataProto) -> DataProto: def preprocess_data(self, data: DataProto) -> DataProto:
tp_group = vllm_ps.get_tensor_model_parallel_group().device_group """All gather across tp group to make each rank has identical input."""
data = data.to("cuda") all_gather_data_proto(data, size=self.tp_size, group=self.tp_group)
data.all_gather(tp_group)
return data return data
def postprocess_data(self, data: DataProto) -> DataProto: def postprocess_data(self, data: DataProto) -> DataProto:
dp_rank = dist.get_rank() """Get chunk data of this tp rank since we do all gather in preprocess."""
tp_size = vllm_ps.get_tensor_model_parallel_world_size() if self.tp_size > 1:
if tp_size > 1: data = data.chunk(chunks=self.tp_size)[self.tp_rank]
data = data.chunk(chunks=tp_size)[dp_rank % tp_size]
return data return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment