Commit ff7fb65e authored by chenych's avatar chenych
Browse files

Update

parent c132cbcb
...@@ -72,7 +72,9 @@ class ActorConfig: ...@@ -72,7 +72,9 @@ class ActorConfig:
micro_batch_size_per_device_for_update: int = 4 micro_batch_size_per_device_for_update: int = 4
micro_batch_size_per_device_for_experience: int = 16 micro_batch_size_per_device_for_experience: int = 16
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
clip_ratio: float = 0.2 clip_ratio_low: float = 0.2
clip_ratio_high: float = 0.3
clip_ratio_dual: float = 3.0
ppo_epochs: int = 1 ppo_epochs: int = 1
padding_free: bool = False padding_free: bool = False
ulysses_sequence_parallel_size: int = 1 ulysses_sequence_parallel_size: int = 1
......
...@@ -250,18 +250,21 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -250,18 +250,21 @@ class DataParallelPPOActor(BasePPOActor):
# all return: (bsz, response_length) # all return: (bsz, response_length)
log_probs = self._forward_micro_batch(model_inputs, temperature=temperature) log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)
entropy_loss = -VF.masked_mean(log_probs, response_mask) # estimator of entropy loss
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss( pg_loss, pg_clipfrac_higher, pg_clipfrac_lower, ppo_kl = core_algos.compute_policy_loss(
old_log_probs=old_log_probs, old_log_probs=old_log_probs,
log_probs=log_probs, log_probs=log_probs,
advantages=advantages, advantages=advantages,
eos_mask=response_mask, response_mask=response_mask,
cliprange=self.config.clip_ratio, clip_ratio_low=self.config.clip_ratio_low,
clip_ratio_high=self.config.clip_ratio_high,
clip_ratio_dual=self.config.clip_ratio_dual,
) )
if "ref_log_probs" in model_inputs: if "ref_log_probs" in model_inputs:
ref_log_probs = model_inputs["ref_log_probs"] ref_log_probs = model_inputs["ref_log_probs"]
# compute kl loss # compute kl loss
kld = core_algos.kl_penalty( kld = core_algos.compute_kl(
log_probs=log_probs, log_probs=log_probs,
ref_log_probs=ref_log_probs, ref_log_probs=ref_log_probs,
kl_penalty=self.config.kl_penalty, kl_penalty=self.config.kl_penalty,
...@@ -276,7 +279,9 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -276,7 +279,9 @@ class DataParallelPPOActor(BasePPOActor):
batch_metrics = { batch_metrics = {
"actor/pg_loss": pg_loss.detach().item(), "actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac": pg_clipfrac.detach().item(), "actor/pg_clipfrac_higher": pg_clipfrac_higher.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
"actor/entropy_loss": entropy_loss.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(), "actor/ppo_kl": ppo_kl.detach().item(),
} }
append_to_dict(metrics, batch_metrics) append_to_dict(metrics, batch_metrics)
......
...@@ -199,14 +199,14 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -199,14 +199,14 @@ class DataParallelPPOCritic(BasePPOCritic):
values = model_inputs["values"] values = model_inputs["values"]
returns = model_inputs["returns"] returns = model_inputs["returns"]
response_length = responses.size(1) response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length - 1 : -1] # shift left for value computation action_mask = attention_mask[:, -response_length - 1 : -1] # shift left for value computation
vpreds = self._forward_micro_batch(model_inputs) vpreds = self._forward_micro_batch(model_inputs)
vf_loss, vf_clipfrac = core_algos.compute_value_loss( vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds, vpreds=vpreds,
returns=returns, returns=returns,
values=values, values=values,
eos_mask=eos_mask, action_mask=action_mask,
cliprange_value=self.config.cliprange_value, cliprange_value=self.config.cliprange_value,
) )
loss = vf_loss / gradient_accumulation loss = vf_loss / gradient_accumulation
...@@ -215,7 +215,7 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -215,7 +215,7 @@ class DataParallelPPOCritic(BasePPOCritic):
batch_metrics = { batch_metrics = {
"critic/vf_loss": vf_loss.detach().item(), "critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(), "critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": VF.masked_mean(vpreds, eos_mask).detach().item(), "critic/vpred_mean": VF.masked_mean(vpreds, action_mask).detach().item(),
} }
append_to_dict(metrics, batch_metrics) append_to_dict(metrics, batch_metrics)
......
...@@ -57,7 +57,7 @@ from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_wi ...@@ -57,7 +57,7 @@ from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_wi
from .actor import DataParallelPPOActor from .actor import DataParallelPPOActor
from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig
from .critic import DataParallelPPOCritic from .critic import DataParallelPPOCritic
from .rollout.vllm_rollout import vLLMRollout from .rollout import vLLMRollout
from .sharding_manager import FSDPVLLMShardingManager from .sharding_manager import FSDPVLLMShardingManager
from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
...@@ -75,6 +75,10 @@ class FSDPWorker(Worker): ...@@ -75,6 +75,10 @@ class FSDPWorker(Worker):
if not dist.is_initialized(): if not dist.is_initialized():
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
# improve numerical stability
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_critic = self.role == "critic" self._is_critic = self.role == "critic"
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
...@@ -131,7 +135,7 @@ class FSDPWorker(Worker): ...@@ -131,7 +135,7 @@ class FSDPWorker(Worker):
config.global_batch_size // self.device_mesh.size() * config.ulysses_sequence_parallel_size config.global_batch_size // self.device_mesh.size() * config.ulysses_sequence_parallel_size
) )
if config.global_batch_size_per_device == 0: if config.global_batch_size_per_device == 0:
raise ValueError(f"{role} global batch size must be larger than num gpus.") raise ValueError(f"{role} global batch size * ulysses size must be larger than num gpus.")
if config.global_batch_size_per_device % config.micro_batch_size_per_device_for_update != 0: if config.global_batch_size_per_device % config.micro_batch_size_per_device_for_update != 0:
raise ValueError(f"{role} global batch size per device must be divisible by the micro batch size.") raise ValueError(f"{role} global batch size per device must be divisible by the micro batch size.")
...@@ -413,7 +417,7 @@ class FSDPWorker(Worker): ...@@ -413,7 +417,7 @@ class FSDPWorker(Worker):
if self._use_param_offload: if self._use_param_offload:
offload_fsdp_model(self.fsdp_module) offload_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload: if self._use_optimizer_offload: # avoid OOM in resuming
offload_fsdp_optimizer(self.optimizer) offload_fsdp_optimizer(self.optimizer)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
......
...@@ -21,4 +21,5 @@ from dataclasses import dataclass ...@@ -21,4 +21,5 @@ from dataclasses import dataclass
@dataclass @dataclass
class RewardConfig: class RewardConfig:
reward_type: str = "function" reward_type: str = "function"
compute_score: str = "math" score_function: str = "math"
skip_special_tokens: bool = True
...@@ -14,13 +14,14 @@ ...@@ -14,13 +14,14 @@
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Tuple, TypedDict from typing import Callable, Dict, List, Tuple, TypedDict
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ...protocol import DataProto from ...protocol import DataProto
from ...utils.reward_score import math_compute_score, r1v_compute_score from ...utils.reward_score import math_compute_score, r1v_compute_score
from .config import RewardConfig
class RewardScore(TypedDict): class RewardScore(TypedDict):
...@@ -30,16 +31,17 @@ class RewardScore(TypedDict): ...@@ -30,16 +31,17 @@ class RewardScore(TypedDict):
class CustomRewardManager: class CustomRewardManager:
def __init__(self, tokenizer: PreTrainedTokenizer, compute_score: str): def __init__(self, tokenizer: PreTrainedTokenizer, config: RewardConfig):
self.config = config
self.tokenizer = tokenizer self.tokenizer = tokenizer
if compute_score == "math": if config.score_function == "math":
self.compute_score: Callable[[str, str], RewardScore] = math_compute_score self.compute_score: Callable[[str, str], RewardScore] = math_compute_score
elif compute_score == "r1v": elif config.score_function == "r1v":
self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score
else: else:
raise NotImplementedError() raise NotImplementedError(f"Unknown score function {config.score_function}.")
def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, Any]]: def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list) reward_metrics = defaultdict(list)
for i in range(len(data)): for i in range(len(data)):
...@@ -49,7 +51,9 @@ class CustomRewardManager: ...@@ -49,7 +51,9 @@ class CustomRewardManager:
valid_response_length = response_mask.sum() valid_response_length = response_mask.sum()
valid_response_ids = response_ids[:valid_response_length] valid_response_ids = response_ids[:valid_response_length]
response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) response_str = self.tokenizer.decode(
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
)
ground_truth = data_item.non_tensor_batch["ground_truth"] ground_truth = data_item.non_tensor_batch["ground_truth"]
score = self.compute_score(response_str, ground_truth) score = self.compute_score(response_str, ground_truth)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from .config import RolloutConfig from .config import RolloutConfig
from .vllm_rollout_spmd import vLLMRollout
__all__ = ["RolloutConfig"] __all__ = ["RolloutConfig", "vLLMRollout"]
...@@ -28,11 +28,10 @@ class RolloutConfig: ...@@ -28,11 +28,10 @@ class RolloutConfig:
top_k: int = -1 top_k: int = -1
limit_images: int = 0 limit_images: int = 0
dtype: str = "bf16" dtype: str = "bf16"
gpu_memory_utilization: float = 0.5 gpu_memory_utilization: float = 0.6
ignore_eos: bool = False ignore_eos: bool = False
enforce_eager: bool = False enforce_eager: bool = False
free_cache_engine: bool = False enable_chunked_prefill: bool = False # only for v0 engine
enable_chunked_prefill: bool = False
tensor_parallel_size: int = 2 tensor_parallel_size: int = 2
max_num_batched_tokens: int = 8192 max_num_batched_tokens: int = 8192
max_num_seqs: int = 1024 max_num_seqs: int = 1024
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vllm_rollout_spmd import vLLMRollout
__all__ = ["vLLMRollout"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import is_pp_missing_parameter
def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[stacked_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight)
def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def qwen2vl_dtensor_weight_loader(actor_weights: Dict[str, torch.Tensor], vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (vllm_substr, hf_substr, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
vllm_params = dict(vllm_model.named_parameters(remove_duplicate=False))
for actor_name, actor_weight in actor_weights.items():
if "rotary_emb.inv_freq" in actor_name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_name:
continue
for vllm_substr, hf_substr, shard_id in stacked_params_mapping:
if hf_substr not in actor_name:
continue
if "visual" in actor_name:
continue
vllm_name = "language_model." + actor_name.replace(hf_substr, vllm_substr)
if actor_name.endswith(".bias") and actor_name not in vllm_params:
continue # skip loading extra bias for GPTQ models
local_actor_weight = redistribute_dtensor(param_name=actor_name, loaded_weights=actor_weight)
vllm_param = vllm_params[vllm_name]
weight_loader = vllm_param.weight_loader
weight_loader(vllm_param, local_actor_weight.to(dtype=vllm_param.dtype), shard_id)
break
else:
if actor_name.endswith(".bias") and actor_name not in vllm_params:
continue # skip loading extra bias for GPTQ models
if "visual" in actor_name:
vllm_name = actor_name
else:
vllm_name = "language_model." + actor_name
vllm_param = vllm_params[vllm_name]
local_actor_weight = redistribute_dtensor(param_name=actor_name, loaded_weights=actor_weight)
weight_loader = getattr(vllm_param, "weight_loader", default_weight_loader)
weight_loader(vllm_param, local_actor_weight.to(dtype=vllm_param.dtype))
def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=vllm_model.config.n_routed_experts,
)
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
local_loaded_weight.to(dtype=param.dtype),
weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None):
param_name = _process_parameter_names(name=param_name)
if parallelize_plan is not None:
assert param_name in parallelize_plan.keys(), (
f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}"
)
placement = parallelize_plan[param_name]
local_loaded_weights = loaded_weights.redistribute(
device_mesh=loaded_weights.device_mesh, placements=placement
).to_local()
else:
local_loaded_weights = loaded_weights.full_tensor()
return local_loaded_weights
def _process_parameter_names(name):
# Remove '.weight' if it exists at the end of the string
if name.endswith(".weight"):
name = name[:-7]
# Remove 'model.layers.x.' or 'model.' prefix
if "model.layers" in name:
parts = name.split(".")
# Reconstruct the string without 'model.layers.x.'
name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
elif name.startswith("model."):
name = name[6:] # Remove 'model.'
return name
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
"LlamaForCausalLM": llama_dtensor_weight_loader,
"LLaMAForCausalLM": llama_dtensor_weight_loader,
"MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM
"InternLMForCausalLM": llama_dtensor_weight_loader,
"Phi3ForCausalLM": llama_dtensor_weight_loader,
"GemmaForCausalLM": gemma_dtensor_weight_loader,
"Gemma2ForCausalLM": gemma_dtensor_weight_loader,
"Qwen2ForCausalLM": qwen2_dtensor_weight_loader,
"DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader,
"Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader,
"Qwen2_5_VLForConditionalGeneration": qwen2vl_dtensor_weight_loader,
}
# the actor model is .state_dict()
# Load dtensor weights
def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__:
return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(
f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}"
)
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
def update_dtensor_weight_loader():
pass
...@@ -29,11 +29,11 @@ from tensordict import TensorDict ...@@ -29,11 +29,11 @@ from tensordict import TensorDict
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm import LLM, RequestOutput, SamplingParams from vllm import LLM, RequestOutput, SamplingParams
from ....protocol import DataProto from ...protocol import DataProto
from ....utils import torch_functional as VF from ...utils import torch_functional as VF
from ....utils.torch_dtypes import PrecisionType from ...utils.torch_dtypes import PrecisionType
from ..base import BaseRollout from .base import BaseRollout
from ..config import RolloutConfig from .config import RolloutConfig
def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]: def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
...@@ -59,9 +59,6 @@ class vLLMRollout(BaseRollout): ...@@ -59,9 +59,6 @@ class vLLMRollout(BaseRollout):
if config.tensor_parallel_size > torch.distributed.get_world_size(): if config.tensor_parallel_size > torch.distributed.get_world_size():
raise ValueError("Tensor parallelism size should be less than world size.") raise ValueError("Tensor parallelism size should be less than world size.")
if not config.enforce_eager and config.free_cache_engine:
raise ValueError("CUDA graph should be disabled when `free_cache_engine` is True.")
if config.max_num_batched_tokens < config.prompt_length + config.response_length: if config.max_num_batched_tokens < config.prompt_length + config.response_length:
raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.") raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")
...@@ -84,6 +81,7 @@ class vLLMRollout(BaseRollout): ...@@ -84,6 +81,7 @@ class vLLMRollout(BaseRollout):
disable_mm_preprocessor_cache=True, disable_mm_preprocessor_cache=True,
disable_log_stats=config.disable_log_stats, disable_log_stats=config.disable_log_stats,
enable_chunked_prefill=config.enable_chunked_prefill, enable_chunked_prefill=config.enable_chunked_prefill,
seed=self.rank // config.tensor_parallel_size, # dp rank
**vllm_init_kwargs, **vllm_init_kwargs,
) )
...@@ -171,7 +169,7 @@ class vLLMRollout(BaseRollout): ...@@ -171,7 +169,7 @@ class vLLMRollout(BaseRollout):
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11] # position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[..., -1:] + delta_position_id response_position_ids = position_ids[..., -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1) position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_mask = VF.get_eos_mask( response_mask = VF.get_response_mask(
response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
) )
attention_mask = torch.cat((attention_mask, response_mask), dim=-1) attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment