Commit 02689420 authored by xuxz's avatar xuxz
Browse files

Merge branch 'v0.9.2-dev' into 'v0.9.2-dev-add_connector'

# Conflicts:
#   vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
parents ef362942 fa683b07
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.forward_context import get_forward_context
class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Auto Prepare/Finalize that wraps both DeepEP High-Throughput and
Low-Latency implementations and selects one based on prefill/decode phase.
"""
def __init__(self,
ht_prepare_finalize: mk.FusedMoEPrepareAndFinalize,
ll_prepare_finalize: mk.FusedMoEPrepareAndFinalize):
super().__init__()
self.ht_prepare_finalize = ht_prepare_finalize
self.ll_prepare_finalize = ll_prepare_finalize
self._current_phase = "decode" # default to decode (LL)
def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available:
# try:
# forward_context = get_forward_context()
# attn_metadata = forward_context.attn_metadata
# # Handle both v0 (single AttentionMetadata) and v1 (dict) formats
# if isinstance(attn_metadata, dict):
# if attn_metadata:
# attn_metadata = next(iter(attn_metadata.values()))
# else:
# attn_metadata = None
# if attn_metadata is not None and hasattr(attn_metadata,
# "num_decode_tokens"):
# # 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
# self._current_phase = ("decode"
# if attn_metadata.num_decode_tokens > 0
# else "prefill")
# except Exception:
# # If forward_context is not available, use stored phase
# pass
# Prefill uses HT, decode uses LL
if self._current_phase == "prefill":
#rint("************prefill***********")
return self.ll_prepare_finalize
else:
# print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens)
return self.ht_prepare_finalize
#return self.ht_prepare_finalize
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
pf = self._get_current_prepare_finalize()
try:
return pf.activation_format
except NotImplementedError:
# Fallback to standard format if underlying impl does not provide it.
return mk.FusedMoEActivationFormat.Standard
def topk_indices_dtype(self) -> Optional[torch.dtype]:
pf = self._get_current_prepare_finalize()
return pf.topk_indices_dtype()
def max_num_tokens_per_rank(self) -> Optional[int]:
pf = self._get_current_prepare_finalize()
return pf.max_num_tokens_per_rank()
def num_dispatchers(self) -> int:
pf = self._get_current_prepare_finalize()
return pf.num_dispatchers()
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
):
pf = self._get_current_prepare_finalize()
return pf.prepare_async(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
pf = self._get_current_prepare_finalize()
return pf.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
pf = self._get_current_prepare_finalize()
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
):
pf = self._get_current_prepare_finalize()
if hasattr(pf, "finalize_async"):
return pf.finalize_async(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop( ...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop(
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor, w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor, w2_marlin: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
...@@ -234,21 +232,34 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -234,21 +232,34 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
): ):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1_marlin.is_contiguous(), "Packed weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2_marlin.is_contiguous(), "Packed weights2 must be contiguous"
# 当前只支持 bf16 fp16 # 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16] assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype compute_type = hidden_states.dtype
assert use_lightop, ( assert use_lightop, (
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE") "only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
# Packed weights store the same number of elements as the original layout,
# but reshaped/reordered for Marlin kernels:
# - w1_marlin: [E, K/16, (2N)*16]
# - w2_marlin: [E, N/16, K*16]
E, k_div16, twoN_times16 = w1_marlin.shape
K_w1 = k_div16 * 16
assert K_w1 == K, f"w1_marlin K mismatch: {K_w1} vs {K}"
assert twoN_times16 % 16 == 0
twoN = twoN_times16 // 16
assert twoN % 2 == 0
N = twoN // 2 N = twoN // 2
E2, K_w2, N2_w2 = w2.shape E2, n_div16, k_times16 = w2_marlin.shape
assert E2 == E, f"w2_marlin E mismatch: {E2} vs {E}"
K_w2 = k_times16 // 16
assert K_w2 == K, f"w2_marlin K mismatch: {K_w2} vs {K}"
assert n_div16 * 16 == N, f"w2_marlin N mismatch: {n_div16 * 16} vs {N}"
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
......
...@@ -5,7 +5,7 @@ import functools ...@@ -5,7 +5,7 @@ import functools
import json import json
import os import os
import math import math
from typing import Any, Callable, Dict, Optional, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -13,6 +13,7 @@ import vllm.envs as envs ...@@ -13,6 +13,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__)
# yapf: disable # yapf: disable
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, get_config_quant_dtype) FusedMoEQuantConfig, get_config_quant_dtype)
...@@ -31,7 +32,7 @@ try: ...@@ -31,7 +32,7 @@ try:
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json) from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json)
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8 from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") logger.warning_once("Please install lmslim if you want to infer the quantitative model of moe.")
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
...@@ -43,32 +44,9 @@ from vllm.utils import direct_register_custom_op ...@@ -43,32 +44,9 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
...@@ -1247,14 +1225,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -1247,14 +1225,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]: renormalize: bool) -> tuple[torch.Tensor, ...]:
if envs.VLLM_USE_TOPK_RENORM: if envs.VLLM_USE_TOPK_RENORM and renormalize is True:
from lightop import op as op from lightop import op as op
op.topk_softmax( op.topk_softmax(
topk_weights, topk_weights,
topk_indices, topk_indices,
token_expert_indices, token_expert_indices,
gating_output, gating_output,
True, renormalize,
) )
else: else:
ops.topk_softmax( ops.topk_softmax(
...@@ -1326,7 +1304,9 @@ def grouped_topk( ...@@ -1326,7 +1304,9 @@ def grouped_topk(
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
num_fused_shared_experts: Optional[int] = 0
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), ( assert hidden_states.size(0) == gating_output.size(0), (
...@@ -1339,7 +1319,7 @@ def grouped_topk( ...@@ -1339,7 +1319,7 @@ def grouped_topk(
else: else:
raise ValueError(f"Unsupported scoring function: {scoring_func}") raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0) num_token, num_experts = scores.shape
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased # Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights # scores for expert selection but original scores for routing weights
...@@ -1370,8 +1350,20 @@ def grouped_topk( ...@@ -1370,8 +1350,20 @@ def grouped_topk(
dim=-1, dim=-1,
sorted=False) sorted=False)
if num_fused_shared_experts != 0:
topk_ids[:, -1] = num_experts
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
...@@ -1696,6 +1688,102 @@ def fused_experts_impl( ...@@ -1696,6 +1688,102 @@ def fused_experts_impl(
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if not use_nn_moe:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2))
if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0.")
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop is installed and VLLM_USE_LIGHTOP=1."
)
if activation != "silu":
raise RuntimeError(
"Marlin W16A16 MoE only supports activation='silu'.")
if apply_router_weight_on_input:
raise RuntimeError(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
)
# Non-Marlin paths need the original weight shapes.
if use_nn_moe: if use_nn_moe:
E, _, N = w1.size() E, _, N = w1.size()
else: else:
...@@ -1704,69 +1792,20 @@ def fused_experts_impl( ...@@ -1704,69 +1792,20 @@ def fused_experts_impl(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype) cache13 = get_moe_cache(top_k_num,
N,
K if not use_nn_moe else w2.shape[2],
device=hidden_states.device,
dtype=hidden_states.dtype)
else: else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) cache13 = torch.empty(
M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
# Optional fast path: use lmslim's Marlin W16A16 fused MoE implementation if use_int8_w8a8 or use_fp8_w8a8:
# when explicitly requested. This reuses the same cache13 buffer as other
# fused paths for consistency.
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin
if (envs.VLLM_USE_MARLIN_W16A16_MOE
and fused_experts_impl_w16a16_marlin is not None):
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
else:
w2 = w2_cpu
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output
)
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states, return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
...@@ -1776,8 +1815,8 @@ def fused_experts_impl( ...@@ -1776,8 +1815,8 @@ def fused_experts_impl(
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=True, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
...@@ -2115,12 +2154,17 @@ def fused_moe( ...@@ -2115,12 +2154,17 @@ def fused_moe(
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
if use_grouped_topk: if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk_weights, topk_ids = grouped_topk(
topk, renormalize, hidden_states = hidden_states,
num_expert_group, topk_group) gating_output = gating_output,
topk = topk,
renormalize = renormalize,
num_expert_group = num_expert_group,
topk_group = topk_group,
routed_scaling_factor = routed_scaling_factor,
num_fused_shared_experts = 1)
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize) hidden_states, gating_output, topk, renormalize)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import os import os
import importlib import importlib
...@@ -28,8 +29,8 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -28,8 +29,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig) FusedMoEConfig, FusedMoEParallelConfig)
# yapf: enable # yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel, FusedMoEActivationFormat, FusedMoEModularKernel,
DeepGemmDisabledFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, DeepGemmDisabledFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled) # is_rocm_aiter_moe_enabled)
...@@ -37,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -37,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two) fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
...@@ -55,6 +57,7 @@ if current_platform.is_cuda_alike(): ...@@ -55,6 +57,7 @@ if current_platform.is_cuda_alike():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
from .deepep_auto_prepare_finalize import DeepEPAutoPrepareAndFinalize
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore
...@@ -74,6 +77,84 @@ else: ...@@ -74,6 +77,84 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES: tuple[int, ...] = (1, 128)
@functools.lru_cache
def _is_marlin_w16a16_moe_supported(
E: int,
N: int,
K: int,
top_k: int,
dtype: torch.dtype,
) -> bool:
"""Return True if lightop reports Marlin W16A16 MoE is supported.
This is a best-effort probe used to decide whether we can safely pre-pack
weights into Marlin layout (which would otherwise prevent fallback).
"""
if not (current_platform.is_cuda_alike() and torch.cuda.is_available()):
return False
if dtype not in (torch.float16, torch.bfloat16):
return False
if K % 32 != 0 or N % 16 != 0:
return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False
try:
from lightop import get_moe_cuda_marlin_config_w16a16
props = torch.cuda.get_device_properties(torch.cuda.current_device())
arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False
arch_cu = props.multi_processor_count
twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
twoN,
K,
K,
N,
top_k,
arch_name,
arch_cu,
dtype,
)
if not status:
return False
return True
except Exception:
return False
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
#
# aux_stream() is currently used for:
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
from vllm.platforms import current_platform
if _aux_stream is None and current_platform.is_cuda_alike():
_aux_stream = torch.cuda.Stream()
return _aux_stream
class FusedMoeWeightScaleSupported(Enum): class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
...@@ -140,6 +221,62 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -140,6 +221,62 @@ class FusedMoEMethodBase(QuantizeMethodBase):
num_local_experts=moe.num_local_experts, num_local_experts=moe.num_local_experts,
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
) )
elif moe.use_deepep_auto_kernels:
# Initialize both HT and LL prepare_finalize but reuse the single
# LL handle for both (sglang-style single handle)
assert moe.dp_size == all2all_manager.dp_world_size
ll_all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
num_ep_ranks=all2all_manager.world_size,
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts //
all2all_manager.world_size,
)
ll_handle = all2all_manager.get_handle(ll_all_to_all_args)
# HT prepare/finalize built on the same LL handle per request
ht_prepare_finalize = DeepEPHTPrepareAndFinalize(
ll_handle,
num_dispatchers=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank *
moe.num_local_experts,
)
use_fp8_dispatch = (moe.quant_config is not None
and moe.quant_config.quant_dtype
== current_platform.fp8_dtype()
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 and envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
ll_prepare_finalize = DeepEPLLPrepareAndFinalize(
ll_handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
use_int8_dispatch=use_int8_dispatch,
)
prepare_finalize = DeepEPAutoPrepareAndFinalize(
ht_prepare_finalize, ll_prepare_finalize)
experts_ht = self.select_gemm_impl(ht_prepare_finalize, moe)
experts_ll = self.select_gemm_impl(ll_prepare_finalize, moe)
self.topk_indices_dtype = ll_prepare_finalize.topk_indices_dtype()
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize,
experts_ll,
experts_ht=experts_ht,
experts_ll=experts_ll,
shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None,
)
return
elif moe.use_deepep_ht_kernels: elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size assert moe.dp_size == all2all_manager.dp_world_size
...@@ -170,8 +307,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -170,8 +307,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
== current_platform.fp8_dtype() == current_platform.fp8_dtype()
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 and envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
# Note (varun): Whether to use FP8 dispatch or not needs some # Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now. # profiling. Turning it off for now.
...@@ -329,6 +466,81 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -329,6 +466,81 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is supported, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (getattr(layer, "_marlin_w16a16_moe_enabled", False)
and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)):
w1 = layer.w13_weight
w2 = layer.w2_weight
if (w1.is_cuda and w2.is_cuda
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)):
try:
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(
0):
raise RuntimeError("Unexpected MoE weight shapes")
twoN, K = w1.size(1), w1.size(2)
if w2.size(1) != K:
raise RuntimeError("Unexpected MoE w2 layout")
N = w2.size(2)
if twoN != 2 * N:
raise RuntimeError("Unexpected MoE hidden dims")
if (K % 16 != 0 or K % 32 != 0 or N % 16 != 0
or twoN % 32 != 0):
raise RuntimeError("Marlin packing requires alignment")
from vllm.model_executor.layers.fused_moe.marlin_quant import (
w16a16_marlin_weight)
from torch.nn.parameter import Parameter
def _pack_per_expert(weight: torch.Tensor) -> torch.Tensor:
num_experts = weight.shape[0]
packed0 = w16a16_marlin_weight(
weight[0]).contiguous()
packed = packed0.new_empty((num_experts, ) +
packed0.shape)
packed[0].copy_(packed0)
del packed0
for i in range(1, num_experts):
tmp = w16a16_marlin_weight(
weight[i]).contiguous()
packed[i].copy_(tmp)
del tmp
return packed
with torch.no_grad():
w1_packed = _pack_per_expert(w1)
w2_packed = _pack_per_expert(w2)
new_w1 = Parameter(w1_packed, requires_grad=False)
new_w2 = Parameter(w2_packed, requires_grad=False)
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "marlin_w16a16_packed", True)
setattr(new_w2, "marlin_w16a16_packed", True)
layer.w13_weight = new_w1
layer.w2_weight = new_w2
layer._marlin_w16a16_moe_packed = True
return
except Exception:
# If packing dependencies are unavailable, fall back to the
# standard (non-Marlin) layouts.
pass
# Padding the weight for better performance on ROCm # Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
...@@ -698,6 +910,21 @@ class FusedMoE(torch.nn.Module): ...@@ -698,6 +910,21 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = 1.0, routed_scaling_factor: Optional[float] = 1.0,
): ):
super().__init__() super().__init__()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
...@@ -728,6 +955,8 @@ class FusedMoE(torch.nn.Module): ...@@ -728,6 +955,8 @@ class FusedMoE(torch.nn.Module):
self.logical_to_physical_map: Optional[torch.Tensor] = None self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None self.logical_replica_count: Optional[torch.Tensor] = None
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
# Determine expert maps # Determine expert maps
if self.use_ep: if self.use_ep:
if self.enable_eplb: if self.enable_eplb:
...@@ -814,12 +1043,30 @@ class FusedMoE(torch.nn.Module): ...@@ -814,12 +1043,30 @@ class FusedMoE(torch.nn.Module):
# please refer to the implementation in `Fp8MoEMethod`. # please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 " raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.") "quantization for now.")
if quant_config is None: if quant_config is None:
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1 moe_in_dtype = model_dtype
self._marlin_w16a16_moe_enabled = (
not envs.VLLM_USE_MOE_W16A16_TRITON
and params_dtype == moe_in_dtype
and self.activation == "silu"
and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts,
N=self.intermediate_size_per_partition,
K=self.hidden_size,
top_k=self.top_k,
dtype=moe_in_dtype,
))
self.use_nn_moe = int(os.environ.get("MOE_NN", 1)) == 1
# Marlin W16A16 MoE requires the non-NN weight layout.
if self._marlin_w16a16_moe_enabled:
self.use_nn_moe = False
else: else:
self.use_nn_moe = False self.use_nn_moe = False
self._marlin_w16a16_moe_enabled = False
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
...@@ -858,7 +1105,7 @@ class FusedMoE(torch.nn.Module): ...@@ -858,7 +1105,7 @@ class FusedMoE(torch.nn.Module):
self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_auto_kernels):
self.batched_hidden_states = torch.zeros( self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size), (moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype, dtype=moe.in_dtype,
...@@ -909,9 +1156,13 @@ class FusedMoE(torch.nn.Module): ...@@ -909,9 +1156,13 @@ class FusedMoE(torch.nn.Module):
@property @property
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_deepep_auto_kernels(self):
return self.moe_parallel_config.use_deepep_auto_kernels
@property @property
def shared_experts(self) -> Optional[torch.nn.Module]: def shared_experts(self) -> torch.nn.Module | None:
return None return None
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
...@@ -1308,8 +1559,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1308,8 +1559,9 @@ class FusedMoE(torch.nn.Module):
num_expert_group, num_expert_group,
topk_group, topk_group,
top_k, top_k,
0, # TODO also required num of shared expert is not None
routed_scaling_factor, (1 if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION else 0),
routed_scaling_factor
) )
else: else:
topk_weights, topk_ids = ops.moe_fused_gate( topk_weights, topk_ids = ops.moe_fused_gate(
...@@ -1330,7 +1582,11 @@ class FusedMoE(torch.nn.Module): ...@@ -1330,7 +1582,11 @@ class FusedMoE(torch.nn.Module):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias = e_score_correction_bias,
routed_scaling_factor = routed_scaling_factor,
# TODO also required num of shared expert is not None
num_fused_shared_experts = (1 if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION else 0)
)
if indices_type is not None: if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type) topk_ids = topk_ids.to(dtype=indices_type)
elif custom_routing_function is None: elif custom_routing_function is None:
...@@ -1436,7 +1692,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1436,7 +1692,7 @@ class FusedMoE(torch.nn.Module):
early. early.
""" """
return (self.use_pplx_kernels or self.use_deepep_ht_kernels return (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels) or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels)
def maybe_all_reduce_tensor_model_parallel( def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor): self, final_hidden_states: torch.Tensor):
...@@ -1444,13 +1700,15 @@ class FusedMoE(torch.nn.Module): ...@@ -1444,13 +1700,15 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default. The pplx combine kernel reduces across GPU ranks by default.
""" """
if (self.use_pplx_kernels or self.use_deepep_ht_kernels if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels): or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels
or self.enable_dp_attention):
return final_hidden_states return final_hidden_states
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None, # for shared expert overlap
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None, **_
...@@ -1458,16 +1716,26 @@ class FusedMoE(torch.nn.Module): ...@@ -1458,16 +1716,26 @@ class FusedMoE(torch.nn.Module):
# TODO: Once the OOM issue for the TPU backend is resolved, we will # TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op. # switch to using the moe_forward custom op.
if current_platform.is_tpu(): if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now" assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
if self.shared_experts is None: if self.shared_experts is None:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(
self.layer_name, shared_output, hidden_states = hidden_states,
i_q, i_s) router_logits = router_logits,
layer_name = self.layer_name,
shared_output = shared_output,
i_q = i_q,
i_s = i_s)
else: else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, return torch.ops.vllm.moe_forward_shared(
self.layer_name, shared_output) hidden_states = hidden_states,
router_logits = router_logits,
layer_name = self.layer_name,
hidden_states_copy = hidden_states_copy,
shared_output = shared_output,
i_q = i_q,
i_s = i_s)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1547,10 +1815,22 @@ class FusedMoE(torch.nn.Module): ...@@ -1547,10 +1815,22 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_): i_s: Optional[torch.Tensor] = None, **_)-> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None assert self.quant_method is not None
enable_shared_experts_overlap = False
if (self.shared_experts_stream is not None
and hidden_states_copy is not None
and self.shared_experts is not None
and not self.moe_parallel_config.use_pplx_kernels):
enable_shared_experts_overlap = True
hidden_states_copy.record_stream(self.shared_experts_stream)
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
if (self.moe_parallel_config.use_pplx_kernels): if (self.moe_parallel_config.use_pplx_kernels):
#or self.moe_parallel_config.use_deepep_ll_kernels): #or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl_chunked(hidden_states, router_logits)
...@@ -1558,7 +1838,10 @@ class FusedMoE(torch.nn.Module): ...@@ -1558,7 +1838,10 @@ class FusedMoE(torch.nn.Module):
do_naive_dispatch_combine: bool = ( do_naive_dispatch_combine: bool = (
self.dp_size > 1 self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_deepep_ll_kernels) and not self.moe_parallel_config.use_deepep_ll_kernels
and not self.moe_parallel_config.use_deepep_auto_kernels
and not self.enable_dp_attention
)
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)
...@@ -1619,18 +1902,48 @@ class FusedMoE(torch.nn.Module): ...@@ -1619,18 +1902,48 @@ class FusedMoE(torch.nn.Module):
use_fused_gate=self.use_fused_gate, use_fused_gate=self.use_fused_gate,
) )
if do_naive_dispatch_combine: if enable_shared_experts_overlap:
final_hidden_states = get_ep_group().combine(final_hidden_states) assert self.shared_experts is not None
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
assert hidden_states_copy is not None
if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states_copy, iqis=(i_q, i_s))
else:
shared_output = self.shared_experts(hidden_states_copy)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): final_hidden_states = (
# Default set to False. (May have to add shared expert outputs. shared_output,
if envs.VLLM_ENABLE_TBO: final_hidden_states,
final_hidden_states = self.tbo_all_reduce(final_hidden_states) )
else:
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
if envs.VLLM_ENABLE_TBO:
states = self.tbo_all_reduce(states)
else:
states = self.maybe_all_reduce_tensor_model_parallel(
states)
return states
if enable_shared_experts_overlap:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(
...@@ -1686,7 +1999,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1686,7 +1999,7 @@ class FusedMoE(torch.nn.Module):
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, shared_output: Optional[torch.Tensor] = None, layer_name: str, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor: i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
...@@ -1694,10 +2007,10 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, ...@@ -1694,10 +2007,10 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s) return self.forward_impl(hidden_states, router_logits, shared_output = shared_output, i_q = i_q, i_s = i_s)
else: else:
return self.forward_impl(hidden_states, router_logits, shared_output) return self.forward_impl(hidden_states, router_logits, shared_output = shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
...@@ -1717,22 +2030,31 @@ direct_register_custom_op( ...@@ -1717,22 +2030,31 @@ direct_register_custom_op(
) )
def moe_forward_shared( def moe_forward_shared(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
shared_output: Optional[torch.Tensor] = None hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, shared_output) if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, hidden_states_copy = hidden_states_copy, i_q = i_q, i_s = i_s)
else:
return self.forward_impl(hidden_states, router_logits, hidden_states_copy = hidden_states_copy)
def moe_forward_shared_fake( def moe_forward_shared_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
shared_output: Optional[torch.Tensor] = None hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states) fused_out = torch.empty_like(hidden_states)
...@@ -1742,7 +2064,7 @@ def moe_forward_shared_fake( ...@@ -1742,7 +2064,7 @@ def moe_forward_shared_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="moe_forward_shared", op_name="moe_forward_shared",
op_func=moe_forward_shared, op_func=moe_forward_shared,
mutates_args=["hidden_states"], mutates_args=["hidden_states", "hidden_states_copy"],
fake_impl=moe_forward_shared_fake, fake_impl=moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
...@@ -6,7 +6,9 @@ from math import prod ...@@ -6,7 +6,9 @@ from math import prod
from typing import Optional, final from typing import Optional, final
from dataclasses import dataclass from dataclasses import dataclass
from collections.abc import Callable from collections.abc import Callable
from vllm.logger import init_logger
logger = init_logger(__name__)
import torch import torch
import vllm.envs as envs import vllm.envs as envs
...@@ -843,11 +845,16 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -843,11 +845,16 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute, fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
experts_ht: CustomizedFusedMoEPermuteExpertsUnpermute = None,
experts_ll: CustomizedFusedMoEPermuteExpertsUnpermute = None,
shared_experts: Optional[torch.nn.Module] = None, shared_experts: Optional[torch.nn.Module] = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.fused_experts_ht = experts_ht
self.fused_experts_ll = experts_ll
self.shared_experts = shared_experts self.shared_experts = shared_experts
if self.shared_experts is not None: if self.shared_experts is not None:
...@@ -919,7 +926,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -919,7 +926,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
prepare_finalize = self.prepare_finalize
fused_experts = self.fused_experts
if envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
num_ht_ll_tokens = envs.VLLM_MOE_HT_THRESHOLD
num_tokens = hidden_states.size(0)
# logger.info("num_tokens=%d", num_tokens)
if num_tokens > num_ht_ll_tokens:
prepare_finalize = self.prepare_finalize.ht_prepare_finalize
fused_experts = self.fused_experts_ht
else:
prepare_finalize = self.prepare_finalize.ll_prepare_finalize
fused_experts = self.fused_experts_ll
a1 = hidden_states a1 = hidden_states
if inplace and self.shared_experts is None: if inplace and self.shared_experts is None:
...@@ -931,7 +952,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -931,7 +952,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
prepare_ret = self.prepare_finalize.prepare_async( prepare_ret = prepare_finalize.prepare_async(
a1, a1,
a1_scale, a1_scale,
a2_scale, a2_scale,
...@@ -940,7 +961,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -940,7 +961,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
global_num_experts, global_num_experts,
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, fused_experts.quant_config,
) )
hook, receiver = ( hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret) prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
...@@ -971,7 +992,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -971,7 +992,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case. # and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else: else:
fused_out = self.fused_experts.apply( fused_out = fused_experts.apply(
None, None,
a1, a1,
a1q, a1q,
...@@ -1008,12 +1029,12 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -1008,12 +1029,12 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self.alt_stream.wait_event(self.alt_event) self.alt_stream.wait_event(self.alt_event)
hook = None hook = None
if self.prepare_finalize.activation_format == \ if prepare_finalize.activation_format == \
FusedMoEActivationFormat.BatchedExperts: FusedMoEActivationFormat.BatchedExperts:
self.prepare_finalize.finalize(output, fused_out, topk_weights, prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True) topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
else: else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights, hook = prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True) topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if hook is not None: if hook is not None:
hook() hook()
......
"""
Utilities for capturing MoE router distributions from real workloads.
This is intentionally lightweight and gated behind env vars so it has zero
runtime impact unless explicitly enabled.
Env vars (defaults from vllm.envs):
- VLLM_MOE_ROUTER_CAPTURE=0/1: enable capture (default: 0).
- VLLM_MOE_ROUTER_CAPTURE_DIR=/path: output directory for per-process dumps
(default: /tmp).
- VLLM_MOE_ROUTER_CAPTURE_RANK=N: only capture on the given torch.distributed
rank (default: -1; set to -1 to capture all ranks).
- VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS=N: max number of layers to record per
process (default: 0; 0 = unlimited).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT=A: only record calls where router_logits
has num_tokens > A (default: -1; <0 = disabled).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT=B: only record calls where router_logits
has num_tokens < B (default: -1; 0 = disabled).
Output format:
- A single `.pt` per captured num_tokens (and per rank if torch.distributed is
initialized).
- Payload includes `layers_by_num_tokens: dict[str, dict[layer_name, layer_state]]`.
- A convenience `layers` field is also included (same as
`layers_by_num_tokens[str(num_tokens)]`) for easy loading.
- For each captured MoE layer, stores a list of 2D tensors
`router_logits_chunks: list[Tensor[num_tokens_i, num_experts]]` on CPU,
typically in fp16 for space efficiency.
"""
from __future__ import annotations
import atexit
import inspect
import os
import socket
import threading
import time
from dataclasses import dataclass
from typing import Optional
import torch
import vllm.envs as envs
_DEFAULT_SKIP_STACK_FUNCS = ("profile_run", "_dummy_run",
"determine_available_memory")
@dataclass(frozen=True)
class RouterCaptureConfig:
enabled: bool = False
out_dir: str = "/tmp"
skip_profile: bool = True
skip_stack_funcs: tuple[str, ...] = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = 0
max_layers: int = 0
num_tokens_gt: Optional[int] = None
num_tokens_lt: Optional[int] = None
@staticmethod
def from_env() -> "RouterCaptureConfig":
enabled = envs.VLLM_MOE_ROUTER_CAPTURE
out_dir = envs.VLLM_MOE_ROUTER_CAPTURE_DIR
skip_profile = True
skip_stack_funcs = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = None
if envs.VLLM_MOE_ROUTER_CAPTURE_RANK >= 0:
only_rank = envs.VLLM_MOE_ROUTER_CAPTURE_RANK
max_layers = envs.VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS
num_tokens_gt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT >= 0
else None)
num_tokens_lt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT > 0
else None)
# Per-size mode requires an explicit token-count filter to avoid
# unbounded captures by default.
if num_tokens_gt_opt is None and num_tokens_lt_opt is None:
enabled = False
if (num_tokens_gt_opt is not None and num_tokens_lt_opt is not None
and num_tokens_gt_opt >= num_tokens_lt_opt):
enabled = False
return RouterCaptureConfig(enabled=enabled,
out_dir=out_dir,
skip_profile=skip_profile,
skip_stack_funcs=skip_stack_funcs,
only_rank=only_rank,
max_layers=max_layers,
num_tokens_gt=num_tokens_gt_opt,
num_tokens_lt=num_tokens_lt_opt)
def _in_profile_run(skip_stack_funcs: tuple[str, ...]) -> bool:
"""
Best-effort detection for vLLM startup profiling/warmup runs.
Startup warmups often execute MoE kernels with synthetic shapes. When
enabled, skip captures from these stacks so the first capture comes from a
real request.
"""
if not skip_stack_funcs:
return False
frame = inspect.currentframe()
try:
while frame is not None:
name = frame.f_code.co_name
if name in skip_stack_funcs:
return True
frame = frame.f_back
finally:
# Avoid reference cycles.
del frame
return False
class _RouterCapture:
def __init__(self, cfg: RouterCaptureConfig) -> None:
self.cfg = cfg
# Bucket captures by token count.
self._layers_by_num_tokens: dict[int, dict[str, dict[str, object]]] = {}
self._layer_names: set[str] = set()
self._completed_num_tokens: set[int] = set()
self._lock = threading.Lock()
self._flush_counter = 0
self._pid = os.getpid()
self._host = socket.gethostname()
self._start_time = time.time()
os.makedirs(cfg.out_dir, exist_ok=True)
atexit.register(self.flush)
def _bucket_for_num_tokens(self, num_tokens: int) -> Optional[int]:
"""Return the per-size bucket key for this record call, or None if filtered."""
if self.cfg.num_tokens_gt is None and self.cfg.num_tokens_lt is None:
return None
if self.cfg.num_tokens_gt is not None:
if int(num_tokens) <= int(self.cfg.num_tokens_gt):
return None
if self.cfg.num_tokens_lt is not None:
if int(num_tokens) >= int(self.cfg.num_tokens_lt):
return None
bucket_num_tokens = int(num_tokens)
if bucket_num_tokens != 0 and bucket_num_tokens in self._completed_num_tokens:
return None
return bucket_num_tokens
def _snapshot_layers_by_num_tokens(
self,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
) -> dict[int, dict[str, dict[str, object]]]:
snapshot: dict[int, dict[str, dict[str, object]]] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_snapshot: dict[str, dict[str, object]] = {}
for layer_name, state in bucket.items():
chunks = state.get("router_logits_chunks", [])
bucket_snapshot[layer_name] = {
"num_experts": int(state.get("num_experts", 0)),
"num_tokens": int(state.get("num_tokens", 0)),
"router_logits_chunks": list(chunks),
}
snapshot[int(num_tokens)] = bucket_snapshot
return snapshot
@torch.no_grad()
def record(self, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
if self.cfg.skip_profile and _in_profile_run(self.cfg.skip_stack_funcs):
return
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return
if router_logits.dim() != 2:
return
num_tokens, num_experts = router_logits.shape
if num_tokens == 0 or num_experts == 0:
return
bucket_num_tokens = self._bucket_for_num_tokens(int(num_tokens))
if bucket_num_tokens is None:
return
# Limit the number of recorded layers to avoid unbounded dumps.
if layer_name not in self._layer_names:
if self.cfg.max_layers != 0 and len(self._layer_names) >= self.cfg.max_layers:
return
self._layer_names.add(layer_name)
# Store on CPU to avoid consuming GPU memory during long runs.
# fp16 is typically sufficient because we primarily care about
# distribution and relative ordering (top-k), not exact values.
router_logits_cpu = router_logits.detach()
if router_logits_cpu.is_cuda:
router_logits_cpu = router_logits_cpu.to(device="cpu",
dtype=torch.float16)
else:
router_logits_cpu = router_logits_cpu.to(dtype=torch.float16)
bucket_snapshot: Optional[dict[str, dict[str, object]]] = None
should_flush = False
with self._lock:
bucket = self._layers_by_num_tokens.setdefault(bucket_num_tokens, {})
if layer_name in bucket:
return
bucket[layer_name] = {
"num_experts": int(num_experts),
"num_tokens": int(num_tokens),
"router_logits_chunks": [router_logits_cpu],
}
if self.cfg.max_layers != 0 and len(bucket) >= int(self.cfg.max_layers):
should_flush = True
bucket_snapshot = self._snapshot_layers_by_num_tokens(
{int(bucket_num_tokens): bucket})[int(bucket_num_tokens)]
self._completed_num_tokens.add(int(bucket_num_tokens))
self._layers_by_num_tokens.pop(int(bucket_num_tokens), None)
if should_flush and bucket_snapshot is not None:
self._flush_payload(
layers_by_num_tokens={int(bucket_num_tokens): bucket_snapshot},
file_tag=f"nt{int(bucket_num_tokens)}",
)
def _flush_payload(
self,
*,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
file_tag: Optional[str] = None,
) -> Optional[str]:
if not self.cfg.enabled:
return None
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return None
rank = _get_rank()
now = time.time()
ts = time.strftime("%Y%m%d_%H%M%S", time.localtime(now))
ts_us = int(now * 1_000_000)
with self._lock:
flush_idx = self._flush_counter
self._flush_counter += 1
rank_str = f"rank{rank}" if rank is not None else "rankNA"
tag = f"{file_tag}_" if file_tag else ""
out_path = os.path.join(
self.cfg.out_dir,
f"moe_router_stats_{tag}{ts_us}_{self._host}_{rank_str}_pid{self._pid}_flush{flush_idx}.pt",
)
layers_by_num_tokens_out: dict[str, object] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_out: dict[str, object] = {}
for layer_name, state in bucket.items():
bucket_out[layer_name] = {
"num_experts": int(state["num_experts"]),
"num_tokens": int(state["num_tokens"]),
"router_logits_chunks":
state["router_logits_chunks"], # type: ignore[typeddict-item]
}
layers_by_num_tokens_out[str(int(num_tokens))] = bucket_out
payload: dict[str, object] = {
"meta": {
"timestamp": ts,
"timestamp_us": ts_us,
"flush_index": int(flush_idx),
"host": self._host,
"pid": self._pid,
"rank": rank,
"wall_time_s": float(now - self._start_time),
},
"layers_by_num_tokens": layers_by_num_tokens_out,
}
# Backward-compatible convenience field when there is a single bucket.
if len(layers_by_num_tokens) == 1:
(only_bucket_key, ) = layers_by_num_tokens.keys()
payload["layers"] = layers_by_num_tokens_out[str(int(only_bucket_key))]
try:
torch.save(payload, out_path)
except Exception:
return None
return out_path
def flush(self) -> Optional[str]:
with self._lock:
if not self._layers_by_num_tokens:
return None
snapshot = self._snapshot_layers_by_num_tokens(self._layers_by_num_tokens)
return self._flush_payload(layers_by_num_tokens=snapshot)
def reset(self) -> None:
with self._lock:
self._layers_by_num_tokens.clear()
self._layer_names.clear()
self._completed_num_tokens.clear()
self._start_time = time.time()
_CAPTURE: Optional[_RouterCapture] = None
_CAPTURE_DISABLED: bool = False
def _disable_global_capture() -> None:
global _CAPTURE, _CAPTURE_DISABLED
_CAPTURE = None
_CAPTURE_DISABLED = True
def _get_rank() -> Optional[int]:
if torch.distributed.is_available() and torch.distributed.is_initialized():
try:
return torch.distributed.get_rank()
except Exception:
return None
return None
def _get_capture() -> Optional[_RouterCapture]:
global _CAPTURE, _CAPTURE_DISABLED
if _CAPTURE_DISABLED:
return None
if _CAPTURE is not None:
return _CAPTURE
cfg = RouterCaptureConfig.from_env()
if not cfg.enabled:
_disable_global_capture()
return None
if cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != cfg.only_rank:
_disable_global_capture()
return None
_CAPTURE = _RouterCapture(cfg)
return _CAPTURE
@torch.no_grad()
def maybe_record_router_logits(*, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
capture = _get_capture()
if capture is None:
return
capture.record(layer_name=layer_name, router_logits=router_logits, top_k=top_k)
def maybe_flush_router_capture(*, reset: bool = False) -> Optional[str]:
"""Flush capture buffers to disk without exiting the process."""
capture = _get_capture()
if capture is None:
return None
out_path = capture.flush()
if out_path is not None and reset:
capture.reset()
return out_path
...@@ -34,7 +34,11 @@ class SharedFusedMoE(FusedMoE): ...@@ -34,7 +34,11 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: hidden_states_copy: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped: if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states) shared_out = self._shared_experts(hidden_states)
...@@ -53,6 +57,8 @@ class SharedFusedMoE(FusedMoE): ...@@ -53,6 +57,8 @@ class SharedFusedMoE(FusedMoE):
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
hidden_states_copy = hidden_states_copy,
i_s = i_s,
i_q = i_q,
) )
return fused_out return fused_out
...@@ -22,6 +22,7 @@ from vllm.utils import round_up ...@@ -22,6 +22,7 @@ from vllm.utils import round_up
try: try:
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from lightop import op
except Exception: except Exception:
print("INFO: Please install lmslim if you want to use int utils.\n") print("INFO: Please install lmslim if you want to use int utils.\n")
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -622,52 +623,62 @@ def ep_scatter( ...@@ -622,52 +623,62 @@ def ep_scatter(
num_experts = num_recv_tokens_per_expert.shape[0] num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1] hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1] scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0 if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
_fwd_kernel_ep_scatter_1[(grid,)]( assert m_indices.shape[0] % BLOCK_E == 0
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8) _fwd_kernel_ep_scatter_1[(grid,)](
_fwd_kernel_ep_scatter_2[(grid,)]( num_recv_tokens_per_expert,
recv_topk.shape[0], expert_start_loc,
expert_start_loc, m_indices,
recv_x, num_experts=num_experts,
recv_x.stride(0), num_warps=num_warps,
recv_x.stride(1), BLOCK_E=BLOCK_E,
recv_x_scale, BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
recv_x_scale.stride(0), )
recv_x_scale.stride(1),
recv_topk, grid = min(recv_topk.shape[0], 1024 * 8)
recv_topk.stride(0), _fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.stride(1), recv_topk.shape[0],
output_tensor, expert_start_loc,
output_tensor.stride(0), recv_x,
output_tensor.stride(1), recv_x.stride(0),
output_tensor_scale, recv_x.stride(1),
output_tensor_scale.stride(0), recv_x_scale,
output_tensor_scale.stride(1), recv_x_scale.stride(0),
output_index, recv_x_scale.stride(1),
output_index.stride(0), recv_topk,
output_index.stride(1), recv_topk.stride(0),
topk_num=recv_topk.shape[1], recv_topk.stride(1),
expert_map=expert_map, output_tensor,
HAS_EXPERT_MAP=expert_map is not None, output_tensor.stride(0),
num_warps=num_warps, output_tensor.stride(1),
HIDDEN_SIZE=hidden_size, output_tensor_scale,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), output_tensor_scale.stride(0),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D, output_tensor_scale.stride(1),
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D), output_index,
) output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return return
...@@ -897,7 +908,7 @@ class EPSharedExperts(nn.Module): ...@@ -897,7 +908,7 @@ class EPSharedExperts(nn.Module):
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x, **_):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
......
...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
RowvLLMParameter) RowvLLMParameter)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
# yapf: enable # yapf: enable
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -331,7 +332,6 @@ class ReplicatedLinear(LinearBase): ...@@ -331,7 +332,6 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -343,7 +343,6 @@ class ReplicatedLinear(LinearBase): ...@@ -343,7 +343,6 @@ class ReplicatedLinear(LinearBase):
quant_config, quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias)
self.eps = eps
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
...@@ -393,44 +392,18 @@ class ReplicatedLinear(LinearBase): ...@@ -393,44 +392,18 @@ class ReplicatedLinear(LinearBase):
def forward( def forward(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, iqis: Optional[tuple] = None, **_
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]], tuple[torch.Tensor, Optional[Parameter]]]:
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None): bias = self.bias if not self.skip_bias_add else None
if quant_args is not None: assert self.quant_method is not None
input_quant_args = quant_args output = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
output_bias = self.bias if self.skip_bias_add else None
bias = self.bias if not self.skip_bias_add else None if not self.return_bias:
assert self.quant_method is not None return output
output = self.quant_method.apply(self, input_, bias, input_quant_args) return output, output_bias
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
else: else:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
...@@ -459,7 +432,6 @@ class FusedQuantedReplicatedLinear(LinearBase): ...@@ -459,7 +432,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -473,7 +445,6 @@ class FusedQuantedReplicatedLinear(LinearBase): ...@@ -473,7 +445,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
quant_config, quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias)
self.eps = eps
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
...@@ -588,7 +559,6 @@ class FusedQuantedReplicatedLinear(LinearBase): ...@@ -588,7 +559,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
assert len(self.kv_a_weight.shape) == 2 assert len(self.kv_a_weight.shape) == 2
fused_weight = torch.cat([self.q_a_weight, self.kv_a_weight], dim=0) # TN fused_weight = torch.cat([self.q_a_weight, self.kv_a_weight], dim=0) # TN
param.data.copy_(fused_weight) param.data.copy_(fused_weight)
#TODO: wjl 删掉无用的显存tensor
else: else:
raise ValueError(f"Unexpected weight: {source}") raise ValueError(f"Unexpected weight: {source}")
...@@ -596,31 +566,17 @@ class FusedQuantedReplicatedLinear(LinearBase): ...@@ -596,31 +566,17 @@ class FusedQuantedReplicatedLinear(LinearBase):
def forward( def forward(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, iqis: Optional[tuple] = None, **_
residual: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[Parameter]]:
update_hd: Optional[bool] = True if envs.USE_FUSED_RMS_QUANT and iqis is not None:
) -> Union[torch.Tensor,
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args) output = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
assert self.return_bias is True assert self.return_bias is True
if not self.return_bias: if not self.return_bias:
raise RuntimeError("Not return bias. Unexpected Error.") raise RuntimeError("Not return bias. Unexpected Error.")
return output, new_residual, output_bias return output, output_bias
else: else:
raise RuntimeError("Unexpected Error.") raise RuntimeError("Unexpected Error.")
...@@ -670,12 +626,18 @@ class ColumnParallelLinear(LinearBase): ...@@ -670,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None: if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size self.tp_size = self.expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition] self.output_partition_sizes = [self.output_size_per_partition]
...@@ -858,31 +820,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -858,31 +820,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
def forward( def forward(
self, input_, self, input_,
rms_weight: Optional[torch.Tensor] = None, xqxs: Optional[tuple] = None,
residual: Optional[torch.Tensor] = None, iqis: Optional[tuple] = None, **_
update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]], tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]],
]: ]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args) output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
...@@ -892,7 +840,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -892,7 +840,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: if not self.return_bias:
return output return output
return output, new_residual, i_q, _scales, output_bias return output, output_bias
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
...@@ -933,13 +881,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -933,13 +881,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
self.eps = eps
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -949,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -949,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
tp_size = get_moe_tp_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=sum(output_sizes), output_size=sum(output_sizes),
...@@ -959,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -959,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
expect_tp_size=expect_tp_size) expect_tp_size=expect_tp_size,
enable_dp_attn_moe=enable_dp_attn_moe)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -1060,6 +1012,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1060,6 +1012,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if self.expect_tp_size is not None and self.expect_tp_size == 1: if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
tp_size = 1 tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
...@@ -1182,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1182,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if hasattr(param, "expect_tp_size"): if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe and hasattr(param, "enable_dp_attn_moe"):
tp_size = get_moe_tp_size()
param.enable_dp_attn_moe = self.enable_dp_attn_moe
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod) Fp8LinearMethod, Fp8MoEMethod)
...@@ -1613,6 +1573,7 @@ class RowParallelLinear(LinearBase): ...@@ -1613,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -1621,7 +1582,13 @@ class RowParallelLinear(LinearBase): ...@@ -1621,7 +1582,13 @@ class RowParallelLinear(LinearBase):
if expect_tp_size is not None: if expect_tp_size is not None:
self.tp_rank = 0 self.tp_rank = 0
self.tp_size = 1 self.tp_size = 1
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_rank = get_moe_tp_rank()
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
...@@ -1671,6 +1638,11 @@ class RowParallelLinear(LinearBase): ...@@ -1671,6 +1638,11 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None: if self.expect_tp_size is not None:
tp_rank = 0 tp_rank = 0
tp_size = 1 tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False) is_sharded_weight = getattr(param, "is_sharded_weight", False)
...@@ -1725,6 +1697,9 @@ class RowParallelLinear(LinearBase): ...@@ -1725,6 +1697,9 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"): if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe is not None and hasattr(param, "enable_dp_attn_moe"):
param.enable_dp_attn_moe = self.enable_dp_attn_moe
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
......
...@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe=False,
use_fused_gate: Optional[bool] = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -29,6 +29,7 @@ from vllm.utils import round_up ...@@ -29,6 +29,7 @@ from vllm.utils import round_up
try: try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep, fuse_silu_mul_quant from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep, fuse_silu_mul_quant
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
from lmslim.layers.fused_moe.fuse_moe_fp8_marlin import fused_experts_impl_fp8_marlin
from lightop import m_grouped_w8a8_gemm_nt_contig_asm from lightop import m_grouped_w8a8_gemm_nt_contig_asm
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -37,8 +38,27 @@ logger = init_logger(__name__) ...@@ -37,8 +38,27 @@ logger = init_logger(__name__)
__all__ = [ __all__ = [
"CompressedTensorsW8A8Int8MarlinMoEMethod", "CompressedTensorsW8A8Int8MarlinMoEMethod",
"CompressedTensorsW8A8FP8MarlinMoEMethod",
] ]
def fp32_to_fp8_e4m3fn(t: torch.Tensor) -> torch.Tensor:
"""更合理的FP32到Float8_e4m3fn转换,使用最近值而不是简单舍弃尾数"""
# torch.float8_e4m3fn的数值范围约[-448, 448]
fp8_min, fp8_max = -448.0, 448.0
t_clamped = t.clamp(min=fp8_min, max=fp8_max)
# 保证不会下溢到0
# 转换前到float16再转fp8可能提升精度(float8实现本身通常通过float16做rounding)
t_fp16 = t_clamped.to(torch.float16)
return t_fp16.to(torch.float8_e4m3fn)
def w8a8_fp8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
k_tile=16,
n_tile=16, ):
size_n, size_k = w8a8_w.shape
assert size_n % k_tile == 0 and size_k % n_tile == 0, "k_tile / n_tile 必须能整除对应维度"
w8a8_w = w8a8_w.reshape((size_n // n_tile, n_tile, size_k // k_tile, k_tile))
w8a8_w = w8a8_w.permute((0, 2, 1, 3)).contiguous()
w8a8_w = w8a8_w.reshape((size_n // k_tile, size_k * k_tile))
return w8a8_w
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase): class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
@staticmethod @staticmethod
...@@ -46,17 +66,492 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase): ...@@ -46,17 +66,492 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501 quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501
layer: torch.nn.Module, layer: torch.nn.Module,
) -> "CompressedTensorsMarlinMoEMethod": ) -> "CompressedTensorsMarlinMoEMethod":
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights") weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get( input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): if quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8FP8MarlinMoEMethod(quant_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config) return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config)
else: else:
raise RuntimeError( raise RuntimeError(
f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}") f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsMarlinConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.use_deepgemm = False
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
self.block_shape = [256, 256]
self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM or envs.VLLM_ALL2ALL_BACKEND == "deepep_auto"
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepgemm:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_fp8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in.float() if w1_marlin_in.dtype == torch.float8_e4m3fn else w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w1_marlin = fp32_to_fp8_e4m3fn(w1_marlin)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepgemm:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_fp8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in.float() if w2_marlin_in.dtype == torch.float8_e4m3fn else w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
w2_marlin = fp32_to_fp8_e4m3fn(w2_marlin)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def masked_groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int, ):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def contiguous_groupgemm_workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_num_tokens_cpu: torch.Tensor
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert self.block_shape is not None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m = self.block_shape[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_num_tokens_cpu
)
assert M_sum % block_m == 0
workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype, M_sum)
def w8a8_groupgemm_masked_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.masked_groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
# expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.ep_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
# ) // global_num_experts
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def w8a8_groupgemm_contiguous_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
a1q = q_x
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype, M_sum) = self.contiguous_groupgemm_workspace_shapes(
x, q_x, topk_ids.size(0), N, K, topk_ids.size(1), global_num_experts,
local_num_experts, expert_num_tokens_cpu)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=x.device,
dtype=workspace_dtype)
mm1_out = _resize_cache(workspace13, (M_sum, N))
mm2_out = _resize_cache(workspace2, (M_sum, K))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(
workspace13.view(dtype=a1q.dtype), (M_sum, N // 2)
)
fused_out = _resize_cache(workspace13, fused_out_shape)
a1q_perm = _resize_cache(workspace2.view(dtype=a1q.dtype), (M_sum, K))
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
block_shape=self.block_shape,
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
aq_out=a1q_perm,
M_sum=M_sum
)
m_grouped_w8a8_gemm_nt_contig_asm(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
a2q, a2q_scale = fuse_silu_mul_quant(mm1_out)
# a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, expert_ids=expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=fused_out,
)
return fused_out
def fused_moe_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_):
return fused_experts_impl_fp8_marlin(
hidden_states=x if q_x is None else q_x,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8FP8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None, )
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
self.max_num_tokens_per_rank = max_num_tokens_per_rank
logger.debug(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_fp8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_masked_forward
)
else:
logger.debug(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
use_fp8_w8a8=True if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else False,
per_act_token_quant=True if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else False,
fused_experts=self.w8a8_groupgemm_contiguous_forward if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else self.fused_moe_forward
)
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__( def __init__(
self, self,
...@@ -90,7 +585,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -90,7 +585,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
self.ep_size = get_ep_group().world_size self.ep_size = get_ep_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.use_deepgemm = False self.use_deepgemm = False
if self.use_deepep: if self.use_deepep:
...@@ -98,7 +594,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -98,7 +594,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
assert all2all_manager is not None assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size self.num_dispatchers = all2all_manager.world_size
self.block_shape = [256, 256] self.block_shape = [256, 256]
self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM or envs.VLLM_ALL2ALL_BACKEND == "deepep_auto"
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional from typing import Callable, Optional
from vllm import envs
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter from torch.nn import Parameter
...@@ -61,6 +61,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -61,6 +61,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
# If channelwise, scales are already lined up, so just transpose. # If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight weight = layer.weight
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.t()
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None) input_scale = getattr(layer, 'input_scale', None)
...@@ -140,7 +142,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -140,7 +142,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def apply_weights(self, def apply_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None,input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_,
) -> torch.Tensor:
return self.fp8_linear.apply(input=x, return self.fp8_linear.apply(input=x,
weight=layer.weight, weight=layer.weight,
......
...@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric: bool): input_symmetric: bool):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
self.input_symmetric = input_symmetric self.input_symmetric = input_symmetric
@classmethod @classmethod
......
...@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data weight_scale_inv = layer.weight_scale_inv.data
weight = self._maybe_pad_weight(weight) if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale_inv = weight_scale_inv.T.contiguous()
else:
weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses. # Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv, layer.weight_scale_inv = Parameter(weight_scale_inv,
...@@ -854,10 +858,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -854,10 +858,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_fused_gate: Optional[bool] = False,
enable_eplb: bool = False, enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,**_,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
assert expert_load_view is not None assert expert_load_view is not None
...@@ -882,6 +887,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -882,6 +887,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_load_view=expert_load_view, expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map, logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count, logical_replica_count=logical_replica_count,
use_fused_gate=use_fused_gate,
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
......
...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config): def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0] n=layer.weight.shape[0]
......
...@@ -6,6 +6,8 @@ import functools ...@@ -6,6 +6,8 @@ import functools
import json import json
import os import os
from typing import Any, Callable, Optional, Union, List from typing import Any, Callable, Optional, Union, List
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
import torch import torch
...@@ -19,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -19,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
try:
from lmslim.layers.gemm.fp8_utils import per_token_group_quant_fp8,w8a8_block_fp8_matmul
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -83,7 +89,7 @@ if current_platform.is_rocm(): ...@@ -83,7 +89,7 @@ if current_platform.is_rocm():
def dispatch_w8a8_blockscale_func( def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool use_cutlass: bool, use_aiter_and_is_supported: bool, use_blaslt: bool
) -> Callable[[ ) -> Callable[[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
...@@ -96,6 +102,9 @@ def dispatch_w8a8_blockscale_func( ...@@ -96,6 +102,9 @@ def dispatch_w8a8_blockscale_func(
return cutlass_scaled_mm return cutlass_scaled_mm
if (use_aiter_and_is_supported): if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if use_blaslt:
return hipblaslt_w8a8_block_fp8_matmul
return w8a8_block_fp8_matmul return w8a8_block_fp8_matmul
...@@ -127,7 +136,11 @@ def apply_w8a8_block_fp8_linear( ...@@ -127,7 +136,11 @@ def apply_w8a8_block_fp8_linear(
assert input_scale is None assert input_scale is None
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype output_dtype = input.dtype
if should_use_deepgemm(output_dtype, weight): if should_use_deepgemm(output_dtype, weight):
...@@ -166,9 +179,12 @@ def apply_w8a8_block_fp8_linear( ...@@ -166,9 +179,12 @@ def apply_w8a8_block_fp8_linear(
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
else: else:
use_cutlass = False use_cutlass = False
use_blaslt = False
if envs.VLLM_W8A8_BACKEND == 3:
use_blaslt = True
w8a8_blockscale_func = dispatch_w8a8_blockscale_func( w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported) use_cutlass, use_aiter_and_is_supported, use_blaslt)
if use_cutlass: if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass) input_2d, block_size[1], column_major_scales=use_cutlass)
...@@ -197,7 +213,11 @@ def apply_w8a8_block_fp8_linear_fake( ...@@ -197,7 +213,11 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False, use_aiter_and_is_supported: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device) return torch.empty(output_shape, dtype=input.dtype, device=input.device)
...@@ -240,333 +260,9 @@ def block_quant_to_tensor_quant( ...@@ -240,333 +260,9 @@ def block_quant_to_tensor_quant(
return x_q_tensor, scale return x_q_tensor, scale
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
# Ensure offset calculation uses int64 for y_s_ptr
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
tl.int64)
y_s_ptr += y_s_ptr_offset
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
assert out_q is None or out_q.shape == x.shape
x_q = out_q
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s def hipblaslt_w8a8_block_fp8_matmul(
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block FP8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
"Using default W8A8 Block FP8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s",
config_file_path,
)
return None
def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
...@@ -574,80 +270,19 @@ def w8a8_block_fp8_matmul( ...@@ -574,80 +270,19 @@ def w8a8_block_fp8_matmul(
block_size: list[int], block_size: list[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise m, k = A.shape
quantization. _, n = B.shape
It takes two input tensors `A` and `B` with scales `As` and `Bs`. enum_block_size = BlockSize.block_128x128
The output is returned in the specified `output_dtype`. if block_size[0] == 64:
Args: enum_block_size = BlockSize.block_64x64
A: The input tensor, e.g., activation. elif block_size[0] == 128:
B: The input tensor, e.g., weight. enum_block_size = BlockSize.block_128x128
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# Get the optimal config if there is one
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Default config print(f"[WARN] Unsupported block_size: {block_size}. Falling back to BlockSize.block_128x128")
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1] _, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs,
config = { m, n, k, 'NN', output_dtype,
"BLOCK_SIZE_M": 64, enum_block_size, None)
"BLOCK_SIZE_N": block_size[0], return d
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2,
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
...@@ -11,7 +11,12 @@ from vllm.config import CompilationLevel, get_current_vllm_config ...@@ -11,7 +11,12 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.quantize import quant_ops
try:
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None TORCH_DEVICE_IDENTITY = None
...@@ -252,6 +257,39 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -252,6 +257,39 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
output = output.view(*output_shape) output = output.view(*output_shape)
return output return output
def hipblaslt_w8a8_channelwise_scaled_mm(
qinput: torch.Tensor,
input_2d: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs
) -> torch.Tensor:
assert qinput.is_contiguous() and weight.is_contiguous()
assert qinput.shape[-1] == weight.shape[-1]
assert qinput.dtype == weight.dtype
m = qinput.shape[0]
k = qinput.shape[1]
n = weight.shape[0]
success, output = quant_ops.hipblaslt_w8a8_channelwise_gemm(
a = qinput,
b = weight,
scale_a = scale_a,
scale_b = scale_b,
m = m,
n = n,
k = k,
transpose_flag = "NT",
out_dtype = out_dtype,
bias = bias,
)
return output.view(m, n)
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
...@@ -278,25 +316,27 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -278,25 +316,27 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput, qinput = qinput.view(-1,qinput.shape[-1])
output = triton_scaled_mm_fp8(qinput,
weight, weight,
scale_a=TORCH_DEVICE_IDENTITY, scale_a=scale_a,
scale_b=TORCH_DEVICE_IDENTITY, scale_b=scale_b,
out_dtype=torch.float32) out_dtype=out_dtype,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple # A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5 # for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2: # if type(output) is tuple and len(output) == 2:
output = output[0] # output = output[0]
# Unpad (undo num_token_padding) # # Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0]) # output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) # x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
#
# DQ # # DQ
# C = sw * sx * (X * W) + bias # # C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t() # output = output * x_scale * scale_b.t()
if bias is not None: # if bias is not None:
output = output + bias # output = output + bias
return output.to(out_dtype).view(*output_shape) return output.view(*output_shape)
def dispatch_w8a8_scaled_mm( def dispatch_w8a8_scaled_mm(
...@@ -310,6 +350,8 @@ def dispatch_w8a8_scaled_mm( ...@@ -310,6 +350,8 @@ def dispatch_w8a8_scaled_mm(
if current_platform.is_rocm(): if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm
if envs.VLLM_W8A8_BACKEND == 3:
return hipblaslt_w8a8_channelwise_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights if (use_per_token_if_dynamic and not per_tensor_weights
......
...@@ -232,6 +232,11 @@ def get_model_architecture( ...@@ -232,6 +232,11 @@ def get_model_architecture(
'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM', 'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel'] 'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
#针对使用dtype为fp16的情况的量化默认关闭"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
if model_config.quantization in {"awq", "awq_marlin", "moe_wna16"}:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '0'
if not envs.VLLM_USE_NN: if not envs.VLLM_USE_NN:
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
...@@ -247,7 +252,8 @@ def get_model_architecture( ...@@ -247,7 +252,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD") \
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
...@@ -255,8 +261,12 @@ def get_model_architecture( ...@@ -255,8 +261,12 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
# if not envs.is_set("VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA"):
# os.environ['VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"): if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1' os.environ['USE_FUSED_RMS_QUANT'] = '1'
...@@ -278,6 +288,8 @@ def get_model_architecture( ...@@ -278,6 +288,8 @@ def get_model_architecture(
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1' os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"): if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1' os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
...@@ -287,10 +299,11 @@ def get_model_architecture( ...@@ -287,10 +299,11 @@ def get_model_architecture(
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
else: else:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD") \
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
...@@ -298,8 +311,12 @@ def get_model_architecture( ...@@ -298,8 +311,12 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
# if not envs.is_set("VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA"):
# os.environ['VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"): if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1' os.environ['USE_FUSED_RMS_QUANT'] = '1'
...@@ -321,6 +338,8 @@ def get_model_architecture( ...@@ -321,6 +338,8 @@ def get_model_architecture(
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1' os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"): if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1' os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
...@@ -353,7 +372,6 @@ def get_model_architecture( ...@@ -353,7 +372,6 @@ def get_model_architecture(
mixtral_supported = [ mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
] ]
vllm_supported_archs = ModelRegistry.get_supported_archs() vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures) for arch in architectures)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment