Commit 347fc09c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-nmz' into v0.9.2-dev

parents ffcc47b7 3e191138
......@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonState)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm import envs
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
......@@ -93,7 +96,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.num_q_heads,
1, # MQA for the decode path
)
return m
......@@ -222,6 +224,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
q_scale = None,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor:
......@@ -233,18 +236,32 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
else:
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o)
......@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._q_scale, layer._k_scale, self.kv_cache_dtype)
return output
\ No newline at end of file
......@@ -69,6 +69,27 @@ def get_mla_metadata(
num_heads_k)
def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
......@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe(
return out, softmax_lse
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, return by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
#
# TODO: Add fake functions
#
......
......@@ -146,6 +146,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
......@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
# If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "0"))),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
......
......@@ -255,8 +255,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
......@@ -300,8 +300,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
......
......@@ -22,19 +22,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union
import os
import re
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
......@@ -48,17 +51,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, extract_layer_index,
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
......@@ -105,49 +108,86 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.experts = FusedMoE(num_experts=config.num_experts,
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
reduce_results=True,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts")
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
quant_config=quant_config,
prefix=f"{prefix}.gate")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
assert hidden_states.dim(
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
is_input_1d = hidden_states.dim() == 1
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
return final_hidden_states.view(orig_shape)
# return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else \
final_hidden_states
class Qwen3MoeAttention(nn.Module):
......@@ -166,6 +206,7 @@ class Qwen3MoeAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -189,6 +230,7 @@ class Qwen3MoeAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim,
......@@ -210,72 +252,25 @@ class Qwen3MoeAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {},
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
# q_out:torch.Tensor,
# k_out:torch.Tensor,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -283,52 +278,23 @@ class Qwen3MoeAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if envs.VLLM_USE_FUSED_RMS_ROPE :
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
# # q, k 使用 continuous
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
# Add qk-norm then RoPE (original path).
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -336,19 +302,21 @@ class Qwen3MoeAttention(nn.Module):
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
......@@ -362,6 +330,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
)
# `mlp_only_layers` in the config.
......@@ -371,8 +340,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
......@@ -416,9 +384,11 @@ class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
......@@ -433,12 +403,11 @@ class Qwen3MoeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
......@@ -475,8 +444,7 @@ class Qwen3MoeModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -486,6 +454,16 @@ class Qwen3MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
num_redundant_experts=self.num_redundant_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
......@@ -502,16 +480,9 @@ class Qwen3MoeModel(nn.Module):
".v_scale", "_v_scale", ".weight_scale",
"_weight_scale", ".input_scale", "_input_scale")
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if self.use_llama_nn:
current_count = loaded_weight.current_count
......@@ -537,35 +508,68 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(
ignore_suffixes) and name not in params_dict:
if name_mapped.endswith(
ignore_suffixes
) and name_mapped not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(
ignore_suffixes) and name not in params_dict:
......@@ -635,7 +639,8 @@ class Qwen3MoeModel(nn.Module):
return loaded_params
class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
MixtureOfExperts):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -652,7 +657,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
......@@ -660,13 +665,74 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# Set MoE hyperparameters
self.expert_weights = []
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -684,13 +750,14 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
......@@ -1095,6 +1095,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......@@ -1154,7 +1156,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale,
)
else:
from lightop import fused_rms_norm_rope_contiguous
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
......@@ -1162,7 +1163,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
from lightop import fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
......@@ -1199,6 +1200,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._q_scale, layer._k_scale, self.kv_cache_dtype)
return output_padded
\ No newline at end of file
......@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
......@@ -162,13 +164,14 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
q_scale = None,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
......@@ -180,11 +183,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
......@@ -193,24 +195,54 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
descale_q=q_scale,
descale_k=k_scale,
)
else:
o, _ = flash_mla_with_kvcache_q_nope_pe(
q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
else:
o, _ = flash_mla_with_kvcache_q_nope_pe(
q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment