Commit 02689420 authored by xuxz's avatar xuxz
Browse files

Merge branch 'v0.9.2-dev' into 'v0.9.2-dev-add_connector'

# Conflicts:
#   vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
parents ef362942 fa683b07
...@@ -235,12 +235,15 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -235,12 +235,15 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("qa_kva_proj", "kv_a_proj_with_mqa", 1) ("qa_kva_proj", "kv_a_proj_with_mqa", 1)
] ]
stacked_params_mapping += fused_params_mapping stacked_params_mapping += fused_params_mapping
enable_shared_experts_fusion = envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION and self.config.n_shared_experts > 0
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts) num_experts=self.config.n_routed_experts + (
self.config.n_shared_experts
if enable_shared_experts_fusion else 0
))
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
...@@ -251,6 +254,16 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -251,6 +254,16 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if spec_layer is None: if spec_layer is None:
continue continue
name = self._rewrite_spec_layer_name(spec_layer, name) name = self._rewrite_spec_layer_name(spec_layer, name)
# Assuame num of shared experts is only one.
if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION \
and self.config.n_shared_experts > 0 \
and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
...@@ -273,7 +286,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -273,7 +286,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)): if envs.USE_FUSED_RMS_QUANT \
and envs.VLLM_USE_FUSED_QA_KVA_GEMM \
and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
weight_loader(param, loaded_weight, old_weight_name) weight_loader(param, loaded_weight, old_weight_name)
else: else:
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -361,340 +376,4 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -361,340 +376,4 @@ class DeepSeekMTP(nn.Module, SupportsPP):
# treat rest weights as weights for transformer layer block # treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.", name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.") f"model.layers.{spec_layer}.mtp_block.")
return name return name
\ No newline at end of file
# # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# import os
# import re
# from collections.abc import Iterable
# from typing import Iterable, Optional
# import torch
# import torch.nn as nn
# from transformers import PretrainedConfig
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
# from vllm.model_executor.layers.fused_moe import FusedMoE
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
# from vllm.model_executor.layers.quantization import QuantizationConfig
# from vllm.model_executor.layers.vocab_parallel_embedding import (
# ParallelLMHead, VocabParallelEmbedding)
# from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors
# from vllm.compilation.decorators import support_torch_compile
# from .deepseek_v2 import (DeepseekV2DecoderLayer,
# get_spec_layer_idx_from_weight_name)
# from .interfaces import SupportsPP
# from .utils import maybe_prefix
# from vllm import _custom_ops as ops
# from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
# class SharedHead(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# return self.norm(hidden_states)
# class DeepSeekMultiTokenPredictorLayer(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# prefix: str,
# model_config: ModelConfig,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.eh_proj = nn.Linear(config.hidden_size * 2,
# config.hidden_size,
# bias=False)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
# self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
# cache_config, quant_config)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_index: int = 0,
# ) -> torch.Tensor:
# assert inputs_embeds is not None
# # masking inputs at position 0, as not needed by MTP
# inputs_embeds[positions == 0] = 0
# inputs_embeds = self.enorm(inputs_embeds)
# previous_hidden_states = self.hnorm(previous_hidden_states)
# hidden_states = self.eh_proj(
# torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
# hidden_states, residual = self.mtp_block(positions=positions,
# hidden_states=hidden_states,
# residual=None)
# hidden_states = residual + hidden_states
# return hidden_states
# class DeepSeekMultiTokenPredictor(nn.Module):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# config = vllm_config.model_config.hf_config
# self.mtp_start_layer_idx = config.num_hidden_layers
# self.num_mtp_layers = config.num_nextn_predict_layers
# # to map the exact layer index from weights
# self.layers = torch.nn.ModuleDict({
# str(idx):
# DeepSeekMultiTokenPredictorLayer(
# config,
# f"{prefix}.layers.{idx}",
# model_config=vllm_config.model_config,
# cache_config=vllm_config.cache_config,
# quant_config=vllm_config.quant_config,
# )
# for idx in range(self.mtp_start_layer_idx,
# self.mtp_start_layer_idx + self.num_mtp_layers)
# })
# self.embed_tokens = VocabParallelEmbedding(
# config.vocab_size,
# config.hidden_size,
# )
# self.logits_processor = LogitsProcessor(config.vocab_size)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
# input_ids,
# positions,
# previous_hidden_states,
# inputs_embeds,
# current_step_idx,
# )
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# mtp_layer = self.layers[str(self.mtp_start_layer_idx +
# current_step_idx)]
# logits = self.logits_processor(mtp_layer.shared_head.head,
# mtp_layer.shared_head(hidden_states),
# sampling_metadata)
# return logits
# @support_torch_compile
# class DeepSeekMTP(nn.Module, SupportsPP):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# self.config = vllm_config.model_config.hf_config
# quant_config = vllm_config.quant_config
# self.quant_method = None
# if quant_config is not None:
# self.quant_method = quant_config.get_name()
# os.environ['LLAMA_NN'] = '0'
# os.environ['LM_NN'] = '0'
# # The AWQ layer of MTP uses BlockInt8W8A8.
# if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
# vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
# self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
# prefix=maybe_prefix(
# prefix, "model"))
# self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# intermediate_tensors: Optional[IntermediateTensors] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# hidden_states = self.model(input_ids, positions,
# previous_hidden_states, inputs_embeds,
# spec_step_idx)
# return hidden_states
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> Optional[torch.Tensor]:
# return self.model.compute_logits(hidden_states, sampling_metadata,
# spec_step_idx)
# def load_weights(self, weights: Iterable[tuple[str,
# torch.Tensor]]) -> set[str]:
# stacked_params_mapping = [
# ("gate_up_proj", "gate_proj", 0),
# ("gate_up_proj", "up_proj", 1),
# ]
# expert_params_mapping = FusedMoE.make_expert_params_mapping(
# ckpt_gate_proj_name="gate_proj",
# ckpt_down_proj_name="down_proj",
# ckpt_up_proj_name="up_proj",
# num_experts=self.config.n_routed_experts)
# params_dict = dict(self.named_parameters())
# loaded_params: set[str] = set()
# for name, loaded_weight in weights:
# if "rotary_emb.inv_freq" in name:
# continue
# spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
# if spec_layer is None:
# continue
# name = self._rewrite_spec_layer_name(spec_layer, name)
# for (param_name, weight_name, shard_id) in stacked_params_mapping:
# # Skip non-stacked layers and experts (experts handled below).
# if weight_name not in name:
# continue
# # We have mlp.experts[0].gate_proj in the checkpoint.
# # Since we handle the experts below in expert_params_mapping,
# # we need to skip here BEFORE we update the name, otherwise
# # name will be updated to mlp.experts[0].gate_up_proj, which
# # will then be updated below in expert_params_mapping
# # for mlp.experts[0].gate_gate_up_proj, which breaks load.
# if (("mlp.experts." in name) and name not in params_dict):
# continue
# name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param, loaded_weight, shard_id)
# break
# else:
# for mapping in expert_params_mapping:
# param_name, weight_name, expert_id, shard_id = mapping
# if weight_name not in name:
# continue
# name = name.replace(weight_name, param_name)
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param,
# loaded_weight,
# name,
# shard_id=shard_id,
# expert_id=expert_id)
# break
# else:
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# # According to DeepSeek-V3 Technical Report, MTP modules
# # shares embedding layer. We only load the first weights.
# if (spec_layer != self.model.mtp_start_layer_idx
# and ".layers" not in name):
# continue
# param = params_dict[name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# weight_loader(param, loaded_weight)
# loaded_params.add(name)
# if self.use_llama_nn and self.quant_method is None:
# lay_key_words = [
# "self_attn.eh_proj.weight",
# "self_attn.q_proj.weight",
# "self_attn.q_a_proj.weight",
# "self_attn.q_b_proj.weight",
# "self_attn.kv_a_proj_with_mqa.weight",
# "self_attn.kv_b_proj.weight",
# "self_attn.o_proj.weight",
# "mlp.gate_up_proj.weight",
# "mlp.down_proj.weight",
# "mlp.gate.weight",
# "shared_experts.gate_up_proj.weight",
# "shared_experts.down_proj.weight",
# "shared_head.head.weight",
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# _weight = torch.zeros_like(weight.data)
# ori_shape =_weight.shape
# ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
# weight.data.copy_(_weight)
# weight.data=weight.data.reshape(ori_shape[1],-1)
# return loaded_params
# def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
# """
# Rewrite the weight name to match the format of the original model.
# Add .mtp_block for modules in transformer layer block for spec layer
# and rename shared layer weights to be top level.
# """
# spec_layer_weight_names = [
# "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
# ]
# shared_weight_names = ["embed_tokens"]
# spec_layer_weight = False
# shared_weight = False
# for weight_name in spec_layer_weight_names:
# if weight_name in name:
# spec_layer_weight = True
# if weight_name in shared_weight_names:
# shared_weight = True
# break
# if not spec_layer_weight:
# # treat rest weights as weights for transformer layer block
# name = name.replace(f"model.layers.{spec_layer}.",
# f"model.layers.{spec_layer}.mtp_block.")
# elif shared_weight:
# # treat shared weights as top level weights
# name = name.replace(f"model.layers.{spec_layer}.", "model.")
# return name
...@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.dp_attention import (dp_gather, dp_reduce_scatter_tensor,
get_moe_tp_size, get_moe_tp_rank, get_attention_tp_size)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -70,8 +72,8 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter, ...@@ -70,8 +72,8 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from lmslim.quantize.quant_ops import lm_faster_rmsquant
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -84,38 +86,40 @@ class DeepseekV2MLP(nn.Module): ...@@ -84,38 +86,40 @@ class DeepseekV2MLP(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
vllm_config = get_current_vllm_config()
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj") prefix=f"{prefix}.gate_up_proj",
enable_dp_attn_moe=enable_dp_attention)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results if not enable_dp_attention else False,
prefix=f"{prefix}.down_proj") prefix=f"{prefix}.down_proj",
enable_dp_attn_moe=enable_dp_attention)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x, def forward(self, x,
rms_weight: Optional[torch.Tensor] = None, xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
residual: Optional[torch.Tensor] = None, iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
update_hd: Optional[bool] = False, ) -> torch.Tensor:
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, i_q, _scales, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd) assert iqis is not None
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT: if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True) x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else: else:
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x
return x, new_resi, i_q, _scales
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs) gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
if envs.USE_FUSED_SILU_MUL_QUANT: if envs.USE_FUSED_SILU_MUL_QUANT:
...@@ -180,33 +184,19 @@ class DeepseekV2MoE(nn.Module): ...@@ -180,33 +184,19 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts) self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start + self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts) self.n_local_physical_experts)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if not self.use_deepep: self.enable_shared_experts_overlap = False
self.experts = FusedMoE( self.enable_shared_experts_fusion = (self.n_shared_experts != 0 and envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION)
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None: if not self.use_deepep:
if config.n_shared_experts is not None and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2MLP(
...@@ -214,10 +204,54 @@ class DeepseekV2MoE(nn.Module): ...@@ -214,10 +204,54 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs( reduce_results = False,
),
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
and config.n_shared_experts is not None)
if self.enable_shared_experts_overlap:
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
num_fake_experts = (
config.n_shared_experts
if self.enable_shared_experts_fusion
else 0)
self.experts = FusedMoE(
num_experts=config.n_routed_experts + (num_fake_experts),
top_k=config.num_experts_per_tok + min(num_fake_experts, 1),
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else: else:
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
...@@ -249,22 +283,61 @@ class DeepseekV2MoE(nn.Module): ...@@ -249,22 +283,61 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
shared_experts=self.shared_experts) shared_experts=self.shared_experts)
self.run_shared_expert_singlely = (
self.n_shared_experts is not None
and not self.enable_shared_experts_overlap
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None # iq = input quant, is = input scale
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
# For shared experts overlap optimization.
def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
hidden_states_copy = hidden_states.clone()
return self.experts(
hidden_states = hidden_states,
router_logits = router_logits,
hidden_states_copy = hidden_states_copy,
i_q = i_q,
i_s = i_s)
# For shared experts fusion optimization.
def shared_exprts_fusion_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
return self.experts(
hidden_states = hidden_states,
router_logits = router_logits,
i_q = i_q,
i_s = i_s)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
if self.n_shared_experts is not None: if self.n_shared_experts is not None and not self.enable_shared_experts_overlap:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs) shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
...@@ -273,63 +346,30 @@ class DeepseekV2MoE(nn.Module): ...@@ -273,63 +346,30 @@ class DeepseekV2MoE(nn.Module):
router_logits=router_logits, router_logits=router_logits,
shared_output=shared_output) shared_output=shared_output)
else: else:
if hidden_states.dtype != torch.float16: if self.enable_shared_experts_fusion:
final_hidden_states = self.experts( final_hidden_states = shared_exprts_fusion_pass(hidden_states, router_logits)
hidden_states=hidden_states, elif self.enable_shared_experts_overlap:
router_logits=router_logits) * self.routed_scaling_factor assert self.shared_experts is not None
else: shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states *= self.routed_scaling_factor
else: final_hidden_states += shared_output
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
else:
if not self.enable_expert_parallel:
i_q, i_s = None, None
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else: else:
shared_output = self.shared_experts(hidden_states) assert shared_output is not None
final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output,
i_q=i_q, i_s=i_s)
else: else:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits) * self.routed_scaling_factor
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
...@@ -338,11 +378,68 @@ class DeepseekV2MoE(nn.Module): ...@@ -338,11 +378,68 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: else: # RQ
if not self.enable_expert_parallel:
i_q, i_s = None, None
if iqis is not None:
i_q, i_s = iqis
if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states, iqis=iqis)
else:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if self.enable_shared_experts_fusion:
final_hidden_states = shared_exprts_fusion_pass(hidden_states, router_logits, iqis = iqis)
elif self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits, iqis = iqis)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
final_hidden_states += shared_output
else:
assert shared_output is not None
final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
else:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output,
i_q=i_q, i_s=i_s)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else: # EP
router_logits, _ = self.gate(hidden_states)
if self.use_deepep: if self.use_deepep:
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states, shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if shared_output is not None: if shared_output is not None:
...@@ -354,37 +451,48 @@ class DeepseekV2MoE(nn.Module): ...@@ -354,37 +451,48 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: else:
if self.n_shared_experts is not None: if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) shared_output = self.shared_experts(hidden_states, iqis=iqis)
else: else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
final_hidden_states = self.experts( if self.enable_shared_experts_overlap:
hidden_states=hidden_states, assert self.shared_experts is not None
router_logits=router_logits) shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow
if shared_output is not None: # See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states += shared_output
else: else:
# Fix FP16 overflow assert shared_output is not None
# See DeepseekV2DecoderLayer for more details. final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
final_hidden_states = final_hidden_states + shared_output \ else:
* (1. / self.routed_scaling_factor) if i_q is not None:
i_q=iqis[0]
if self.tp_size > 1: i_s=iqis[1]
if envs.VLLM_ENABLE_TBO: final_hidden_states = self.experts(hidden_states=hidden_states,
final_hidden_states = self.tbo_all_reduce(final_hidden_states) router_logits=router_logits,
else: i_q=i_q, i_s=i_s)
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel( if shared_output is not None:
final_hidden_states)) if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
if envs.USE_FUSED_RMS_QUANT: else:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else: else:
return final_hidden_states.view(num_tokens, hidden_dim) final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
...@@ -546,7 +654,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -546,7 +654,7 @@ class DeepseekV2MLAAttention(nn.Module):
""" """
Main reference: DeepseekV2 paper, and FlashInfer Implementation Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
""" """
...@@ -602,7 +710,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -602,7 +710,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.q_lora_rank, self.q_lora_rank,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
eps=config.rms_norm_eps,
prefix=f"{prefix}.q_a_proj") prefix=f"{prefix}.q_a_proj")
self.q_b_proj = ColumnParallelLinear(q_lora_rank, self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads * self.num_heads *
...@@ -623,7 +730,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -623,7 +730,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_b_proj") prefix=f"{prefix}.q_b_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank, self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -697,27 +804,24 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -697,27 +804,24 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
# TODO wjl: 这里的forward拆了
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
pa_rms_weight: Optional[torch.Tensor] = None, pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None, pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6, pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8, pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True update_input: Optional[bool] = True,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
if envs.VLLM_USE_FUSED_QA_KVA_GEMM: if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
qc_kvc_kpe, new_residual, _bias = self.qa_kva_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False) qc_kvc_kpe, _bias = self.qa_kva_proj(hidden_states, iqis)
q_c = qc_kvc_kpe[:, :self.q_lora_rank] q_c = qc_kvc_kpe[:, :self.q_lora_rank]
kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:] kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False) q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
...@@ -727,15 +831,15 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -727,15 +831,15 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = kvc_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c, k_pe = kvc_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else: else:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False) q_c, _ = self.q_a_proj(hidden_states, iqis=iqis)
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False) q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else: else:
q = self.q_proj(hidden_states)[0] q = self.q_proj(hidden_states)[0]
kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0] kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, iqis=iqis)[0]
kv_c, k_pe = kvc_kpe.split( kv_c, k_pe = kvc_kpe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c) kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
...@@ -763,19 +867,21 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -763,19 +867,21 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype: if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype) cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device) kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
weight=weight, weight=weight,
cos_sin_cache=cos_sin_cache) cos_sin_cache=cos_sin_cache)
return self.o_proj(attn_out)[0], new_residual return self.o_proj(attn_out)[0]
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0] q_c = self.q_a_proj(hidden_states)[0]
...@@ -788,7 +894,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -788,7 +894,7 @@ class DeepseekV2MLAAttention(nn.Module):
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c) kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else: else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
...@@ -811,19 +917,21 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -811,19 +917,21 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype: if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype) cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device) kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
weight=weight, weight=weight,
cos_sin_cache=cos_sin_cache) cos_sin_cache=cos_sin_cache)
packages_ = self.o_proj(attn_out, packages_ = self.o_proj(attn_out,
pa_rms_weight=pa_rms_weight, pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual, pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps, pa_rms_eps=pa_rms_eps,
...@@ -870,13 +978,15 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -870,13 +978,15 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype: if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype) cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device) kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -912,10 +1022,13 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -912,10 +1022,13 @@ class DeepseekV2DecoderLayer(nn.Module):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.config = config self.config = config
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace and layer_idx >= config.first_k_dense_replace
...@@ -935,6 +1048,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -935,6 +1048,8 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.enable_ep_sp = isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1
self.is_mtp_layer = False self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers: if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True self.is_mtp_layer = True
...@@ -943,6 +1058,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -943,6 +1058,9 @@ class DeepseekV2DecoderLayer(nn.Module):
DeepseekV2MoE) and self.use_deepep and \ DeepseekV2MoE) and self.use_deepep and \
self.tp_size > 1 and not self.is_mtp_layer: self.tp_size > 1 and not self.is_mtp_layer:
reduce_results = False reduce_results = False
else:
if self.enable_dp_attention:
reduce_results = False
if model_config.use_mla: if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention attn_cls = DeepseekV2MLAAttention
...@@ -974,10 +1092,11 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -974,10 +1092,11 @@ class DeepseekV2DecoderLayer(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT
self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
self._eps = config.rms_norm_eps
def forward_fused_rmsquant( def forward_fused_RQ(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -985,25 +1104,28 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -985,25 +1104,28 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Fix residual FP16 overflow # Fix residual FP16 overflow
residual_fix_overflow = False residual_fix_overflow = False
assert self.input_layernorm.has_weight is True assert self.input_layernorm.has_weight is True
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True residual_fix_overflow = True
i_q, i_s = lm_faster_rmsquant(input=hidden_states,
rms_weight = self.input_layernorm.weight.data,
epsilon=self._eps,
quant_dtype=torch.int8,
residual=None,
update_input=False)
else: else:
hidden_states, new_residual = self.self_attn( i_q, i_s = lm_faster_rmsquant(input=hidden_states,
positions = positions, rms_weight = self.input_layernorm.weight.data,
hidden_states = hidden_states, epsilon=self._eps,
rms_weight = self.input_layernorm.weight.data, quant_dtype=torch.int8,
residual = residual residual=residual,
) update_input=False)
residual = new_residual
hidden_states = self.self_attn(positions=positions,
hidden_states = hidden_states, # get attr
iqis=(i_q, i_s))
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
# rmsnorm, and rmsnorm result would not affect by scale. # rmsnorm, and rmsnorm result would not affect by scale.
...@@ -1012,11 +1134,17 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1012,11 +1134,17 @@ class DeepseekV2DecoderLayer(nn.Module):
# The residual is shared by all layers, we only scale it on # The residual is shared by all layers, we only scale it on
# first layer. # first layer.
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi, _i_q, _scales = self.mlp(hidden_states, update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
rms_weight=self.post_attention_layernorm.weight.data, _i_q, _i_s = lm_faster_rmsquant(input=hidden_states,
residual=residual, rms_weight=self.post_attention_layernorm.weight.data,
) epsilon=self._eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hs)
new_resi = residual
hidden_states = self.mlp(hidden_states,
iqis=(_i_q, _i_s))
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
...@@ -1029,9 +1157,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1029,9 +1157,9 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, new_resi return hidden_states, new_resi
def forward_fused_CRQ( def forward_fused_CRQ(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] residual: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
residual_fix_overflow = False residual_fix_overflow = False
...@@ -1042,33 +1170,33 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1042,33 +1170,33 @@ class DeepseekV2DecoderLayer(nn.Module):
else: else:
hidden_states, resi_new = self.input_layernorm( hidden_states, resi_new = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
residual = resi_new residual = resi_new
new_hs, new_resi, xq, xs = self.self_attn( new_hs, new_resi, xq, xs = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
pa_rms_weight=self.post_attention_layernorm.weight.data, pa_rms_weight=self.post_attention_layernorm.weight.data,
pa_residual=residual, pa_residual=residual,
pa_rms_eps=self.post_attention_layernorm.variance_epsilon, pa_rms_eps=self.post_attention_layernorm.variance_epsilon,
pa_quant_dtype = torch.int8, pa_quant_dtype = torch.int8,
update_input=True update_input=True
) )
assert xq is not None and xs is not None assert xq is not None and xs is not None
if new_hs.dtype == torch.float16: # overflow处理逻辑 if new_hs.dtype == torch.float16: # overflow处理逻辑
new_hs *= 1. / self.routed_scaling_factor new_hs *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow: if self.layer_idx == 0 or residual_fix_overflow:
new_resi *= 1. / self.routed_scaling_factor new_resi *= 1. / self.routed_scaling_factor
hidden_states = self.mlp(new_hs, xqxs=(xq, xs)) hidden_states = self.mlp(new_hs, xqxs=(xq, xs))
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi return hidden_states, new_resi
def forward_default( def forward_default(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] residual: Optional[torch.Tensor]
...@@ -1083,26 +1211,26 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1083,26 +1211,26 @@ class DeepseekV2DecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
if not self.is_mtp_layer: if not self.is_mtp_layer and self.enable_ep_sp and \
if isinstance(self.mlp, self.layer_idx > self.config.first_k_dense_replace:
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \ hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
) )
if not self.is_mtp_layer: if not self.is_mtp_layer and self.enable_ep_sp:
if isinstance(self.mlp, if self.layer_idx == self.config.first_k_dense_replace:
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: residual = residual.tensor_split(self.tp_size)[self.tp_rank]
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0) hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention:
if self.tp_rank == 0:
hidden_states += residual
hidden_states = dp_gather(hidden_states)
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
...@@ -1115,27 +1243,31 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1115,27 +1243,31 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( if not self.enable_dp_attention:
hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.is_mtp_layer: else:
if isinstance(self.mlp, num_tokens = hidden_states.shape[0]
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: new_bs = num_tokens // get_moe_tp_size() * get_attention_tp_size()
ori_bs = hidden_states.shape[0] residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs hidden_states = self.post_attention_layernorm(hidden_states)
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous() if self.is_mtp_layer and self.enable_ep_sp:
new_bs = (ori_bs+pad_size) // self.tp_size ori_bs = hidden_states.shape[0]
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous() pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states, [0, 0, 0, pad_size], value=0)
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if self.is_mtp_layer: if self.enable_dp_attention:
if isinstance(self.mlp, hidden_states = dp_reduce_scatter_tensor(hidden_states)
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) if self.is_mtp_layer and self.enable_ep_sp:
hidden_states = hidden_states[:ori_bs, :] hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
...@@ -1147,10 +1279,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1147,10 +1279,10 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual return hidden_states, residual
def choose_forward(self): def choose_forward(self):
if self.use_fused_rms_quant: if self.use_fused_rms_quant:
return self.forward_fused_rmsquant return self.forward_fused_RQ
elif self.use_fused_custom_all_reduce: elif self.use_fused_custom_all_reduce:
return self.forward_fused_CRQ return self.forward_fused_CRQ
...@@ -1212,13 +1344,14 @@ class DeepseekV2Model(nn.Module): ...@@ -1212,13 +1344,14 @@ class DeepseekV2Model(nn.Module):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
...@@ -1312,10 +1445,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1312,10 +1445,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.num_routed_experts = example_moe.n_routed_experts self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.tritonsingleton.topk = config.num_experts_per_tok self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method self.tritonsingleton.quant_method=self.quant_method
...@@ -1371,22 +1504,22 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1371,22 +1504,22 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
dtype=dtype, dtype=dtype,
device=device), device=device),
}) })
def restore_qzeros_tensor(self, qzeros, qscales): def restore_qzeros_tensor(self, qzeros, qscales):
low_bits = qzeros & 0x0F low_bits = qzeros & 0x0F
high_bits = qzeros >> 4 high_bits = qzeros >> 4
zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1]) zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1])
zeors_int16 = zeors_tensor.to(torch.int16) zeors_int16 = zeors_tensor.to(torch.int16)
assert zeors_int16.shape == qscales.shape assert zeors_int16.shape == qscales.shape
uint16_tensor1 = zeors_int16.view(torch.uint16) uint16_tensor1 = zeors_int16.view(torch.uint16)
uint16_tensor2 = qscales.view(torch.uint16) uint16_tensor2 = qscales.view(torch.uint16)
uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16 uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16
uint32_tensor2 = uint16_tensor2.to(torch.int32) uint32_tensor2 = uint16_tensor2.to(torch.int32)
result_tensor = uint32_tensor1 + uint32_tensor2 result_tensor = uint32_tensor1 + uint32_tensor2
result_tensor =result_tensor.view(torch.uint32) result_tensor =result_tensor.view(torch.uint32)
result_tensor = result_tensor.transpose(1, 2).contiguous() result_tensor = result_tensor.transpose(1, 2).contiguous()
...@@ -1412,7 +1545,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1412,7 +1545,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts, num_experts=self.config.n_routed_experts
+ (
self.num_shared_experts
if (envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION and self.num_shared_experts > 0)
else 0
),
num_redundant_experts=self.num_redundant_experts) num_redundant_experts=self.num_redundant_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -1425,6 +1563,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1425,6 +1563,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if spec_layer is not None: if spec_layer is not None:
continue # skip spec decode layers for main model continue # skip spec decode layers for main model
# Assuame num of shared experts is only one.
if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION \
and self.num_shared_experts > 0 \
and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
...@@ -1494,7 +1641,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1494,7 +1641,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# However it's not mapped locally to this rank # However it's not mapped locally to this rank
# So we simply skip it # So we simply skip it
continue continue
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -1515,7 +1662,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1515,7 +1662,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
"self_attn.q_proj.weight", "self_attn.q_proj.weight",
...@@ -1533,19 +1680,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1533,19 +1680,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
for layername in loaded_params: for layername in loaded_params:
weight = params_dict[layername] weight = params_dict[layername]
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params return loaded_params
......
...@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter, ...@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -393,7 +394,7 @@ class FalconModel(nn.Module): ...@@ -393,7 +394,7 @@ class FalconModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids) return self.word_embeddings(input_ids)
......
...@@ -31,6 +31,47 @@ from typing import Optional, Union ...@@ -31,6 +31,47 @@ from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Glm4Config from transformers import Glm4Config
import vllm.envs as envs
class MultiModalConfigProxy:
"""
Proxy class to handle both flat configs (e.g., Glm4Config) and
nested multimodal configs (e.g., Glm4vConfig with text_config).
For multimodal configs where attributes are in text_config, this proxy
transparently delegates attribute access to text_config when needed.
"""
def __init__(self, config):
# Store original config (for attributes that do exist at top level)
object.__setattr__(self, '_config', config)
def __getattr__(self, name):
# First try to get from the original config (works for flat configs)
try:
return getattr(self._config, name)
except AttributeError:
# If not found and config has text_config, try there
if hasattr(self._config, 'text_config'):
return getattr(self._config.text_config, name)
# Re-raise the original error if text_config doesn't have it either
raise AttributeError(
f"'{type(self._config).__name__}' object has no attribute '{name}'"
)
def __setattr__(self, name, value):
# Allow setting attributes on the proxy itself
if name == '_config':
object.__setattr__(self, name, value)
else:
setattr(self._config, name, value)
def __hasattr__(self, name):
return hasattr(self._config, name) or (
hasattr(self._config, 'text_config') and
hasattr(self._config.text_config, name)
)
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
...@@ -151,6 +192,9 @@ class Glm4DecoderLayer(nn.Module): ...@@ -151,6 +192,9 @@ class Glm4DecoderLayer(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000) rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
...@@ -177,14 +221,11 @@ class Glm4DecoderLayer(nn.Module): ...@@ -177,14 +221,11 @@ class Glm4DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.input_layernorm = RMSNorm(config.hidden_size, rms_norm_eps = getattr(config, 'rms_norm_eps', 1e-5)
eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
eps=config.rms_norm_eps) self.post_self_attn_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
self.post_self_attn_layernorm = RMSNorm(config.hidden_size, self.post_mlp_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
eps=config.rms_norm_eps)
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
...@@ -254,6 +295,9 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -254,6 +295,9 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
...@@ -289,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -289,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -33,6 +33,46 @@ from transformers import LlamaConfig ...@@ -33,6 +33,46 @@ from transformers import LlamaConfig
import os import os
import re import re
class MultiModalConfigProxy:
"""
Proxy class to handle both flat configs (e.g., LlamaConfig) and
nested multimodal configs (e.g., Glm4vConfig with text_config).
For multimodal configs where attributes are in text_config, this proxy
transparently delegates attribute access to text_config when needed.
"""
def __init__(self, config):
# Store original config (for attributes that do exist at top level)
object.__setattr__(self, '_config', config)
def __getattr__(self, name):
# First try to get from the original config (works for flat configs)
try:
return getattr(self._config, name)
except AttributeError:
# If not found and config has text_config, try there
if hasattr(self._config, 'text_config'):
return getattr(self._config.text_config, name)
# Re-raise the original error if text_config doesn't have it either
raise AttributeError(
f"'{type(self._config).__name__}' object has no attribute '{name}'"
)
def __setattr__(self, name, value):
# Allow setting attributes on the proxy itself
if name == '_config':
object.__setattr__(self, name, value)
else:
setattr(self._config, name, value)
def __hasattr__(self, name):
return hasattr(self._config, name) or (
hasattr(self._config, 'text_config') and
hasattr(self._config.text_config, name)
)
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
...@@ -246,6 +286,9 @@ class LlamaDecoderLayer(nn.Module): ...@@ -246,6 +286,9 @@ class LlamaDecoderLayer(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
...@@ -340,6 +383,9 @@ class LlamaModel(nn.Module): ...@@ -340,6 +383,9 @@ class LlamaModel(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
...@@ -587,6 +633,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -587,6 +633,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
# Wrap config to handle both flat and nested multimodal configs
config = MultiModalConfigProxy(config)
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config, self.model = self._init_model(vllm_config=vllm_config,
......
...@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP ...@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module): ...@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -136,22 +189,49 @@ class Qwen3Attention(nn.Module): ...@@ -136,22 +189,49 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm if envs.VLLM_USE_FUSED_RMS_ROPE:
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, # Fused RMSNorm + RoPE path through custom op.
self.head_dim) cos_sin_cache = self.rotary_emb.cos_sin_cache
if envs.VLLM_USE_APEX_RN: if (cos_sin_cache.device != q.device
q_by_head = self.q_norm.forward_apex(q_by_head) or cos_sin_cache.dtype != q.dtype):
else: cos_sin_cache = cos_sin_cache.to(q.device,
q_by_head = self.q_norm.forward_cuda(q_by_head) dtype=q.dtype,
q = q_by_head.view(q.shape) non_blocking=True)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, # Persist the converted cache so we don't re-copy/re-allocate
self.head_dim) # on every forward when the original buffer starts on CPU.
if envs.VLLM_USE_APEX_RN: self.rotary_emb.cos_sin_cache = cos_sin_cache
k_by_head = self.k_norm.forward_apex(k_by_head) q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else: else:
k_by_head = self.k_norm.forward_cuda(k_by_head) # Add qk-norm then RoPE (original path).
k = k_by_head.view(k.shape) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
q, k = self.rotary_emb(positions, q, k) self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -38,6 +38,19 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size ...@@ -38,6 +38,19 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
try:
from vllm.model_executor.layers.fused_moe.router_capture import (
maybe_record_router_logits,
)
except ImportError:
def maybe_record_router_logits(
*,
layer_name: str,
router_logits: torch.Tensor,
top_k: int,
) -> None:
return None
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -111,6 +124,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -111,6 +124,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self._router_top_k = int(config.num_experts_per_tok)
self._router_capture_layer_name = prefix
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
...@@ -140,6 +155,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -140,6 +155,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if not (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()):
capture_enabled = envs.VLLM_MOE_ROUTER_CAPTURE
if capture_enabled:
maybe_record_router_logits(
layer_name=self._router_capture_layer_name,
router_logits=router_logits,
top_k=self._router_top_k,
)
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
...@@ -453,7 +476,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -453,7 +476,7 @@ class Qwen3MoeModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
......
...@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, ...@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
class TeleChat2Model(LlamaModel): class TeleChat2Model(LlamaModel):
...@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel): ...@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -10,6 +10,7 @@ from torch.nn import Parameter ...@@ -10,6 +10,7 @@ from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader from vllm.model_executor.utils import _make_synced_weight_loader
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
__all__ = [ __all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
...@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self._output_dim = output_dim self._output_dim = output_dim
super().__init__(**kwargs) super().__init__(**kwargs)
self.expect_tp_size = -1 self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property @property
...@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.output_dim] shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size) tp_rank * shard_size, shard_size)
...@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
param_data = param_data.narrow(self.output_dim, shard_offset, param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size) shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
...@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
self._input_dim = input_dim self._input_dim = input_dim
super().__init__(**kwargs) super().__init__(**kwargs)
self.expect_tp_size = -1 self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property @property
def input_dim(self): def input_dim(self):
...@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.input_dim] shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim, loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size) tp_rank * shard_size, shard_size)
......
...@@ -17,7 +17,7 @@ from vllm.utils import cuda_device_count_stateless ...@@ -17,7 +17,7 @@ from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from vllm.utils import SUPPORT_TC from vllm.utils import SUPPORT_TC
if not SUPPORT_TC: if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0' os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0' os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
......
...@@ -1958,7 +1958,7 @@ class W8a8GetCacheJSON: ...@@ -1958,7 +1958,7 @@ class W8a8GetCacheJSON:
self.moe_weight_shapes=[] self.moe_weight_shapes=[]
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
self.cache_json_data = {}
device_name =arch_name+'_'+str(arch_cu)+'cu' device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name self.device_name=device_name
self.topk=1 self.topk=1
...@@ -2060,19 +2060,27 @@ class W8a8GetCacheJSON: ...@@ -2060,19 +2060,27 @@ class W8a8GetCacheJSON:
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json" return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK, def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False): block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False,use_int8_w8a8:Optional[bool]=False):
if use_int4_w4a8: if use_int4_w4a8:
if block_size is not None: if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json" return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else: else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json" return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
elif use_int8_w8a8:
if block_size is not None:
return self.triton_json_dir + f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir + f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else: else:
if block_size is not None: if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json" return self.triton_json_dir + f"/MOE_BLOCKFP8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else: else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json" return self.triton_json_dir + f"/MOE_W8A8FP8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK): def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
if file_path in self.cache_json_data:
# 直接返回缓存数据,避免重复读取
return self.cache_json_data[file_path]
cache_json_file=file_path cache_json_file=file_path
if os.path.exists(file_path): if os.path.exists(file_path):
...@@ -2088,7 +2096,7 @@ class W8a8GetCacheJSON: ...@@ -2088,7 +2096,7 @@ class W8a8GetCacheJSON:
for sub_key, sub_value in value.items(): for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}" configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value configs_dict[configs_key]=sub_value
self.cache_json_data[file_path] = configs_dict
return configs_dict return configs_dict
# Adapted from: https://stackoverflow.com/a/47212782/5082708 # Adapted from: https://stackoverflow.com/a/47212782/5082708
......
...@@ -136,6 +136,17 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -136,6 +136,17 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.") raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order return key_stride_order, value_stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.float8_e4m3fn
else:
raise ValueError(f"Unsupported FP8 dtype: {kv_cache_dtype}")
elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass @dataclass
class FlashAttentionMetadata: class FlashAttentionMetadata:
...@@ -589,14 +600,19 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -589,14 +600,19 @@ class FlashAttentionImpl(AttentionImpl):
) )
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn) # key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn) # value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens, num_heads, head_size = query.shape dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
query, _ = ops.scaled_fp8_quant( self.kv_cache_dtype)
query.reshape( key_cache = key_cache.view(dtype)
(num_tokens, num_heads * head_size)).contiguous(), value_cache = value_cache.view(dtype)
layer._q_scale) if envs.VLLM_USE_QUERY_QUANT:
query = query.reshape((num_tokens, num_heads, head_size)) num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \ use_local_attn = \
...@@ -620,9 +636,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -620,9 +636,10 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) # descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
if not current_platform.is_rocm(): if not current_platform.is_rocm():
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func( flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
...@@ -672,6 +689,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -672,6 +689,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale.expand(descale_shape), # q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape), # k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape), # v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=attn_metadata.max_num_splits, # num_splits=attn_metadata.max_num_splits,
is_prefix_cache=True, is_prefix_cache=True,
) )
...@@ -729,6 +749,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -729,6 +749,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale, # q_descale=layer._q_scale,
# k_descale=layer._k_scale, # k_descale=layer._k_scale,
# v_descale=layer._v_scale, # v_descale=layer._v_scale,
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
) )
return output return output
...@@ -879,12 +902,12 @@ def cascade_attention( ...@@ -879,12 +902,12 @@ def cascade_attention(
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata, scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version, # fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None, if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None, if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None, if v_descale is not None else None,
is_prefix_cache=True, is_prefix_cache=True,
) )
...@@ -932,12 +955,12 @@ def cascade_attention( ...@@ -932,12 +955,12 @@ def cascade_attention(
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata, scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version, # fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None, if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None, if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None, if v_descale is not None else None,
is_prefix_cache=True, is_prefix_cache=True,
) )
......
...@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_qkv
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -1213,6 +1213,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1213,6 +1213,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q_ori = q_ori[:num_actual_toks, ...] q_ori = q_ori[:num_actual_toks, ...]
decode_q = q_ori[:num_decode_tokens] decode_q = q_ori[:num_decode_tokens]
prefill_q = q_ori[num_decode_tokens:] prefill_q = q_ori[num_decode_tokens:]
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
elif q.dtype == torch.bfloat16:
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
# write the latent and rope to kv cache # write the latent and rope to kv cache
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
...@@ -1226,28 +1234,61 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1226,28 +1234,61 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale, scale=layer._k_scale,
) )
else: else:
if self.kv_cache_dtype == "auto": if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if q.dtype == torch.float16: if has_prefill:
kv_cache_dtype_str = "fp16" fused_rms_norm_rope_contiguous(
elif q.dtype == torch.bfloat16: positions[:num_actual_toks, ...],
kv_cache_dtype_str = "bf16" q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
fuse_rmsnorm_rope_quant_qkv(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
q_scale,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else: else:
kv_cache_dtype_str = self.kv_cache_dtype fused_rms_norm_rope_contiguous(
fused_rms_norm_rope_contiguous( positions[:num_actual_toks, ...],
positions[:num_actual_toks, ...], q,
q, k_pe.squeeze(1),
k_pe.squeeze(1), k_c_normed, # not normed
k_c_normed, # not normed key_normed[:num_actual_toks, ...], # normed
key_normed[:num_actual_toks, ...], # normed weight,
weight, cos_sin_cache,
cos_sin_cache, attn_metadata.slot_mapping.flatten(),
attn_metadata.slot_mapping.flatten(), kv_cache,
kv_cache, kv_cache_dtype_str,
kv_cache_dtype_str, 1.0,
1.0, False,
False, 1e-6,
1e-6, )
)
if has_prefill: if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
...@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode: if has_decode:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and self.kv_cache_dtype=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
decode_q = q_quant[:num_decode_tokens]
decode_q_nope, decode_q_pe = decode_q.split( decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P) # Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1) decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L) # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) # todo: bmm support
decode_ql_nope = torch.bmm(q_scale, decode_q_nope, self.W_UK_T) if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and self.kv_cache_dtype=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA else torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L) # Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
......
...@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, ...@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe, flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8,
flash_mla_with_kvcache_fp8_with_cat,
get_mla_decoding_metadata_dense_fp8, get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -181,31 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -181,31 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if q_nope.shape[0] < 1024: o, _ = flash_mla_with_kvcache_fp8_with_cat(
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode q_nope=q_nope.unsqueeze(1),
q = concat_helper_decode(q_nope, q_pe, dim=2)\ q_pe=q_pe.unsqueeze(1),
.unsqueeze(1) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q = q_scale,
descale_k = k_scale,
)
else:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else: else:
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
else: o, _ = flash_mla_with_kvcache_fp8(
q = torch.cat([q_nope, q_pe], dim=-1)\ q=q.to(torch.float8_e4m3fn),
.unsqueeze(1) # Add seqlen dim of 1 (decode) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
o, _ = flash_mla_with_kvcache_fp8( block_table=attn_metadata.decode.block_table,
q=q.to(torch.float8_e4m3fn), cache_seqlens=attn_metadata.decode.seq_lens,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1 head_dim_v=self.kv_lora_rank,
block_table=attn_metadata.decode.block_table, tile_scheduler_metadata=attn_metadata.decode.
cache_seqlens=attn_metadata.decode.seq_lens, tile_scheduler_metadata,
head_dim_v=self.kv_lora_rank, num_splits=attn_metadata.decode.num_splits,
tile_scheduler_metadata=attn_metadata.decode. softmax_scale=self.scale,
tile_scheduler_metadata, causal=True,
num_splits=attn_metadata.decode.num_splits, descale_q=q_scale,
softmax_scale=self.scale, descale_k=k_scale,
causal=True, )
descale_q=q_scale,
descale_k=k_scale,
)
else: else:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3": if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
......
...@@ -281,20 +281,27 @@ class Scheduler(SchedulerInterface): ...@@ -281,20 +281,27 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens) num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. if new_blocks is None:
# Preempt the lowest-priority request. if self.use_pp:
if self.policy == SchedulingPolicy.PRIORITY: preemptable_reqs = [r for r in self.running if
preempted_req = max( r.num_tokens_with_spec != r.num_computed_tokens]
self.running, else:
key=lambda r: (r.priority, r.arrival_time), preemptable_reqs = self.running
) # The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
preemptable_reqs,
key=lambda r: (r.priority, r.arrival_time),
)
else:
preempted_req = preemptable_reqs[-1]
self.running.remove(preempted_req) self.running.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.spec_token_ids = []
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -901,20 +908,26 @@ class Scheduler(SchedulerInterface): ...@@ -901,20 +908,26 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens) num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None: if new_blocks is None:
if self.use_pp:
preemptable_reqs = [r for r in self.running if
r.num_tokens_with_spec != r.num_computed_tokens]
else:
preemptable_reqs = self.running
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY: if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max( preempted_req = max(
self.running, preemptable_reqs,
key=lambda r: (r.priority, r.arrival_time), key=lambda r: (r.priority, r.arrival_time),
) )
self.running.remove(preempted_req)
else: else:
preempted_req = self.running.pop() preempted_req = preemptable_reqs[-1]
self.running.remove(preempted_req)
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.spec_token_ids = []
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -1051,14 +1064,15 @@ class Scheduler(SchedulerInterface): ...@@ -1051,14 +1064,15 @@ class Scheduler(SchedulerInterface):
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
if envs.VLLM_USE_PD_SPLIT: if envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd() if self.use_mla:
else: if self.full_cuda_graph and self.num_spec_tokens > 0:
if self.connector is not None: return self.schedule_split_pd()
return self.schedule_default() else:
if self.full_cuda_graph and self.use_mla and self.num_spec_tokens > 0 : return self.schedule_default()
return self.schedule_split_pd()
else: else:
return self.schedule_default() return self.schedule_split_pd()
else:
return self.schedule_default()
def _update_after_schedule( def _update_after_schedule(
self, self,
...@@ -1101,13 +1115,14 @@ class Scheduler(SchedulerInterface): ...@@ -1101,13 +1115,14 @@ class Scheduler(SchedulerInterface):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
num_tokens = req.num_generated_token_ids num_tokens = req.num_generated_token_ids
if self.use_pp: if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner # need to send the sampled tokens back because the model runner
# will cache them. # will cache them.
token_ids = req.all_token_ids[-num_tokens:] token_ids = req.all_token_ids[-num_tokens:] if num_tokens > 0 else []
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id]) new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
...@@ -1241,7 +1256,7 @@ class Scheduler(SchedulerInterface): ...@@ -1241,7 +1256,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1 request.num_generated_token_ids = len(generated_token_ids)
if scheduled_spec_token_ids: if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens # num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled # processed in the current step, considering scheduled
...@@ -1253,7 +1268,6 @@ class Scheduler(SchedulerInterface): ...@@ -1253,7 +1268,6 @@ class Scheduler(SchedulerInterface):
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids)) len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids), num_draft_tokens=len(scheduled_spec_token_ids),
......
...@@ -12,8 +12,9 @@ from vllm.attention.layer import Attention ...@@ -12,8 +12,9 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import DPMetadata, set_forward_context, get_warming_up
from vllm.logger import init_logger from vllm.logger import init_logger
import vllm.envs as envs
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
...@@ -92,6 +93,12 @@ class EagleProposer: ...@@ -92,6 +93,12 @@ class EagleProposer:
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.attn_tp_size = vllm_config.parallel_config.tensor_parallel_size
self.ep_sp = False
if self.enable_expert_parallel and self.dp_size > 1 and self.attn_tp_size > 1:
self.ep_sp = True
def propose( def propose(
self, self,
...@@ -187,6 +194,12 @@ class EagleProposer: ...@@ -187,6 +194,12 @@ class EagleProposer:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
# num_input_tokens += num_pad
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
...@@ -275,6 +288,13 @@ class EagleProposer: ...@@ -275,6 +288,13 @@ class EagleProposer:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else: else:
input_batch_size = batch_size input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
...@@ -369,6 +389,7 @@ class EagleProposer: ...@@ -369,6 +389,7 @@ class EagleProposer:
attn_metadata.num_decode_tokens) attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = ( self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills) attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = ( self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens) attn_metadata.decode.seq_lens)
...@@ -516,34 +537,118 @@ class EagleProposer: ...@@ -516,34 +537,118 @@ class EagleProposer:
logger.info("Loading EAGLE LM head weights from the target model.") logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head self.model.lm_head = target_language_model.lm_head
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
# For DP: Don't pad when setting enforce_eager.
# This lets us set enforce_eager on the prefiller in a P/D setup and
# still use CUDA graphs (enabled by this padding) on the decoder.
#
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# Early exit.
return 0, None
try:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
dp_size,
device="cpu",
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
except (RuntimeError, AttributeError) as e:
# DP group may not be initialized yet during dummy run
# Skip padding in this case
logger.debug(
"Skipping DP padding in eagle get_dp_padding due to: %s", e)
return 0, None
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
self, self,
num_tokens: int, num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None, attn_metadata: Optional[dict[str, Any]] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
) -> None: ) -> None:
if attn_metadata is not None and self.attn_metadata_cudagraph is None: if attn_metadata is not None and self.attn_metadata_cudagraph is None:
self.attn_metadata_cudagraph = attn_metadata[ self.attn_metadata_cudagraph = attn_metadata[
self.attn_layer_names[0]] self.attn_layer_names[0]]
# Padding for DP
num_input_tokens = num_tokens
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
# num_input_tokens += num_pad
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model( self.model(
self.input_ids[:num_tokens], self.input_ids[:num_input_tokens],
self.positions[:num_tokens], self.positions[:num_input_tokens],
self.hidden_states[:num_tokens], self.hidden_states[:num_input_tokens],
) )
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1: if self.dp_size > 1 and (self.enable_expert_parallel or self.enable_dp_attention) and self.num_speculative_tokens > 1:
for _ in range(self.num_speculative_tokens - 1): num_tokens = 1
with set_forward_context(attn_metadata,
self.vllm_config, if self.enable_dp_attention or self.ep_sp:
num_tokens=num_tokens): num_tokens = round_up(num_tokens, self.attn_tp_size)
self.model( # dp attention need all dp rank process same number tokens
self.input_ids[:num_tokens], if self.enable_dp_attention:
self.positions[:num_tokens], num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
self.hidden_states[:num_tokens], num_tokens += num_pad
)
if not get_warming_up():
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_tokens + 1],
seq_lens=self.runner.seq_lens[:num_tokens],
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
slot_mapping=self.runner.slot_mapping[:num_tokens],
spec_layer_decoding=True
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build_for_cudagraph_capture(
common_attn_metadata=common_attn_metadata
)
for i in range(self.num_speculative_tokens - 1):
if self.attn_metadata_cudagraph is not None:
if i == 0:
attn_metadata_cudagraph = self.attn_metadata_cudagraph
attn_metadata_cudagraph.num_actual_tokens = num_tokens
attn_metadata_cudagraph.num_decodes = num_tokens
attn_metadata_cudagraph.num_decode_tokens = num_tokens
attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
attn_metadata_cudagraph.decode.seq_lens[:num_tokens] = (
attn_metadata.decode.seq_lens)
attn_metadata_cudagraph.query_start_loc[:num_tokens + 1] = (
attn_metadata.query_start_loc)
attn_metadata_cudagraph.decode.block_table[:num_tokens] = (
attn_metadata.decode.block_table)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
def validate_same_kv_cache_group(self, def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None: kv_cache_config: KVCacheConfig) -> None:
......
...@@ -735,22 +735,35 @@ class InputBatch: ...@@ -735,22 +735,35 @@ class InputBatch:
self, repeat_counts: torch.Tensor self, repeat_counts: torch.Tensor
) -> SamplingMetadata: ) -> SamplingMetadata:
num_reqs = self.num_reqs num_reqs = self.num_reqs
repeat_counts_cpu = repeat_counts # `repeat_counts` is expected to be a CPU torch tensor, but some
# call sites may pass a NumPy array (or other array-likes). Normalize
# to a CPU tensor to keep downstream ops (e.g. repeat_interleave)
# consistent and avoid hard crashes.
if isinstance(repeat_counts, torch.Tensor):
repeat_counts_cpu = repeat_counts.to(device="cpu")
else:
repeat_counts_cpu = torch.as_tensor(repeat_counts, device="cpu")
all_greedy = self.all_greedy all_greedy = self.all_greedy
all_random = self.all_random all_random = self.all_random
# For reject-sampling optimization, force greedy sampling to keep # For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact. # rejection sampler assumptions (per-request shapes) intact.
def _expand_cpu_to_gpu( def _expand_cpu_to_gpu(
t: Optional[torch.Tensor], t: Optional[object],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if t is None: if t is None:
return None return None
base = t[:num_reqs] # `t` should be a CPU torch tensor, but can be a NumPy array view
if repeat_counts_cpu is not None: # (e.g. created via `tensor.numpy()`). Convert if needed.
base = base.repeat_interleave(repeat_counts_cpu, dim=0) if isinstance(t, torch.Tensor):
base = t[:num_reqs]
elif isinstance(t, np.ndarray):
base = torch.from_numpy(t[:num_reqs])
else:
base = torch.as_tensor(t, device="cpu")[:num_reqs]
base = base.repeat_interleave(repeat_counts_cpu, dim=0)
return base.to(device=self.device, return base.to(device=self.device,
dtype=dtype if dtype is not None else None, dtype=dtype if dtype is not None else None,
non_blocking=True) non_blocking=True)
......
...@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1: if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True self.ep_sp = True
self.enable_dp_attention = self.parallel_config.enable_dp_attention
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -512,14 +514,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -512,14 +514,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
new_token_ids = req_data.new_token_ids[i] new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any). # Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens. # This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) - num_new_tokens = len(new_token_ids)
req_state.num_tokens)
if num_new_tokens == 1: if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1]) req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0: elif num_new_tokens > 0:
req_state.output_token_ids.extend( req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:]) new_token_ids)
if len(spec_token_ids) > 0: if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids req_state.spec_token_ids = spec_token_ids
...@@ -545,20 +545,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -545,20 +545,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index) if resumed_from_preemption:
self.input_batch.block_table.add_row(new_block_ids, req_index)
else:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu # For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached. # because the sampled tokens are already cached.
if not is_last_rank: if not is_last_rank:
# Add new_token_ids to token_ids_cpu. # Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + 1 if len(new_token_ids) > 0:
self.input_batch.token_ids_cpu[ end_token_index = num_computed_tokens + 1
req_index, self.input_batch.token_ids_cpu[
start_token_index:end_token_index] = new_token_ids[-1] req_index,
self.input_batch.num_tokens_no_spec[ start_token_index:end_token_index] = new_token_ids[-1]
req_index] = end_token_index self.input_batch.num_tokens_no_spec[
self.input_batch.num_tokens[req_index] = end_token_index req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
if spec_token_ids: if spec_token_ids:
...@@ -1274,8 +1277,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1274,8 +1277,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# #
# TODO(tms) : There are many cases where padding is enabled for # TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# Early exit. # Early exit.
return 0, None return 0, None
...@@ -1354,7 +1356,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1354,7 +1356,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size) num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]): and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
...@@ -1597,6 +1599,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1597,6 +1599,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens: if seq_len < req_state.num_tokens:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if req_state.output_token_ids:
continue
# Ignore the sampled token for partial prefills. # Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled. # Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details # This relies on cuda-specific torch-internal impl details
...@@ -1674,7 +1681,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1674,7 +1681,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata, spec_decode_metadata,
attn_metadata, attn_metadata,
) )
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
...@@ -2084,7 +2093,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2084,7 +2093,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection This is to help balance expert-selection
- during profile_run - during profile_run
- during DP rank dummy run - during DP rank dummy run
""" """
dp_size = self.vllm_config.parallel_config.data_parallel_size dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
...@@ -2115,13 +2124,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2115,13 +2124,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
skip_eplb: bool = False, skip_eplb: bool = False,
is_profile: bool = False, is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
if num_tokens < self.tp_size: if num_tokens < self.tp_size:
num_tokens = self.tp_size num_tokens = self.tp_size
# Padding for DP num_tokens_across_dp = 0
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad num_tokens += num_pad
...@@ -2142,13 +2150,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2142,13 +2150,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots) min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots) num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
if not self.ep_sp: if not (self.ep_sp or self.enable_dp_attention):
num_scheduled_tokens_list[-1] += num_tokens % num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else: else:
if self.speculative_config is not None: if self.speculative_config is not None:
...@@ -2240,7 +2248,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2240,7 +2248,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle() and not is_profile: if self.speculative_config and self.speculative_config.use_eagle() and not is_profile:
#assert isinstance(self.drafter, EagleProposer) #assert isinstance(self.drafter, EagleProposer)
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer): if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens, attn_metadata) self.drafter.dummy_run(num_tokens, attn_metadata,
num_tokens_across_dp=num_tokens_across_dp)
# This is necessary to avoid blocking DP. # This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real # For dummy runs, we typically skip EPLB since we don't have any real
...@@ -3213,7 +3222,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3213,7 +3222,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size) num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]): and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
...@@ -3455,6 +3464,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3455,6 +3464,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens: if seq_len < req_state.num_tokens:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if req_state.output_token_ids:
continue
# Ignore the sampled token for partial prefills. # Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled. # Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details # This relies on cuda-specific torch-internal impl details
...@@ -3476,7 +3490,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3476,7 +3490,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states[:num_scheduled_tokens], hidden_states[:num_scheduled_tokens],
scheduler_output, scheduler_output,
) )
#-----------------------------------
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
...@@ -3516,6 +3529,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3516,6 +3529,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler # Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back. # doesn't need to send them back.
...@@ -3680,4 +3696,4 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3680,4 +3696,4 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if envs.VLLM_USE_ZERO_MTP: if envs.VLLM_USE_ZERO_MTP:
GPUModelRunner=GPUModelRunnerMTP GPUModelRunner=GPUModelRunnerMTP
else: else:
GPUModelRunner=GPUModelRunnerBase GPUModelRunner=GPUModelRunnerBase
\ No newline at end of file
...@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.dp_attention import initialize_dp_attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
...@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner ...@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
from vllm.forward_context import (set_warming_up, get_warming_up)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -260,6 +262,7 @@ class Worker(WorkerBase): ...@@ -260,6 +262,7 @@ class Worker(WorkerBase):
# warm up sizes that are not in cudagraph capture sizes, # warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance, # but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill. # e.g. for the max-num-batched token size in chunked prefill.
set_warming_up(True)
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
warmup_sizes = [ warmup_sizes = [
...@@ -297,6 +300,7 @@ class Worker(WorkerBase): ...@@ -297,6 +300,7 @@ class Worker(WorkerBase):
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
set_warming_up(False)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
...@@ -399,6 +403,9 @@ def init_worker_distributed_environment( ...@@ -399,6 +403,9 @@ def init_worker_distributed_environment(
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
if vllm_config.parallel_config.enable_dp_attention:
initialize_dp_attention(vllm_config, backend)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
......
...@@ -112,6 +112,12 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -112,6 +112,12 @@ class V1ZeroEagleProposer(EagleProposer):
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
# num_input_tokens += num_pad
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
...@@ -199,6 +205,13 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -199,6 +205,13 @@ class V1ZeroEagleProposer(EagleProposer):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else: else:
input_batch_size = batch_size input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment