Commit 02689420 authored by xuxz's avatar xuxz
Browse files

Merge branch 'v0.9.2-dev' into 'v0.9.2-dev-add_connector'

# Conflicts:
#   vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
parents ef362942 fa683b07
......@@ -235,12 +235,15 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("qa_kva_proj", "kv_a_proj_with_mqa", 1)
]
stacked_params_mapping += fused_params_mapping
enable_shared_experts_fusion = envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION and self.config.n_shared_experts > 0
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
num_experts=self.config.n_routed_experts + (
self.config.n_shared_experts
if enable_shared_experts_fusion else 0
))
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
......@@ -251,6 +254,16 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
# Assuame num of shared experts is only one.
if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION \
and self.config.n_shared_experts > 0 \
and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
......@@ -273,7 +286,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
param = params_dict[name]
weight_loader = param.weight_loader
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
if envs.USE_FUSED_RMS_QUANT \
and envs.VLLM_USE_FUSED_QA_KVA_GEMM \
and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
weight_loader(param, loaded_weight, old_weight_name)
else:
weight_loader(param, loaded_weight, shard_id)
......@@ -361,340 +376,4 @@ class DeepSeekMTP(nn.Module, SupportsPP):
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
return name
# # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# import os
# import re
# from collections.abc import Iterable
# from typing import Iterable, Optional
# import torch
# import torch.nn as nn
# from transformers import PretrainedConfig
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
# from vllm.model_executor.layers.fused_moe import FusedMoE
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
# from vllm.model_executor.layers.quantization import QuantizationConfig
# from vllm.model_executor.layers.vocab_parallel_embedding import (
# ParallelLMHead, VocabParallelEmbedding)
# from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors
# from vllm.compilation.decorators import support_torch_compile
# from .deepseek_v2 import (DeepseekV2DecoderLayer,
# get_spec_layer_idx_from_weight_name)
# from .interfaces import SupportsPP
# from .utils import maybe_prefix
# from vllm import _custom_ops as ops
# from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
# class SharedHead(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# return self.norm(hidden_states)
# class DeepSeekMultiTokenPredictorLayer(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# prefix: str,
# model_config: ModelConfig,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.eh_proj = nn.Linear(config.hidden_size * 2,
# config.hidden_size,
# bias=False)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
# self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
# cache_config, quant_config)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_index: int = 0,
# ) -> torch.Tensor:
# assert inputs_embeds is not None
# # masking inputs at position 0, as not needed by MTP
# inputs_embeds[positions == 0] = 0
# inputs_embeds = self.enorm(inputs_embeds)
# previous_hidden_states = self.hnorm(previous_hidden_states)
# hidden_states = self.eh_proj(
# torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
# hidden_states, residual = self.mtp_block(positions=positions,
# hidden_states=hidden_states,
# residual=None)
# hidden_states = residual + hidden_states
# return hidden_states
# class DeepSeekMultiTokenPredictor(nn.Module):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# config = vllm_config.model_config.hf_config
# self.mtp_start_layer_idx = config.num_hidden_layers
# self.num_mtp_layers = config.num_nextn_predict_layers
# # to map the exact layer index from weights
# self.layers = torch.nn.ModuleDict({
# str(idx):
# DeepSeekMultiTokenPredictorLayer(
# config,
# f"{prefix}.layers.{idx}",
# model_config=vllm_config.model_config,
# cache_config=vllm_config.cache_config,
# quant_config=vllm_config.quant_config,
# )
# for idx in range(self.mtp_start_layer_idx,
# self.mtp_start_layer_idx + self.num_mtp_layers)
# })
# self.embed_tokens = VocabParallelEmbedding(
# config.vocab_size,
# config.hidden_size,
# )
# self.logits_processor = LogitsProcessor(config.vocab_size)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
# input_ids,
# positions,
# previous_hidden_states,
# inputs_embeds,
# current_step_idx,
# )
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# mtp_layer = self.layers[str(self.mtp_start_layer_idx +
# current_step_idx)]
# logits = self.logits_processor(mtp_layer.shared_head.head,
# mtp_layer.shared_head(hidden_states),
# sampling_metadata)
# return logits
# @support_torch_compile
# class DeepSeekMTP(nn.Module, SupportsPP):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# self.config = vllm_config.model_config.hf_config
# quant_config = vllm_config.quant_config
# self.quant_method = None
# if quant_config is not None:
# self.quant_method = quant_config.get_name()
# os.environ['LLAMA_NN'] = '0'
# os.environ['LM_NN'] = '0'
# # The AWQ layer of MTP uses BlockInt8W8A8.
# if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
# vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
# self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
# prefix=maybe_prefix(
# prefix, "model"))
# self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# intermediate_tensors: Optional[IntermediateTensors] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# hidden_states = self.model(input_ids, positions,
# previous_hidden_states, inputs_embeds,
# spec_step_idx)
# return hidden_states
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> Optional[torch.Tensor]:
# return self.model.compute_logits(hidden_states, sampling_metadata,
# spec_step_idx)
# def load_weights(self, weights: Iterable[tuple[str,
# torch.Tensor]]) -> set[str]:
# stacked_params_mapping = [
# ("gate_up_proj", "gate_proj", 0),
# ("gate_up_proj", "up_proj", 1),
# ]
# expert_params_mapping = FusedMoE.make_expert_params_mapping(
# ckpt_gate_proj_name="gate_proj",
# ckpt_down_proj_name="down_proj",
# ckpt_up_proj_name="up_proj",
# num_experts=self.config.n_routed_experts)
# params_dict = dict(self.named_parameters())
# loaded_params: set[str] = set()
# for name, loaded_weight in weights:
# if "rotary_emb.inv_freq" in name:
# continue
# spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
# if spec_layer is None:
# continue
# name = self._rewrite_spec_layer_name(spec_layer, name)
# for (param_name, weight_name, shard_id) in stacked_params_mapping:
# # Skip non-stacked layers and experts (experts handled below).
# if weight_name not in name:
# continue
# # We have mlp.experts[0].gate_proj in the checkpoint.
# # Since we handle the experts below in expert_params_mapping,
# # we need to skip here BEFORE we update the name, otherwise
# # name will be updated to mlp.experts[0].gate_up_proj, which
# # will then be updated below in expert_params_mapping
# # for mlp.experts[0].gate_gate_up_proj, which breaks load.
# if (("mlp.experts." in name) and name not in params_dict):
# continue
# name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param, loaded_weight, shard_id)
# break
# else:
# for mapping in expert_params_mapping:
# param_name, weight_name, expert_id, shard_id = mapping
# if weight_name not in name:
# continue
# name = name.replace(weight_name, param_name)
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param,
# loaded_weight,
# name,
# shard_id=shard_id,
# expert_id=expert_id)
# break
# else:
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# # According to DeepSeek-V3 Technical Report, MTP modules
# # shares embedding layer. We only load the first weights.
# if (spec_layer != self.model.mtp_start_layer_idx
# and ".layers" not in name):
# continue
# param = params_dict[name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# weight_loader(param, loaded_weight)
# loaded_params.add(name)
# if self.use_llama_nn and self.quant_method is None:
# lay_key_words = [
# "self_attn.eh_proj.weight",
# "self_attn.q_proj.weight",
# "self_attn.q_a_proj.weight",
# "self_attn.q_b_proj.weight",
# "self_attn.kv_a_proj_with_mqa.weight",
# "self_attn.kv_b_proj.weight",
# "self_attn.o_proj.weight",
# "mlp.gate_up_proj.weight",
# "mlp.down_proj.weight",
# "mlp.gate.weight",
# "shared_experts.gate_up_proj.weight",
# "shared_experts.down_proj.weight",
# "shared_head.head.weight",
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# _weight = torch.zeros_like(weight.data)
# ori_shape =_weight.shape
# ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
# weight.data.copy_(_weight)
# weight.data=weight.data.reshape(ori_shape[1],-1)
# return loaded_params
# def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
# """
# Rewrite the weight name to match the format of the original model.
# Add .mtp_block for modules in transformer layer block for spec layer
# and rename shared layer weights to be top level.
# """
# spec_layer_weight_names = [
# "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
# ]
# shared_weight_names = ["embed_tokens"]
# spec_layer_weight = False
# shared_weight = False
# for weight_name in spec_layer_weight_names:
# if weight_name in name:
# spec_layer_weight = True
# if weight_name in shared_weight_names:
# shared_weight = True
# break
# if not spec_layer_weight:
# # treat rest weights as weights for transformer layer block
# name = name.replace(f"model.layers.{spec_layer}.",
# f"model.layers.{spec_layer}.mtp_block.")
# elif shared_weight:
# # treat shared weights as top level weights
# name = name.replace(f"model.layers.{spec_layer}.", "model.")
# return name
return name
\ No newline at end of file
......@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.dp_attention import (dp_gather, dp_reduce_scatter_tensor,
get_moe_tp_size, get_moe_tp_rank, get_attention_tp_size)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -70,8 +72,8 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix)
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from lmslim.quantize.quant_ops import lm_faster_rmsquant
class DeepseekV2MLP(nn.Module):
def __init__(
......@@ -84,38 +86,40 @@ class DeepseekV2MLP(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
enable_dp_attn_moe=enable_dp_attention)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
reduce_results=reduce_results if not enable_dp_attention else False,
prefix=f"{prefix}.down_proj",
enable_dp_attn_moe=enable_dp_attention)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = False,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor:
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, i_q, _scales, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
assert iqis is not None
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi, i_q, _scales
return x
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
if envs.USE_FUSED_SILU_MUL_QUANT:
......@@ -180,33 +184,19 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
dp_size = get_dp_group().world_size
self.enable_expert_parallel = parallel_config.enable_expert_parallel
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if not self.use_deepep:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
self.enable_shared_experts_overlap = False
self.enable_shared_experts_fusion = (self.n_shared_experts != 0 and envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION)
if config.n_shared_experts is not None:
if not self.use_deepep:
if config.n_shared_experts is not None and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekV2MLP(
......@@ -214,10 +204,54 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
reduce_results = False,
prefix=f"{prefix}.shared_experts",
)
self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
and config.n_shared_experts is not None)
if self.enable_shared_experts_overlap:
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
num_fake_experts = (
config.n_shared_experts
if self.enable_shared_experts_fusion
else 0)
self.experts = FusedMoE(
num_experts=config.n_routed_experts + (num_fake_experts),
top_k=config.num_experts_per_tok + min(num_fake_experts, 1),
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
......@@ -249,22 +283,61 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor,
shared_experts=self.shared_experts)
self.run_shared_expert_singlely = (
self.n_shared_experts is not None
and not self.enable_shared_experts_overlap
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None # iq = input quant, is = input scale
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# For shared experts overlap optimization.
def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
hidden_states_copy = hidden_states.clone()
return self.experts(
hidden_states = hidden_states,
router_logits = router_logits,
hidden_states_copy = hidden_states_copy,
i_q = i_q,
i_s = i_s)
# For shared experts fusion optimization.
def shared_exprts_fusion_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
return self.experts(
hidden_states = hidden_states,
router_logits = router_logits,
i_q = i_q,
i_s = i_s)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
if self.n_shared_experts is not None:
if self.n_shared_experts is not None and not self.enable_shared_experts_overlap:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
......@@ -273,63 +346,30 @@ class DeepseekV2MoE(nn.Module):
router_logits=router_logits,
shared_output=shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
if self.enable_shared_experts_fusion:
final_hidden_states = shared_exprts_fusion_pass(hidden_states, router_logits)
elif self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
else:
if not self.enable_expert_parallel:
i_q, i_s = None, None
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
final_hidden_states *= self.routed_scaling_factor
final_hidden_states += shared_output
else:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output,
i_q=i_q, i_s=i_s)
assert shared_output is not None
final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
......@@ -338,11 +378,68 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
else: # RQ
if not self.enable_expert_parallel:
i_q, i_s = None, None
if iqis is not None:
i_q, i_s = iqis
if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states, iqis=iqis)
else:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
if self.enable_shared_experts_fusion:
final_hidden_states = shared_exprts_fusion_pass(hidden_states, router_logits, iqis = iqis)
elif self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits, iqis = iqis)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
final_hidden_states += shared_output
else:
assert shared_output is not None
final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
else:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output,
i_q=i_q, i_s=i_s)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else: # EP
router_logits, _ = self.gate(hidden_states)
if self.use_deepep:
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
......@@ -354,37 +451,48 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
if self.n_shared_experts is not None:
if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
shared_output = self.shared_experts(hidden_states, iqis=iqis)
else:
shared_output = self.shared_experts(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states += shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
assert shared_output is not None
final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
else:
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
return final_hidden_states.view(num_tokens, hidden_dim)
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
......@@ -546,7 +654,7 @@ class DeepseekV2MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
"""
......@@ -602,7 +710,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.q_lora_rank,
bias=False,
quant_config=quant_config,
eps=config.rms_norm_eps,
prefix=f"{prefix}.q_a_proj")
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
......@@ -623,7 +730,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
......@@ -697,27 +804,24 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
# TODO wjl: 这里的forward拆了
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True
update_input: Optional[bool] = True,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
if self.q_lora_rank is not None:
qc_kvc_kpe, new_residual, _bias = self.qa_kva_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
qc_kvc_kpe, _bias = self.qa_kva_proj(hidden_states, iqis)
q_c = qc_kvc_kpe[:, :self.q_lora_rank]
kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
......@@ -727,15 +831,15 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = kvc_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else:
if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
q_c, _ = self.q_a_proj(hidden_states, iqis=iqis)
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else:
q = self.q_proj(hidden_states)[0]
kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0]
kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, iqis=iqis)[0]
kv_c, k_pe = kvc_kpe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
......@@ -763,19 +867,21 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache)
return self.o_proj(attn_out)[0], new_residual
return self.o_proj(attn_out)[0]
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
......@@ -788,7 +894,7 @@ class DeepseekV2MLAAttention(nn.Module):
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
......@@ -811,19 +917,21 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache)
packages_ = self.o_proj(attn_out,
packages_ = self.o_proj(attn_out,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
......@@ -870,13 +978,15 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
......@@ -912,10 +1022,13 @@ class DeepseekV2DecoderLayer(nn.Module):
parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.tp_size = get_tensor_model_parallel_world_size()
self.config = config
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
......@@ -935,6 +1048,8 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.mlp",
)
self.enable_ep_sp = isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1
self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True
......@@ -943,6 +1058,9 @@ class DeepseekV2DecoderLayer(nn.Module):
DeepseekV2MoE) and self.use_deepep and \
self.tp_size > 1 and not self.is_mtp_layer:
reduce_results = False
else:
if self.enable_dp_attention:
reduce_results = False
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
......@@ -974,10 +1092,11 @@ class DeepseekV2DecoderLayer(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor
self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT
self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
self._eps = config.rms_norm_eps
def forward_fused_rmsquant(
def forward_fused_RQ(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
......@@ -985,25 +1104,28 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
# Fix residual FP16 overflow
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None:
residual = hidden_states
hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True
i_q, i_s = lm_faster_rmsquant(input=hidden_states,
rms_weight = self.input_layernorm.weight.data,
epsilon=self._eps,
quant_dtype=torch.int8,
residual=None,
update_input=False)
else:
hidden_states, new_residual = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = residual
)
residual = new_residual
i_q, i_s = lm_faster_rmsquant(input=hidden_states,
rms_weight = self.input_layernorm.weight.data,
epsilon=self._eps,
quant_dtype=torch.int8,
residual=residual,
update_input=False)
hidden_states = self.self_attn(positions=positions,
hidden_states = hidden_states, # get attr
iqis=(i_q, i_s))
if hidden_states.dtype == torch.float16:
# rmsnorm, and rmsnorm result would not affect by scale.
......@@ -1012,11 +1134,17 @@ class DeepseekV2DecoderLayer(nn.Module):
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi, _i_q, _scales = self.mlp(hidden_states,
rms_weight=self.post_attention_layernorm.weight.data,
residual=residual,
)
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
_i_q, _i_s = lm_faster_rmsquant(input=hidden_states,
rms_weight=self.post_attention_layernorm.weight.data,
epsilon=self._eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hs)
new_resi = residual
hidden_states = self.mlp(hidden_states,
iqis=(_i_q, _i_s))
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......@@ -1029,9 +1157,9 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, new_resi
def forward_fused_CRQ(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
residual_fix_overflow = False
......@@ -1042,33 +1170,33 @@ class DeepseekV2DecoderLayer(nn.Module):
else:
hidden_states, resi_new = self.input_layernorm(
hidden_states, residual)
residual = resi_new
residual = resi_new
new_hs, new_resi, xq, xs = self.self_attn(
positions=positions,
hidden_states=hidden_states,
pa_rms_weight=self.post_attention_layernorm.weight.data,
pa_rms_weight=self.post_attention_layernorm.weight.data,
pa_residual=residual,
pa_rms_eps=self.post_attention_layernorm.variance_epsilon,
pa_quant_dtype = torch.int8,
update_input=True
)
assert xq is not None and xs is not None
if new_hs.dtype == torch.float16: # overflow处理逻辑
new_hs *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
new_resi *= 1. / self.routed_scaling_factor
hidden_states = self.mlp(new_hs, xqxs=(xq, xs))
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
def forward_default(
self,
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
......@@ -1083,26 +1211,26 @@ class DeepseekV2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
if not self.is_mtp_layer and self.enable_ep_sp and \
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
if not self.is_mtp_layer and self.enable_ep_sp:
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention:
if self.tp_rank == 0:
hidden_states += residual
hidden_states = dp_gather(hidden_states)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
......@@ -1115,27 +1243,31 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
if not self.enable_dp_attention:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
else:
num_tokens = hidden_states.shape[0]
new_bs = num_tokens // get_moe_tp_size() * get_attention_tp_size()
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer and self.enable_ep_sp:
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states, [0, 0, 0, pad_size], value=0)
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer and self.enable_ep_sp:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......@@ -1147,10 +1279,10 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
def choose_forward(self):
if self.use_fused_rms_quant:
return self.forward_fused_rmsquant
return self.forward_fused_RQ
elif self.use_fused_custom_all_reduce:
return self.forward_fused_CRQ
......@@ -1212,13 +1344,14 @@ class DeepseekV2Model(nn.Module):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.tp_size = get_tensor_model_parallel_world_size()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
......@@ -1312,10 +1445,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.tritonsingleton= W8a8GetCacheJSON()
self.tritonsingleton= W8a8GetCacheJSON()
self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method
......@@ -1371,22 +1504,22 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
dtype=dtype,
device=device),
})
def restore_qzeros_tensor(self, qzeros, qscales):
low_bits = qzeros & 0x0F
high_bits = qzeros >> 4
zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1])
zeors_int16 = zeors_tensor.to(torch.int16)
assert zeors_int16.shape == qscales.shape
uint16_tensor1 = zeors_int16.view(torch.uint16)
uint16_tensor2 = qscales.view(torch.uint16)
uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16
uint32_tensor2 = uint16_tensor2.to(torch.int32)
result_tensor = uint32_tensor1 + uint32_tensor2
result_tensor =result_tensor.view(torch.uint32)
result_tensor = result_tensor.transpose(1, 2).contiguous()
......@@ -1412,7 +1545,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts
+ (
self.num_shared_experts
if (envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION and self.num_shared_experts > 0)
else 0
),
num_redundant_experts=self.num_redundant_experts)
params_dict = dict(self.named_parameters())
......@@ -1425,6 +1563,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if spec_layer is not None:
continue # skip spec decode layers for main model
# Assuame num of shared experts is only one.
if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION \
and self.num_shared_experts > 0 \
and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
......@@ -1494,7 +1641,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......@@ -1515,7 +1662,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.q_proj.weight",
......@@ -1533,19 +1680,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
]
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params
......
......@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
FalconConfig = Union[HF_FalconConfig, RWConfig]
......@@ -393,7 +394,7 @@ class FalconModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
......
......@@ -31,6 +31,47 @@ from typing import Optional, Union
import torch
from torch import nn
from transformers import Glm4Config
import vllm.envs as envs
class MultiModalConfigProxy:
"""
Proxy class to handle both flat configs (e.g., Glm4Config) and
nested multimodal configs (e.g., Glm4vConfig with text_config).
For multimodal configs where attributes are in text_config, this proxy
transparently delegates attribute access to text_config when needed.
"""
def __init__(self, config):
# Store original config (for attributes that do exist at top level)
object.__setattr__(self, '_config', config)
def __getattr__(self, name):
# First try to get from the original config (works for flat configs)
try:
return getattr(self._config, name)
except AttributeError:
# If not found and config has text_config, try there
if hasattr(self._config, 'text_config'):
return getattr(self._config.text_config, name)
# Re-raise the original error if text_config doesn't have it either
raise AttributeError(
f"'{type(self._config).__name__}' object has no attribute '{name}'"
)
def __setattr__(self, name, value):
# Allow setting attributes on the proxy itself
if name == '_config':
object.__setattr__(self, name, value)
else:
setattr(self._config, name, value)
def __hasattr__(self, name):
return hasattr(self._config, name) or (
hasattr(self._config, 'text_config') and
hasattr(self._config.text_config, name)
)
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
......@@ -151,6 +192,9 @@ class Glm4DecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
......@@ -177,14 +221,11 @@ class Glm4DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_self_attn_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
rms_norm_eps = getattr(config, 'rms_norm_eps', 1e-5)
self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
self.post_self_attn_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
self.post_mlp_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
def forward(
self,
......@@ -254,6 +295,9 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.config = config
self.lora_config = lora_config
......@@ -289,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......
......@@ -33,6 +33,46 @@ from transformers import LlamaConfig
import os
import re
class MultiModalConfigProxy:
"""
Proxy class to handle both flat configs (e.g., LlamaConfig) and
nested multimodal configs (e.g., Glm4vConfig with text_config).
For multimodal configs where attributes are in text_config, this proxy
transparently delegates attribute access to text_config when needed.
"""
def __init__(self, config):
# Store original config (for attributes that do exist at top level)
object.__setattr__(self, '_config', config)
def __getattr__(self, name):
# First try to get from the original config (works for flat configs)
try:
return getattr(self._config, name)
except AttributeError:
# If not found and config has text_config, try there
if hasattr(self._config, 'text_config'):
return getattr(self._config.text_config, name)
# Re-raise the original error if text_config doesn't have it either
raise AttributeError(
f"'{type(self._config).__name__}' object has no attribute '{name}'"
)
def __setattr__(self, name, value):
# Allow setting attributes on the proxy itself
if name == '_config':
object.__setattr__(self, name, value)
else:
setattr(self._config, name, value)
def __hasattr__(self, name):
return hasattr(self._config, name) or (
hasattr(self._config, 'text_config') and
hasattr(self._config.text_config, name)
)
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
......@@ -246,6 +286,9 @@ class LlamaDecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
......@@ -340,6 +383,9 @@ class LlamaModel(nn.Module):
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.config = config
self.quant_config = quant_config
lora_vocab = (lora_config.lora_extra_vocab_size *
......@@ -587,6 +633,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.config = config
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config,
......
......@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
......@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -136,22 +189,49 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
if envs.VLLM_USE_FUSED_RMS_ROPE:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
# Add qk-norm then RoPE (original path).
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......
......@@ -38,6 +38,19 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
try:
from vllm.model_executor.layers.fused_moe.router_capture import (
maybe_record_router_logits,
)
except ImportError:
def maybe_record_router_logits(
*,
layer_name: str,
router_logits: torch.Tensor,
top_k: int,
) -> None:
return None
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -111,6 +124,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self._router_top_k = int(config.num_experts_per_tok)
self._router_capture_layer_name = prefix
if self.tp_size > config.num_experts:
raise ValueError(
......@@ -140,6 +155,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if not (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()):
capture_enabled = envs.VLLM_MOE_ROUTER_CAPTURE
if capture_enabled:
maybe_record_router_logits(
layer_name=self._router_capture_layer_name,
router_logits=router_logits,
top_k=self._router_top_k,
)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
......@@ -453,7 +476,7 @@ class Qwen3MoeModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......
......@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
class TeleChat2Model(LlamaModel):
......@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
......
......@@ -10,6 +10,7 @@ from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
......@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self._output_dim = output_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property
......@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
......@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
......@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
self._input_dim = input_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property
def input_dim(self):
......@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
......
......@@ -17,7 +17,7 @@ from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from vllm.utils import SUPPORT_TC
if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
......
......@@ -1958,7 +1958,7 @@ class W8a8GetCacheJSON:
self.moe_weight_shapes=[]
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
self.cache_json_data = {}
device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name
self.topk=1
......@@ -2060,19 +2060,27 @@ class W8a8GetCacheJSON:
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False,use_int8_w8a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
elif use_int8_w8a8:
if block_size is not None:
return self.triton_json_dir + f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir + f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_BLOCKFP8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_W8A8FP8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
if file_path in self.cache_json_data:
# 直接返回缓存数据,避免重复读取
return self.cache_json_data[file_path]
cache_json_file=file_path
if os.path.exists(file_path):
......@@ -2088,7 +2096,7 @@ class W8a8GetCacheJSON:
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
self.cache_json_data[file_path] = configs_dict
return configs_dict
# Adapted from: https://stackoverflow.com/a/47212782/5082708
......
......@@ -136,6 +136,17 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.float8_e4m3fn
else:
raise ValueError(f"Unsupported FP8 dtype: {kv_cache_dtype}")
elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class FlashAttentionMetadata:
......@@ -589,14 +600,19 @@ class FlashAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# key_cache = key_cache.view(torch.float8_e4m3fn)
# value_cache = value_cache.view(torch.float8_e4m3fn)
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
if envs.VLLM_USE_QUERY_QUANT:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
......@@ -620,9 +636,10 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
# descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
if not current_platform.is_rocm():
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
......@@ -672,6 +689,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache=True,
)
......@@ -729,6 +749,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
return output
......@@ -879,12 +902,12 @@ def cascade_attention(
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
is_prefix_cache=True,
)
......@@ -932,12 +955,12 @@ def cascade_attention(
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
is_prefix_cache=True,
)
......
......@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous
from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_qkv
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -1213,6 +1213,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q_ori = q_ori[:num_actual_toks, ...]
decode_q = q_ori[:num_decode_tokens]
prefill_q = q_ori[num_decode_tokens:]
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
elif q.dtype == torch.bfloat16:
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
......@@ -1226,28 +1234,61 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale,
)
else:
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
elif q.dtype == torch.bfloat16:
kv_cache_dtype_str = "bf16"
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if has_prefill:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
fuse_rmsnorm_rope_quant_qkv(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
q_scale,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
kv_cache_dtype_str = self.kv_cache_dtype
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode:
assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and self.kv_cache_dtype=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
decode_q = q_quant[:num_decode_tokens]
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# todo: bmm support
decode_ql_nope = torch.bmm(q_scale, decode_q_nope, self.W_UK_T) if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and self.kv_cache_dtype=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA else torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
......
......@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
flash_mla_with_kvcache_fp8_with_cat,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm.logger import init_logger
......@@ -181,31 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
if envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
o, _ = flash_mla_with_kvcache_fp8_with_cat(
q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q = q_scale,
descale_k = k_scale,
)
else:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
else:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
......
......@@ -281,20 +281,27 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
if new_blocks is None:
if self.use_pp:
preemptable_reqs = [r for r in self.running if
r.num_tokens_with_spec != r.num_computed_tokens]
else:
preemptable_reqs = self.running
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
preemptable_reqs,
key=lambda r: (r.priority, r.arrival_time),
)
else:
preempted_req = preemptable_reqs[-1]
self.running.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.spec_token_ids = []
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......@@ -901,20 +908,26 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
if self.use_pp:
preemptable_reqs = [r for r in self.running if
r.num_tokens_with_spec != r.num_computed_tokens]
else:
preemptable_reqs = self.running
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
preemptable_reqs,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
else:
preempted_req = self.running.pop()
preempted_req = preemptable_reqs[-1]
self.running.remove(preempted_req)
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.spec_token_ids = []
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......@@ -1051,14 +1064,15 @@ class Scheduler(SchedulerInterface):
def schedule(self) -> SchedulerOutput:
if envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd()
else:
if self.connector is not None:
return self.schedule_default()
if self.full_cuda_graph and self.use_mla and self.num_spec_tokens > 0 :
return self.schedule_split_pd()
if self.use_mla:
if self.full_cuda_graph and self.num_spec_tokens > 0:
return self.schedule_split_pd()
else:
return self.schedule_default()
else:
return self.schedule_default()
return self.schedule_split_pd()
else:
return self.schedule_default()
def _update_after_schedule(
self,
......@@ -1101,13 +1115,14 @@ class Scheduler(SchedulerInterface):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = req.num_generated_token_ids
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[-num_tokens:]
token_ids = req.all_token_ids[-num_tokens:] if num_tokens > 0 else []
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
......@@ -1241,7 +1256,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
request.num_generated_token_ids = len(generated_token_ids)
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
......@@ -1253,7 +1268,6 @@ class Scheduler(SchedulerInterface):
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
......
......@@ -12,8 +12,9 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.forward_context import DPMetadata, set_forward_context, get_warming_up
from vllm.logger import init_logger
import vllm.envs as envs
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
......@@ -92,6 +93,12 @@ class EagleProposer:
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.attn_tp_size = vllm_config.parallel_config.tensor_parallel_size
self.ep_sp = False
if self.enable_expert_parallel and self.dp_size > 1 and self.attn_tp_size > 1:
self.ep_sp = True
def propose(
self,
......@@ -187,6 +194,12 @@ class EagleProposer:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
# num_input_tokens += num_pad
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
......@@ -275,6 +288,13 @@ class EagleProposer:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......@@ -369,6 +389,7 @@ class EagleProposer:
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
......@@ -516,34 +537,118 @@ class EagleProposer:
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
# For DP: Don't pad when setting enforce_eager.
# This lets us set enforce_eager on the prefiller in a P/D setup and
# still use CUDA graphs (enabled by this padding) on the decoder.
#
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# Early exit.
return 0, None
try:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
dp_size,
device="cpu",
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
except (RuntimeError, AttributeError) as e:
# DP group may not be initialized yet during dummy run
# Skip padding in this case
logger.debug(
"Skipping DP padding in eagle get_dp_padding due to: %s", e)
return 0, None
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
) -> None:
if attn_metadata is not None and self.attn_metadata_cudagraph is None:
self.attn_metadata_cudagraph = attn_metadata[
self.attn_layer_names[0]]
# Padding for DP
num_input_tokens = num_tokens
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
# num_input_tokens += num_pad
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
)
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1:
for _ in range(self.num_speculative_tokens - 1):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
if self.dp_size > 1 and (self.enable_expert_parallel or self.enable_dp_attention) and self.num_speculative_tokens > 1:
num_tokens = 1
if self.enable_dp_attention or self.ep_sp:
num_tokens = round_up(num_tokens, self.attn_tp_size)
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
if not get_warming_up():
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_tokens + 1],
seq_lens=self.runner.seq_lens[:num_tokens],
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
slot_mapping=self.runner.slot_mapping[:num_tokens],
spec_layer_decoding=True
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build_for_cudagraph_capture(
common_attn_metadata=common_attn_metadata
)
for i in range(self.num_speculative_tokens - 1):
if self.attn_metadata_cudagraph is not None:
if i == 0:
attn_metadata_cudagraph = self.attn_metadata_cudagraph
attn_metadata_cudagraph.num_actual_tokens = num_tokens
attn_metadata_cudagraph.num_decodes = num_tokens
attn_metadata_cudagraph.num_decode_tokens = num_tokens
attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
attn_metadata_cudagraph.decode.seq_lens[:num_tokens] = (
attn_metadata.decode.seq_lens)
attn_metadata_cudagraph.query_start_loc[:num_tokens + 1] = (
attn_metadata.query_start_loc)
attn_metadata_cudagraph.decode.block_table[:num_tokens] = (
attn_metadata.decode.block_table)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
......
......@@ -735,22 +735,35 @@ class InputBatch:
self, repeat_counts: torch.Tensor
) -> SamplingMetadata:
num_reqs = self.num_reqs
repeat_counts_cpu = repeat_counts
# `repeat_counts` is expected to be a CPU torch tensor, but some
# call sites may pass a NumPy array (or other array-likes). Normalize
# to a CPU tensor to keep downstream ops (e.g. repeat_interleave)
# consistent and avoid hard crashes.
if isinstance(repeat_counts, torch.Tensor):
repeat_counts_cpu = repeat_counts.to(device="cpu")
else:
repeat_counts_cpu = torch.as_tensor(repeat_counts, device="cpu")
all_greedy = self.all_greedy
all_random = self.all_random
# For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact.
def _expand_cpu_to_gpu(
t: Optional[torch.Tensor],
t: Optional[object],
*,
dtype: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]:
if t is None:
return None
base = t[:num_reqs]
if repeat_counts_cpu is not None:
base = base.repeat_interleave(repeat_counts_cpu, dim=0)
# `t` should be a CPU torch tensor, but can be a NumPy array view
# (e.g. created via `tensor.numpy()`). Convert if needed.
if isinstance(t, torch.Tensor):
base = t[:num_reqs]
elif isinstance(t, np.ndarray):
base = torch.from_numpy(t[:num_reqs])
else:
base = torch.as_tensor(t, device="cpu")[:num_reqs]
base = base.repeat_interleave(repeat_counts_cpu, dim=0)
return base.to(device=self.device,
dtype=dtype if dtype is not None else None,
non_blocking=True)
......
......@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True
self.enable_dp_attention = self.parallel_config.enable_dp_attention
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
......@@ -512,14 +514,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
num_new_tokens = len(new_token_ids)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
new_token_ids)
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
......@@ -545,20 +545,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)
if resumed_from_preemption:
self.input_batch.block_table.add_row(new_block_ids, req_index)
else:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = new_token_ids[-1]
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
if len(new_token_ids) > 0:
end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = new_token_ids[-1]
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
if spec_token_ids:
......@@ -1274,8 +1277,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
#
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# Early exit.
return 0, None
......@@ -1354,7 +1356,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......@@ -1597,6 +1599,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if req_state.output_token_ids:
continue
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
......@@ -1674,7 +1681,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata,
attn_metadata,
)
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
......@@ -2084,7 +2093,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
- during profile_run
- during DP rank dummy run
- during DP rank dummy run
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
......@@ -2115,13 +2124,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
skip_eplb: bool = False,
is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
if num_tokens < self.tp_size:
num_tokens = self.tp_size
# Padding for DP
num_tokens_across_dp = 0
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
......@@ -2142,13 +2150,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
if not self.ep_sp:
if not (self.ep_sp or self.enable_dp_attention):
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else:
if self.speculative_config is not None:
......@@ -2240,7 +2248,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle() and not is_profile:
#assert isinstance(self.drafter, EagleProposer)
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens, attn_metadata)
self.drafter.dummy_run(num_tokens, attn_metadata,
num_tokens_across_dp=num_tokens_across_dp)
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
......@@ -3213,7 +3222,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......@@ -3455,6 +3464,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if req_state.output_token_ids:
continue
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
......@@ -3476,7 +3490,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states[:num_scheduled_tokens],
scheduler_output,
)
#-----------------------------------
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
......@@ -3516,6 +3529,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
......@@ -3680,4 +3696,4 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if envs.VLLM_USE_ZERO_MTP:
GPUModelRunner=GPUModelRunnerMTP
else:
GPUModelRunner=GPUModelRunnerBase
\ No newline at end of file
GPUModelRunner=GPUModelRunnerBase
......@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.dp_attention import initialize_dp_attention
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
......@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
from vllm.forward_context import (set_warming_up, get_warming_up)
logger = init_logger(__name__)
......@@ -260,6 +262,7 @@ class Worker(WorkerBase):
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
set_warming_up(True)
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
......@@ -297,6 +300,7 @@ class Worker(WorkerBase):
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
set_warming_up(False)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
......@@ -399,6 +403,9 @@ def init_worker_distributed_environment(
ensure_kv_transfer_initialized(vllm_config)
if vllm_config.parallel_config.enable_dp_attention:
initialize_dp_attention(vllm_config, backend)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
......
......@@ -112,6 +112,12 @@ class V1ZeroEagleProposer(EagleProposer):
else:
num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
# num_input_tokens += num_pad
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
......@@ -199,6 +205,13 @@ class V1ZeroEagleProposer(EagleProposer):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment