Commit 02689420 authored by xuxz's avatar xuxz
Browse files

Merge branch 'v0.9.2-dev' into 'v0.9.2-dev-add_connector'

# Conflicts:
#   vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
parents ef362942 fa683b07
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.forward_context import get_forward_context
class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Auto Prepare/Finalize that wraps both DeepEP High-Throughput and
Low-Latency implementations and selects one based on prefill/decode phase.
"""
def __init__(self,
ht_prepare_finalize: mk.FusedMoEPrepareAndFinalize,
ll_prepare_finalize: mk.FusedMoEPrepareAndFinalize):
super().__init__()
self.ht_prepare_finalize = ht_prepare_finalize
self.ll_prepare_finalize = ll_prepare_finalize
self._current_phase = "decode" # default to decode (LL)
def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available:
# try:
# forward_context = get_forward_context()
# attn_metadata = forward_context.attn_metadata
# # Handle both v0 (single AttentionMetadata) and v1 (dict) formats
# if isinstance(attn_metadata, dict):
# if attn_metadata:
# attn_metadata = next(iter(attn_metadata.values()))
# else:
# attn_metadata = None
# if attn_metadata is not None and hasattr(attn_metadata,
# "num_decode_tokens"):
# # 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
# self._current_phase = ("decode"
# if attn_metadata.num_decode_tokens > 0
# else "prefill")
# except Exception:
# # If forward_context is not available, use stored phase
# pass
# Prefill uses HT, decode uses LL
if self._current_phase == "prefill":
#rint("************prefill***********")
return self.ll_prepare_finalize
else:
# print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens)
return self.ht_prepare_finalize
#return self.ht_prepare_finalize
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
pf = self._get_current_prepare_finalize()
try:
return pf.activation_format
except NotImplementedError:
# Fallback to standard format if underlying impl does not provide it.
return mk.FusedMoEActivationFormat.Standard
def topk_indices_dtype(self) -> Optional[torch.dtype]:
pf = self._get_current_prepare_finalize()
return pf.topk_indices_dtype()
def max_num_tokens_per_rank(self) -> Optional[int]:
pf = self._get_current_prepare_finalize()
return pf.max_num_tokens_per_rank()
def num_dispatchers(self) -> int:
pf = self._get_current_prepare_finalize()
return pf.num_dispatchers()
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
):
pf = self._get_current_prepare_finalize()
return pf.prepare_async(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
pf = self._get_current_prepare_finalize()
return pf.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
pf = self._get_current_prepare_finalize()
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
):
pf = self._get_current_prepare_finalize()
if hasattr(pf, "finalize_async"):
return pf.finalize_async(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
......@@ -214,8 +214,6 @@ def moe_align_block_size_lightop(
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor,
topk_weights: torch.Tensor,
......@@ -234,21 +232,34 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert w1_marlin.is_contiguous(), "Packed weights1 must be contiguous"
assert w2_marlin.is_contiguous(), "Packed weights2 must be contiguous"
# 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype
assert use_lightop, (
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
"only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
# Packed weights store the same number of elements as the original layout,
# but reshaped/reordered for Marlin kernels:
# - w1_marlin: [E, K/16, (2N)*16]
# - w2_marlin: [E, N/16, K*16]
E, k_div16, twoN_times16 = w1_marlin.shape
K_w1 = k_div16 * 16
assert K_w1 == K, f"w1_marlin K mismatch: {K_w1} vs {K}"
assert twoN_times16 % 16 == 0
twoN = twoN_times16 // 16
assert twoN % 2 == 0
N = twoN // 2
E2, K_w2, N2_w2 = w2.shape
E2, n_div16, k_times16 = w2_marlin.shape
assert E2 == E, f"w2_marlin E mismatch: {E2} vs {E}"
K_w2 = k_times16 // 16
assert K_w2 == K, f"w2_marlin K mismatch: {K_w2} vs {K}"
assert n_div16 * 16 == N, f"w2_marlin N mismatch: {n_div16 * 16} vs {N}"
if global_num_experts == -1:
global_num_experts = E
......
......@@ -5,7 +5,7 @@ import functools
import json
import os
import math
from typing import Any, Callable, Dict, Optional, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional
import torch
......@@ -13,6 +13,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
logger = init_logger(__name__)
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, get_config_quant_dtype)
......@@ -31,7 +32,7 @@ try:
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json)
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger.warning_once("Please install lmslim if you want to infer the quantitative model of moe.")
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
......@@ -43,32 +44,9 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
......@@ -1247,14 +1225,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
if envs.VLLM_USE_TOPK_RENORM:
if envs.VLLM_USE_TOPK_RENORM and renormalize is True:
from lightop import op as op
op.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
True,
renormalize,
)
else:
ops.topk_softmax(
......@@ -1326,7 +1304,9 @@ def grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
num_fused_shared_experts: Optional[int] = 0
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), (
......@@ -1339,7 +1319,7 @@ def grouped_topk(
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0)
num_token, num_experts = scores.shape
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
......@@ -1370,8 +1350,20 @@ def grouped_topk(
dim=-1,
sorted=False)
if num_fused_shared_experts != 0:
topk_ids[:, -1] = num_experts
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
......@@ -1696,6 +1688,102 @@ def fused_experts_impl(
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if not use_nn_moe:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2))
if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0.")
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop is installed and VLLM_USE_LIGHTOP=1."
)
if activation != "silu":
raise RuntimeError(
"Marlin W16A16 MoE only supports activation='silu'.")
if apply_router_weight_on_input:
raise RuntimeError(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
)
# Non-Marlin paths need the original weight shapes.
if use_nn_moe:
E, _, N = w1.size()
else:
......@@ -1704,69 +1792,20 @@ def fused_experts_impl(
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
cache13 = get_moe_cache(top_k_num,
N,
K if not use_nn_moe else w2.shape[2],
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
cache13 = torch.empty(
M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
# Optional fast path: use lmslim's Marlin W16A16 fused MoE implementation
# when explicitly requested. This reuses the same cache13 buffer as other
# fused paths for consistency.
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin
if (envs.VLLM_USE_MARLIN_W16A16_MOE
and fused_experts_impl_w16a16_marlin is not None):
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
else:
w2 = w2_cpu
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output
)
if use_int8_w8a8 is True:
if use_int8_w8a8 or use_fp8_w8a8:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
......@@ -1776,8 +1815,8 @@ def fused_experts_impl(
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
use_int8_w8a8=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=per_channel_quant,
......@@ -2115,12 +2154,17 @@ def fused_moe(
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
topk_weights, topk_ids = grouped_topk(
hidden_states = hidden_states,
gating_output = gating_output,
topk = topk,
renormalize = renormalize,
num_expert_group = num_expert_group,
topk_group = topk_group,
routed_scaling_factor = routed_scaling_factor,
num_fused_shared_experts = 1)
elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize)
......
......@@ -6,7 +6,9 @@ from math import prod
from typing import Optional, final
from dataclasses import dataclass
from collections.abc import Callable
from vllm.logger import init_logger
logger = init_logger(__name__)
import torch
import vllm.envs as envs
......@@ -843,11 +845,16 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
experts_ht: CustomizedFusedMoEPermuteExpertsUnpermute = None,
experts_ll: CustomizedFusedMoEPermuteExpertsUnpermute = None,
shared_experts: Optional[torch.nn.Module] = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.fused_experts_ht = experts_ht
self.fused_experts_ll = experts_ll
self.shared_experts = shared_experts
if self.shared_experts is not None:
......@@ -919,7 +926,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
prepare_finalize = self.prepare_finalize
fused_experts = self.fused_experts
if envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
num_ht_ll_tokens = envs.VLLM_MOE_HT_THRESHOLD
num_tokens = hidden_states.size(0)
# logger.info("num_tokens=%d", num_tokens)
if num_tokens > num_ht_ll_tokens:
prepare_finalize = self.prepare_finalize.ht_prepare_finalize
fused_experts = self.fused_experts_ht
else:
prepare_finalize = self.prepare_finalize.ll_prepare_finalize
fused_experts = self.fused_experts_ll
a1 = hidden_states
if inplace and self.shared_experts is None:
......@@ -931,7 +952,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
prepare_ret = self.prepare_finalize.prepare_async(
prepare_ret = prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
......@@ -940,7 +961,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
fused_experts.quant_config,
)
hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
......@@ -971,7 +992,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
fused_out = self.fused_experts.apply(
fused_out = fused_experts.apply(
None,
a1,
a1q,
......@@ -1008,12 +1029,12 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self.alt_stream.wait_event(self.alt_event)
hook = None
if self.prepare_finalize.activation_format == \
if prepare_finalize.activation_format == \
FusedMoEActivationFormat.BatchedExperts:
self.prepare_finalize.finalize(output, fused_out, topk_weights,
prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
hook = prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if hook is not None:
hook()
......
"""
Utilities for capturing MoE router distributions from real workloads.
This is intentionally lightweight and gated behind env vars so it has zero
runtime impact unless explicitly enabled.
Env vars (defaults from vllm.envs):
- VLLM_MOE_ROUTER_CAPTURE=0/1: enable capture (default: 0).
- VLLM_MOE_ROUTER_CAPTURE_DIR=/path: output directory for per-process dumps
(default: /tmp).
- VLLM_MOE_ROUTER_CAPTURE_RANK=N: only capture on the given torch.distributed
rank (default: -1; set to -1 to capture all ranks).
- VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS=N: max number of layers to record per
process (default: 0; 0 = unlimited).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT=A: only record calls where router_logits
has num_tokens > A (default: -1; <0 = disabled).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT=B: only record calls where router_logits
has num_tokens < B (default: -1; 0 = disabled).
Output format:
- A single `.pt` per captured num_tokens (and per rank if torch.distributed is
initialized).
- Payload includes `layers_by_num_tokens: dict[str, dict[layer_name, layer_state]]`.
- A convenience `layers` field is also included (same as
`layers_by_num_tokens[str(num_tokens)]`) for easy loading.
- For each captured MoE layer, stores a list of 2D tensors
`router_logits_chunks: list[Tensor[num_tokens_i, num_experts]]` on CPU,
typically in fp16 for space efficiency.
"""
from __future__ import annotations
import atexit
import inspect
import os
import socket
import threading
import time
from dataclasses import dataclass
from typing import Optional
import torch
import vllm.envs as envs
_DEFAULT_SKIP_STACK_FUNCS = ("profile_run", "_dummy_run",
"determine_available_memory")
@dataclass(frozen=True)
class RouterCaptureConfig:
enabled: bool = False
out_dir: str = "/tmp"
skip_profile: bool = True
skip_stack_funcs: tuple[str, ...] = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = 0
max_layers: int = 0
num_tokens_gt: Optional[int] = None
num_tokens_lt: Optional[int] = None
@staticmethod
def from_env() -> "RouterCaptureConfig":
enabled = envs.VLLM_MOE_ROUTER_CAPTURE
out_dir = envs.VLLM_MOE_ROUTER_CAPTURE_DIR
skip_profile = True
skip_stack_funcs = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = None
if envs.VLLM_MOE_ROUTER_CAPTURE_RANK >= 0:
only_rank = envs.VLLM_MOE_ROUTER_CAPTURE_RANK
max_layers = envs.VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS
num_tokens_gt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT >= 0
else None)
num_tokens_lt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT > 0
else None)
# Per-size mode requires an explicit token-count filter to avoid
# unbounded captures by default.
if num_tokens_gt_opt is None and num_tokens_lt_opt is None:
enabled = False
if (num_tokens_gt_opt is not None and num_tokens_lt_opt is not None
and num_tokens_gt_opt >= num_tokens_lt_opt):
enabled = False
return RouterCaptureConfig(enabled=enabled,
out_dir=out_dir,
skip_profile=skip_profile,
skip_stack_funcs=skip_stack_funcs,
only_rank=only_rank,
max_layers=max_layers,
num_tokens_gt=num_tokens_gt_opt,
num_tokens_lt=num_tokens_lt_opt)
def _in_profile_run(skip_stack_funcs: tuple[str, ...]) -> bool:
"""
Best-effort detection for vLLM startup profiling/warmup runs.
Startup warmups often execute MoE kernels with synthetic shapes. When
enabled, skip captures from these stacks so the first capture comes from a
real request.
"""
if not skip_stack_funcs:
return False
frame = inspect.currentframe()
try:
while frame is not None:
name = frame.f_code.co_name
if name in skip_stack_funcs:
return True
frame = frame.f_back
finally:
# Avoid reference cycles.
del frame
return False
class _RouterCapture:
def __init__(self, cfg: RouterCaptureConfig) -> None:
self.cfg = cfg
# Bucket captures by token count.
self._layers_by_num_tokens: dict[int, dict[str, dict[str, object]]] = {}
self._layer_names: set[str] = set()
self._completed_num_tokens: set[int] = set()
self._lock = threading.Lock()
self._flush_counter = 0
self._pid = os.getpid()
self._host = socket.gethostname()
self._start_time = time.time()
os.makedirs(cfg.out_dir, exist_ok=True)
atexit.register(self.flush)
def _bucket_for_num_tokens(self, num_tokens: int) -> Optional[int]:
"""Return the per-size bucket key for this record call, or None if filtered."""
if self.cfg.num_tokens_gt is None and self.cfg.num_tokens_lt is None:
return None
if self.cfg.num_tokens_gt is not None:
if int(num_tokens) <= int(self.cfg.num_tokens_gt):
return None
if self.cfg.num_tokens_lt is not None:
if int(num_tokens) >= int(self.cfg.num_tokens_lt):
return None
bucket_num_tokens = int(num_tokens)
if bucket_num_tokens != 0 and bucket_num_tokens in self._completed_num_tokens:
return None
return bucket_num_tokens
def _snapshot_layers_by_num_tokens(
self,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
) -> dict[int, dict[str, dict[str, object]]]:
snapshot: dict[int, dict[str, dict[str, object]]] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_snapshot: dict[str, dict[str, object]] = {}
for layer_name, state in bucket.items():
chunks = state.get("router_logits_chunks", [])
bucket_snapshot[layer_name] = {
"num_experts": int(state.get("num_experts", 0)),
"num_tokens": int(state.get("num_tokens", 0)),
"router_logits_chunks": list(chunks),
}
snapshot[int(num_tokens)] = bucket_snapshot
return snapshot
@torch.no_grad()
def record(self, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
if self.cfg.skip_profile and _in_profile_run(self.cfg.skip_stack_funcs):
return
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return
if router_logits.dim() != 2:
return
num_tokens, num_experts = router_logits.shape
if num_tokens == 0 or num_experts == 0:
return
bucket_num_tokens = self._bucket_for_num_tokens(int(num_tokens))
if bucket_num_tokens is None:
return
# Limit the number of recorded layers to avoid unbounded dumps.
if layer_name not in self._layer_names:
if self.cfg.max_layers != 0 and len(self._layer_names) >= self.cfg.max_layers:
return
self._layer_names.add(layer_name)
# Store on CPU to avoid consuming GPU memory during long runs.
# fp16 is typically sufficient because we primarily care about
# distribution and relative ordering (top-k), not exact values.
router_logits_cpu = router_logits.detach()
if router_logits_cpu.is_cuda:
router_logits_cpu = router_logits_cpu.to(device="cpu",
dtype=torch.float16)
else:
router_logits_cpu = router_logits_cpu.to(dtype=torch.float16)
bucket_snapshot: Optional[dict[str, dict[str, object]]] = None
should_flush = False
with self._lock:
bucket = self._layers_by_num_tokens.setdefault(bucket_num_tokens, {})
if layer_name in bucket:
return
bucket[layer_name] = {
"num_experts": int(num_experts),
"num_tokens": int(num_tokens),
"router_logits_chunks": [router_logits_cpu],
}
if self.cfg.max_layers != 0 and len(bucket) >= int(self.cfg.max_layers):
should_flush = True
bucket_snapshot = self._snapshot_layers_by_num_tokens(
{int(bucket_num_tokens): bucket})[int(bucket_num_tokens)]
self._completed_num_tokens.add(int(bucket_num_tokens))
self._layers_by_num_tokens.pop(int(bucket_num_tokens), None)
if should_flush and bucket_snapshot is not None:
self._flush_payload(
layers_by_num_tokens={int(bucket_num_tokens): bucket_snapshot},
file_tag=f"nt{int(bucket_num_tokens)}",
)
def _flush_payload(
self,
*,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
file_tag: Optional[str] = None,
) -> Optional[str]:
if not self.cfg.enabled:
return None
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return None
rank = _get_rank()
now = time.time()
ts = time.strftime("%Y%m%d_%H%M%S", time.localtime(now))
ts_us = int(now * 1_000_000)
with self._lock:
flush_idx = self._flush_counter
self._flush_counter += 1
rank_str = f"rank{rank}" if rank is not None else "rankNA"
tag = f"{file_tag}_" if file_tag else ""
out_path = os.path.join(
self.cfg.out_dir,
f"moe_router_stats_{tag}{ts_us}_{self._host}_{rank_str}_pid{self._pid}_flush{flush_idx}.pt",
)
layers_by_num_tokens_out: dict[str, object] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_out: dict[str, object] = {}
for layer_name, state in bucket.items():
bucket_out[layer_name] = {
"num_experts": int(state["num_experts"]),
"num_tokens": int(state["num_tokens"]),
"router_logits_chunks":
state["router_logits_chunks"], # type: ignore[typeddict-item]
}
layers_by_num_tokens_out[str(int(num_tokens))] = bucket_out
payload: dict[str, object] = {
"meta": {
"timestamp": ts,
"timestamp_us": ts_us,
"flush_index": int(flush_idx),
"host": self._host,
"pid": self._pid,
"rank": rank,
"wall_time_s": float(now - self._start_time),
},
"layers_by_num_tokens": layers_by_num_tokens_out,
}
# Backward-compatible convenience field when there is a single bucket.
if len(layers_by_num_tokens) == 1:
(only_bucket_key, ) = layers_by_num_tokens.keys()
payload["layers"] = layers_by_num_tokens_out[str(int(only_bucket_key))]
try:
torch.save(payload, out_path)
except Exception:
return None
return out_path
def flush(self) -> Optional[str]:
with self._lock:
if not self._layers_by_num_tokens:
return None
snapshot = self._snapshot_layers_by_num_tokens(self._layers_by_num_tokens)
return self._flush_payload(layers_by_num_tokens=snapshot)
def reset(self) -> None:
with self._lock:
self._layers_by_num_tokens.clear()
self._layer_names.clear()
self._completed_num_tokens.clear()
self._start_time = time.time()
_CAPTURE: Optional[_RouterCapture] = None
_CAPTURE_DISABLED: bool = False
def _disable_global_capture() -> None:
global _CAPTURE, _CAPTURE_DISABLED
_CAPTURE = None
_CAPTURE_DISABLED = True
def _get_rank() -> Optional[int]:
if torch.distributed.is_available() and torch.distributed.is_initialized():
try:
return torch.distributed.get_rank()
except Exception:
return None
return None
def _get_capture() -> Optional[_RouterCapture]:
global _CAPTURE, _CAPTURE_DISABLED
if _CAPTURE_DISABLED:
return None
if _CAPTURE is not None:
return _CAPTURE
cfg = RouterCaptureConfig.from_env()
if not cfg.enabled:
_disable_global_capture()
return None
if cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != cfg.only_rank:
_disable_global_capture()
return None
_CAPTURE = _RouterCapture(cfg)
return _CAPTURE
@torch.no_grad()
def maybe_record_router_logits(*, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
capture = _get_capture()
if capture is None:
return
capture.record(layer_name=layer_name, router_logits=router_logits, top_k=top_k)
def maybe_flush_router_capture(*, reset: bool = False) -> Optional[str]:
"""Flush capture buffers to disk without exiting the process."""
capture = _get_capture()
if capture is None:
return None
out_path = capture.flush()
if out_path is not None and reset:
capture.reset()
return out_path
......@@ -34,7 +34,11 @@ class SharedFusedMoE(FusedMoE):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
hidden_states_copy: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
......@@ -53,6 +57,8 @@ class SharedFusedMoE(FusedMoE):
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states_copy = hidden_states_copy,
i_s = i_s,
i_q = i_q,
)
return fused_out
......@@ -22,6 +22,7 @@ from vllm.utils import round_up
try:
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from lightop import op
except Exception:
print("INFO: Please install lmslim if you want to use int utils.\n")
from vllm.utils import cdiv
......@@ -622,52 +623,62 @@ def ep_scatter(
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
assert m_indices.shape[0] % BLOCK_E == 0
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D),
)
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return
......@@ -897,7 +908,7 @@ class EPSharedExperts(nn.Module):
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
def forward(self, x, **_):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
......
......@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
......@@ -331,7 +332,6 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
......@@ -343,7 +343,6 @@ class ReplicatedLinear(LinearBase):
quant_config,
prefix=prefix,
return_bias=return_bias)
self.eps = eps
# All the linear layer supports quant method.
assert self.quant_method is not None
......@@ -393,44 +392,18 @@ class ReplicatedLinear(LinearBase):
def forward(
self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
iqis: Optional[tuple] = None, **_
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
if quant_args is not None:
input_quant_args = quant_args
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
bias = self.bias if not self.skip_bias_add else None
......@@ -459,7 +432,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
......@@ -473,7 +445,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
quant_config,
prefix=prefix,
return_bias=return_bias)
self.eps = eps
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
......@@ -588,7 +559,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
assert len(self.kv_a_weight.shape) == 2
fused_weight = torch.cat([self.q_a_weight, self.kv_a_weight], dim=0) # TN
param.data.copy_(fused_weight)
#TODO: wjl 删掉无用的显存tensor
else:
raise ValueError(f"Unexpected weight: {source}")
......@@ -596,31 +566,17 @@ class FusedQuantedReplicatedLinear(LinearBase):
def forward(
self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor,
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
iqis: Optional[tuple] = None, **_
) -> tuple[torch.Tensor, Optional[Parameter]]:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
output_bias = self.bias if self.skip_bias_add else None
assert self.return_bias is True
if not self.return_bias:
raise RuntimeError("Not return bias. Unexpected Error.")
return output, new_residual, output_bias
return output, output_bias
else:
raise RuntimeError("Unexpected Error.")
......@@ -670,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
......@@ -858,31 +820,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
def forward(
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None
xqxs: Optional[tuple] = None,
iqis: Optional[tuple] = None, **_
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]],
]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
if self.gather_output:
# All-gather across the partitions.
......@@ -892,7 +840,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, i_q, _scales, output_bias
return output, output_bias
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None
......@@ -933,13 +881,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
self.eps = eps
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
......@@ -949,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
tp_size = get_moe_tp_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
......@@ -959,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
expect_tp_size=expect_tp_size)
expect_tp_size=expect_tp_size,
enable_dp_attn_moe=enable_dp_attn_moe)
def weight_loader(self,
param: Parameter,
......@@ -1060,6 +1012,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_rank = 0
tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
......@@ -1182,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe and hasattr(param, "enable_dp_attn_moe"):
tp_size = get_moe_tp_size()
param.enable_dp_attn_moe = self.enable_dp_attn_moe
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
......@@ -1613,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
......@@ -1621,7 +1582,13 @@ class RowParallelLinear(LinearBase):
if expect_tp_size is not None:
self.tp_rank = 0
self.tp_size = 1
self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_rank = get_moe_tp_rank()
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
......@@ -1671,6 +1638,11 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None:
tp_rank = 0
tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
......@@ -1725,6 +1697,9 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe is not None and hasattr(param, "enable_dp_attn_moe"):
param.enable_dp_attn_moe = self.enable_dp_attn_moe
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(
......
......@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe=False,
use_fused_gate: Optional[bool] = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from vllm import envs
import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter
......@@ -61,6 +61,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.t()
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
......@@ -140,7 +142,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_,
) -> torch.Tensor:
return self.fp8_linear.apply(input=x,
weight=layer.weight,
......
......@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PerTensorScaleParameter)
from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
self.input_symmetric = input_symmetric
@classmethod
......
......@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
weight = self._maybe_pad_weight(weight)
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale_inv = weight_scale_inv.T.contiguous()
else:
weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
......@@ -854,10 +858,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_fused_gate: Optional[bool] = False,
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,**_,
) -> torch.Tensor:
if enable_eplb:
assert expert_load_view is not None
......@@ -882,6 +887,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
use_fused_gate=use_fused_gate,
)
if self.rocm_aiter_moe_enabled:
......
......@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
......
......@@ -11,7 +11,12 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.quantize import quant_ops
try:
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
......@@ -252,6 +257,39 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
output = output.view(*output_shape)
return output
def hipblaslt_w8a8_channelwise_scaled_mm(
qinput: torch.Tensor,
input_2d: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs
) -> torch.Tensor:
assert qinput.is_contiguous() and weight.is_contiguous()
assert qinput.shape[-1] == weight.shape[-1]
assert qinput.dtype == weight.dtype
m = qinput.shape[0]
k = qinput.shape[1]
n = weight.shape[0]
success, output = quant_ops.hipblaslt_w8a8_channelwise_gemm(
a = qinput,
b = weight,
scale_a = scale_a,
scale_b = scale_b,
m = m,
n = n,
k = k,
transpose_flag = "NT",
out_dtype = out_dtype,
bias = bias,
)
return output.view(m, n)
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
......@@ -278,25 +316,27 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
qinput = qinput.view(-1,qinput.shape[-1])
output = triton_scaled_mm_fp8(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
# if type(output) is tuple and len(output) == 2:
# output = output[0]
# # Unpad (undo num_token_padding)
# output = torch.narrow(output, 0, 0, input_2d.shape[0])
# x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
#
# # DQ
# # C = sw * sx * (X * W) + bias
# output = output * x_scale * scale_b.t()
# if bias is not None:
# output = output + bias
return output.view(*output_shape)
def dispatch_w8a8_scaled_mm(
......@@ -310,6 +350,8 @@ def dispatch_w8a8_scaled_mm(
if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
if envs.VLLM_W8A8_BACKEND == 3:
return hipblaslt_w8a8_channelwise_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights
......
......@@ -232,6 +232,11 @@ def get_model_architecture(
'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures):
#针对使用dtype为fp16的情况的量化默认关闭"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
if model_config.quantization in {"awq", "awq_marlin", "moe_wna16"}:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '0'
if not envs.VLLM_USE_NN:
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......@@ -247,7 +252,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD") \
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
......@@ -255,8 +261,12 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
# if not envs.is_set("VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA"):
# os.environ['VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1'
......@@ -278,6 +288,8 @@ def get_model_architecture(
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"):
......@@ -287,10 +299,11 @@ def get_model_architecture(
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD") \
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
......@@ -298,8 +311,12 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
# if not envs.is_set("VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA"):
# os.environ['VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1'
......@@ -321,6 +338,8 @@ def get_model_architecture(
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"):
......@@ -353,7 +372,6 @@ def get_model_architecture(
mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
]
vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment