Commit 02689420 authored by xuxz's avatar xuxz
Browse files

Merge branch 'v0.9.2-dev' into 'v0.9.2-dev-add_connector'

# Conflicts:
#   vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
parents ef362942 fa683b07
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.forward_context import get_forward_context
class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Auto Prepare/Finalize that wraps both DeepEP High-Throughput and
Low-Latency implementations and selects one based on prefill/decode phase.
"""
def __init__(self,
ht_prepare_finalize: mk.FusedMoEPrepareAndFinalize,
ll_prepare_finalize: mk.FusedMoEPrepareAndFinalize):
super().__init__()
self.ht_prepare_finalize = ht_prepare_finalize
self.ll_prepare_finalize = ll_prepare_finalize
self._current_phase = "decode" # default to decode (LL)
def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available:
# try:
# forward_context = get_forward_context()
# attn_metadata = forward_context.attn_metadata
# # Handle both v0 (single AttentionMetadata) and v1 (dict) formats
# if isinstance(attn_metadata, dict):
# if attn_metadata:
# attn_metadata = next(iter(attn_metadata.values()))
# else:
# attn_metadata = None
# if attn_metadata is not None and hasattr(attn_metadata,
# "num_decode_tokens"):
# # 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
# self._current_phase = ("decode"
# if attn_metadata.num_decode_tokens > 0
# else "prefill")
# except Exception:
# # If forward_context is not available, use stored phase
# pass
# Prefill uses HT, decode uses LL
if self._current_phase == "prefill":
#rint("************prefill***********")
return self.ll_prepare_finalize
else:
# print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens)
return self.ht_prepare_finalize
#return self.ht_prepare_finalize
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
pf = self._get_current_prepare_finalize()
try:
return pf.activation_format
except NotImplementedError:
# Fallback to standard format if underlying impl does not provide it.
return mk.FusedMoEActivationFormat.Standard
def topk_indices_dtype(self) -> Optional[torch.dtype]:
pf = self._get_current_prepare_finalize()
return pf.topk_indices_dtype()
def max_num_tokens_per_rank(self) -> Optional[int]:
pf = self._get_current_prepare_finalize()
return pf.max_num_tokens_per_rank()
def num_dispatchers(self) -> int:
pf = self._get_current_prepare_finalize()
return pf.num_dispatchers()
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
):
pf = self._get_current_prepare_finalize()
return pf.prepare_async(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
pf = self._get_current_prepare_finalize()
return pf.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
pf = self._get_current_prepare_finalize()
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
):
pf = self._get_current_prepare_finalize()
if hasattr(pf, "finalize_async"):
return pf.finalize_async(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop( ...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop(
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor, w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor, w2_marlin: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
...@@ -234,20 +232,33 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -234,20 +232,33 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
): ):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1_marlin.is_contiguous(), "Packed weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2_marlin.is_contiguous(), "Packed weights2 must be contiguous"
# 当前只支持 bf16 fp16 # 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16] assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype compute_type = hidden_states.dtype
assert use_lightop, ( assert use_lightop, (
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE") "only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
# Packed weights store the same number of elements as the original layout,
# but reshaped/reordered for Marlin kernels:
# - w1_marlin: [E, K/16, (2N)*16]
# - w2_marlin: [E, N/16, K*16]
E, k_div16, twoN_times16 = w1_marlin.shape
K_w1 = k_div16 * 16
assert K_w1 == K, f"w1_marlin K mismatch: {K_w1} vs {K}"
assert twoN_times16 % 16 == 0
twoN = twoN_times16 // 16
assert twoN % 2 == 0
N = twoN // 2 N = twoN // 2
E2, K_w2, N2_w2 = w2.shape E2, n_div16, k_times16 = w2_marlin.shape
assert E2 == E, f"w2_marlin E mismatch: {E2} vs {E}"
K_w2 = k_times16 // 16
assert K_w2 == K, f"w2_marlin K mismatch: {K_w2} vs {K}"
assert n_div16 * 16 == N, f"w2_marlin N mismatch: {n_div16 * 16} vs {N}"
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
......
...@@ -5,7 +5,7 @@ import functools ...@@ -5,7 +5,7 @@ import functools
import json import json
import os import os
import math import math
from typing import Any, Callable, Dict, Optional, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -13,6 +13,7 @@ import vllm.envs as envs ...@@ -13,6 +13,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__)
# yapf: disable # yapf: disable
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, get_config_quant_dtype) FusedMoEQuantConfig, get_config_quant_dtype)
...@@ -31,7 +32,7 @@ try: ...@@ -31,7 +32,7 @@ try:
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json) from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json)
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8 from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") logger.warning_once("Please install lmslim if you want to infer the quantitative model of moe.")
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
...@@ -43,32 +44,9 @@ from vllm.utils import direct_register_custom_op ...@@ -43,32 +44,9 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
...@@ -1247,14 +1225,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -1247,14 +1225,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]: renormalize: bool) -> tuple[torch.Tensor, ...]:
if envs.VLLM_USE_TOPK_RENORM: if envs.VLLM_USE_TOPK_RENORM and renormalize is True:
from lightop import op as op from lightop import op as op
op.topk_softmax( op.topk_softmax(
topk_weights, topk_weights,
topk_indices, topk_indices,
token_expert_indices, token_expert_indices,
gating_output, gating_output,
True, renormalize,
) )
else: else:
ops.topk_softmax( ops.topk_softmax(
...@@ -1326,7 +1304,9 @@ def grouped_topk( ...@@ -1326,7 +1304,9 @@ def grouped_topk(
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
num_fused_shared_experts: Optional[int] = 0
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), ( assert hidden_states.size(0) == gating_output.size(0), (
...@@ -1339,7 +1319,7 @@ def grouped_topk( ...@@ -1339,7 +1319,7 @@ def grouped_topk(
else: else:
raise ValueError(f"Unsupported scoring function: {scoring_func}") raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0) num_token, num_experts = scores.shape
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased # Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights # scores for expert selection but original scores for routing weights
...@@ -1370,8 +1350,20 @@ def grouped_topk( ...@@ -1370,8 +1350,20 @@ def grouped_topk(
dim=-1, dim=-1,
sorted=False) sorted=False)
if num_fused_shared_experts != 0:
topk_ids[:, -1] = num_experts
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
...@@ -1696,63 +1688,88 @@ def fused_experts_impl( ...@@ -1696,63 +1688,88 @@ def fused_experts_impl(
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1) top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue: # We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if not use_nn_moe:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2))
if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0.")
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop is installed and VLLM_USE_LIGHTOP=1."
)
if activation != "silu":
raise RuntimeError(
"Marlin W16A16 MoE only supports activation='silu'.")
if apply_router_weight_on_input:
raise RuntimeError(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype) cache13 = get_moe_cache(top_k_num,
else: twoN,
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) K,
device=hidden_states.device,
# Optional fast path: use lmslim's Marlin W16A16 fused MoE implementation dtype=hidden_states.dtype)
# when explicitly requested. This reuses the same cache13 buffer as other
# fused paths for consistency.
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin
if (envs.VLLM_USE_MARLIN_W16A16_MOE
and fused_experts_impl_w16a16_marlin is not None):
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
else: else:
w2 = w2_cpu cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin( return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1_marlin=w1,
w2=w2, w2_marlin=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
cache13=cache13, cache13=cache13,
...@@ -1763,10 +1780,32 @@ def fused_experts_impl( ...@@ -1763,10 +1780,32 @@ def fused_experts_impl(
expert_map=expert_map, expert_map=expert_map,
use_nn_moe=False, use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output shared_output=shared_output,
) )
if use_int8_w8a8 is True: # Non-Marlin paths need the original weight shapes.
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
N,
K if not use_nn_moe else w2.shape[2],
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(
M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
if use_int8_w8a8 or use_fp8_w8a8:
return fused_experts_impl_int8(hidden_states=hidden_states, return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
...@@ -1776,8 +1815,8 @@ def fused_experts_impl( ...@@ -1776,8 +1815,8 @@ def fused_experts_impl(
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=True, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
...@@ -2115,12 +2154,17 @@ def fused_moe( ...@@ -2115,12 +2154,17 @@ def fused_moe(
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
if use_grouped_topk: if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk_weights, topk_ids = grouped_topk(
topk, renormalize, hidden_states = hidden_states,
num_expert_group, topk_group) gating_output = gating_output,
topk = topk,
renormalize = renormalize,
num_expert_group = num_expert_group,
topk_group = topk_group,
routed_scaling_factor = routed_scaling_factor,
num_fused_shared_experts = 1)
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize) hidden_states, gating_output, topk, renormalize)
......
...@@ -6,7 +6,9 @@ from math import prod ...@@ -6,7 +6,9 @@ from math import prod
from typing import Optional, final from typing import Optional, final
from dataclasses import dataclass from dataclasses import dataclass
from collections.abc import Callable from collections.abc import Callable
from vllm.logger import init_logger
logger = init_logger(__name__)
import torch import torch
import vllm.envs as envs import vllm.envs as envs
...@@ -843,11 +845,16 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -843,11 +845,16 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute, fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
experts_ht: CustomizedFusedMoEPermuteExpertsUnpermute = None,
experts_ll: CustomizedFusedMoEPermuteExpertsUnpermute = None,
shared_experts: Optional[torch.nn.Module] = None, shared_experts: Optional[torch.nn.Module] = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.fused_experts_ht = experts_ht
self.fused_experts_ll = experts_ll
self.shared_experts = shared_experts self.shared_experts = shared_experts
if self.shared_experts is not None: if self.shared_experts is not None:
...@@ -919,7 +926,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -919,7 +926,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
prepare_finalize = self.prepare_finalize
fused_experts = self.fused_experts
if envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
num_ht_ll_tokens = envs.VLLM_MOE_HT_THRESHOLD
num_tokens = hidden_states.size(0)
# logger.info("num_tokens=%d", num_tokens)
if num_tokens > num_ht_ll_tokens:
prepare_finalize = self.prepare_finalize.ht_prepare_finalize
fused_experts = self.fused_experts_ht
else:
prepare_finalize = self.prepare_finalize.ll_prepare_finalize
fused_experts = self.fused_experts_ll
a1 = hidden_states a1 = hidden_states
if inplace and self.shared_experts is None: if inplace and self.shared_experts is None:
...@@ -931,7 +952,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -931,7 +952,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
prepare_ret = self.prepare_finalize.prepare_async( prepare_ret = prepare_finalize.prepare_async(
a1, a1,
a1_scale, a1_scale,
a2_scale, a2_scale,
...@@ -940,7 +961,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -940,7 +961,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
global_num_experts, global_num_experts,
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, fused_experts.quant_config,
) )
hook, receiver = ( hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret) prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
...@@ -971,7 +992,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -971,7 +992,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case. # and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else: else:
fused_out = self.fused_experts.apply( fused_out = fused_experts.apply(
None, None,
a1, a1,
a1q, a1q,
...@@ -1008,12 +1029,12 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -1008,12 +1029,12 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self.alt_stream.wait_event(self.alt_event) self.alt_stream.wait_event(self.alt_event)
hook = None hook = None
if self.prepare_finalize.activation_format == \ if prepare_finalize.activation_format == \
FusedMoEActivationFormat.BatchedExperts: FusedMoEActivationFormat.BatchedExperts:
self.prepare_finalize.finalize(output, fused_out, topk_weights, prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True) topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
else: else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights, hook = prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True) topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if hook is not None: if hook is not None:
hook() hook()
......
"""
Utilities for capturing MoE router distributions from real workloads.
This is intentionally lightweight and gated behind env vars so it has zero
runtime impact unless explicitly enabled.
Env vars (defaults from vllm.envs):
- VLLM_MOE_ROUTER_CAPTURE=0/1: enable capture (default: 0).
- VLLM_MOE_ROUTER_CAPTURE_DIR=/path: output directory for per-process dumps
(default: /tmp).
- VLLM_MOE_ROUTER_CAPTURE_RANK=N: only capture on the given torch.distributed
rank (default: -1; set to -1 to capture all ranks).
- VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS=N: max number of layers to record per
process (default: 0; 0 = unlimited).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT=A: only record calls where router_logits
has num_tokens > A (default: -1; <0 = disabled).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT=B: only record calls where router_logits
has num_tokens < B (default: -1; 0 = disabled).
Output format:
- A single `.pt` per captured num_tokens (and per rank if torch.distributed is
initialized).
- Payload includes `layers_by_num_tokens: dict[str, dict[layer_name, layer_state]]`.
- A convenience `layers` field is also included (same as
`layers_by_num_tokens[str(num_tokens)]`) for easy loading.
- For each captured MoE layer, stores a list of 2D tensors
`router_logits_chunks: list[Tensor[num_tokens_i, num_experts]]` on CPU,
typically in fp16 for space efficiency.
"""
from __future__ import annotations
import atexit
import inspect
import os
import socket
import threading
import time
from dataclasses import dataclass
from typing import Optional
import torch
import vllm.envs as envs
_DEFAULT_SKIP_STACK_FUNCS = ("profile_run", "_dummy_run",
"determine_available_memory")
@dataclass(frozen=True)
class RouterCaptureConfig:
enabled: bool = False
out_dir: str = "/tmp"
skip_profile: bool = True
skip_stack_funcs: tuple[str, ...] = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = 0
max_layers: int = 0
num_tokens_gt: Optional[int] = None
num_tokens_lt: Optional[int] = None
@staticmethod
def from_env() -> "RouterCaptureConfig":
enabled = envs.VLLM_MOE_ROUTER_CAPTURE
out_dir = envs.VLLM_MOE_ROUTER_CAPTURE_DIR
skip_profile = True
skip_stack_funcs = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = None
if envs.VLLM_MOE_ROUTER_CAPTURE_RANK >= 0:
only_rank = envs.VLLM_MOE_ROUTER_CAPTURE_RANK
max_layers = envs.VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS
num_tokens_gt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT >= 0
else None)
num_tokens_lt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT > 0
else None)
# Per-size mode requires an explicit token-count filter to avoid
# unbounded captures by default.
if num_tokens_gt_opt is None and num_tokens_lt_opt is None:
enabled = False
if (num_tokens_gt_opt is not None and num_tokens_lt_opt is not None
and num_tokens_gt_opt >= num_tokens_lt_opt):
enabled = False
return RouterCaptureConfig(enabled=enabled,
out_dir=out_dir,
skip_profile=skip_profile,
skip_stack_funcs=skip_stack_funcs,
only_rank=only_rank,
max_layers=max_layers,
num_tokens_gt=num_tokens_gt_opt,
num_tokens_lt=num_tokens_lt_opt)
def _in_profile_run(skip_stack_funcs: tuple[str, ...]) -> bool:
"""
Best-effort detection for vLLM startup profiling/warmup runs.
Startup warmups often execute MoE kernels with synthetic shapes. When
enabled, skip captures from these stacks so the first capture comes from a
real request.
"""
if not skip_stack_funcs:
return False
frame = inspect.currentframe()
try:
while frame is not None:
name = frame.f_code.co_name
if name in skip_stack_funcs:
return True
frame = frame.f_back
finally:
# Avoid reference cycles.
del frame
return False
class _RouterCapture:
def __init__(self, cfg: RouterCaptureConfig) -> None:
self.cfg = cfg
# Bucket captures by token count.
self._layers_by_num_tokens: dict[int, dict[str, dict[str, object]]] = {}
self._layer_names: set[str] = set()
self._completed_num_tokens: set[int] = set()
self._lock = threading.Lock()
self._flush_counter = 0
self._pid = os.getpid()
self._host = socket.gethostname()
self._start_time = time.time()
os.makedirs(cfg.out_dir, exist_ok=True)
atexit.register(self.flush)
def _bucket_for_num_tokens(self, num_tokens: int) -> Optional[int]:
"""Return the per-size bucket key for this record call, or None if filtered."""
if self.cfg.num_tokens_gt is None and self.cfg.num_tokens_lt is None:
return None
if self.cfg.num_tokens_gt is not None:
if int(num_tokens) <= int(self.cfg.num_tokens_gt):
return None
if self.cfg.num_tokens_lt is not None:
if int(num_tokens) >= int(self.cfg.num_tokens_lt):
return None
bucket_num_tokens = int(num_tokens)
if bucket_num_tokens != 0 and bucket_num_tokens in self._completed_num_tokens:
return None
return bucket_num_tokens
def _snapshot_layers_by_num_tokens(
self,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
) -> dict[int, dict[str, dict[str, object]]]:
snapshot: dict[int, dict[str, dict[str, object]]] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_snapshot: dict[str, dict[str, object]] = {}
for layer_name, state in bucket.items():
chunks = state.get("router_logits_chunks", [])
bucket_snapshot[layer_name] = {
"num_experts": int(state.get("num_experts", 0)),
"num_tokens": int(state.get("num_tokens", 0)),
"router_logits_chunks": list(chunks),
}
snapshot[int(num_tokens)] = bucket_snapshot
return snapshot
@torch.no_grad()
def record(self, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
if self.cfg.skip_profile and _in_profile_run(self.cfg.skip_stack_funcs):
return
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return
if router_logits.dim() != 2:
return
num_tokens, num_experts = router_logits.shape
if num_tokens == 0 or num_experts == 0:
return
bucket_num_tokens = self._bucket_for_num_tokens(int(num_tokens))
if bucket_num_tokens is None:
return
# Limit the number of recorded layers to avoid unbounded dumps.
if layer_name not in self._layer_names:
if self.cfg.max_layers != 0 and len(self._layer_names) >= self.cfg.max_layers:
return
self._layer_names.add(layer_name)
# Store on CPU to avoid consuming GPU memory during long runs.
# fp16 is typically sufficient because we primarily care about
# distribution and relative ordering (top-k), not exact values.
router_logits_cpu = router_logits.detach()
if router_logits_cpu.is_cuda:
router_logits_cpu = router_logits_cpu.to(device="cpu",
dtype=torch.float16)
else:
router_logits_cpu = router_logits_cpu.to(dtype=torch.float16)
bucket_snapshot: Optional[dict[str, dict[str, object]]] = None
should_flush = False
with self._lock:
bucket = self._layers_by_num_tokens.setdefault(bucket_num_tokens, {})
if layer_name in bucket:
return
bucket[layer_name] = {
"num_experts": int(num_experts),
"num_tokens": int(num_tokens),
"router_logits_chunks": [router_logits_cpu],
}
if self.cfg.max_layers != 0 and len(bucket) >= int(self.cfg.max_layers):
should_flush = True
bucket_snapshot = self._snapshot_layers_by_num_tokens(
{int(bucket_num_tokens): bucket})[int(bucket_num_tokens)]
self._completed_num_tokens.add(int(bucket_num_tokens))
self._layers_by_num_tokens.pop(int(bucket_num_tokens), None)
if should_flush and bucket_snapshot is not None:
self._flush_payload(
layers_by_num_tokens={int(bucket_num_tokens): bucket_snapshot},
file_tag=f"nt{int(bucket_num_tokens)}",
)
def _flush_payload(
self,
*,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
file_tag: Optional[str] = None,
) -> Optional[str]:
if not self.cfg.enabled:
return None
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return None
rank = _get_rank()
now = time.time()
ts = time.strftime("%Y%m%d_%H%M%S", time.localtime(now))
ts_us = int(now * 1_000_000)
with self._lock:
flush_idx = self._flush_counter
self._flush_counter += 1
rank_str = f"rank{rank}" if rank is not None else "rankNA"
tag = f"{file_tag}_" if file_tag else ""
out_path = os.path.join(
self.cfg.out_dir,
f"moe_router_stats_{tag}{ts_us}_{self._host}_{rank_str}_pid{self._pid}_flush{flush_idx}.pt",
)
layers_by_num_tokens_out: dict[str, object] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_out: dict[str, object] = {}
for layer_name, state in bucket.items():
bucket_out[layer_name] = {
"num_experts": int(state["num_experts"]),
"num_tokens": int(state["num_tokens"]),
"router_logits_chunks":
state["router_logits_chunks"], # type: ignore[typeddict-item]
}
layers_by_num_tokens_out[str(int(num_tokens))] = bucket_out
payload: dict[str, object] = {
"meta": {
"timestamp": ts,
"timestamp_us": ts_us,
"flush_index": int(flush_idx),
"host": self._host,
"pid": self._pid,
"rank": rank,
"wall_time_s": float(now - self._start_time),
},
"layers_by_num_tokens": layers_by_num_tokens_out,
}
# Backward-compatible convenience field when there is a single bucket.
if len(layers_by_num_tokens) == 1:
(only_bucket_key, ) = layers_by_num_tokens.keys()
payload["layers"] = layers_by_num_tokens_out[str(int(only_bucket_key))]
try:
torch.save(payload, out_path)
except Exception:
return None
return out_path
def flush(self) -> Optional[str]:
with self._lock:
if not self._layers_by_num_tokens:
return None
snapshot = self._snapshot_layers_by_num_tokens(self._layers_by_num_tokens)
return self._flush_payload(layers_by_num_tokens=snapshot)
def reset(self) -> None:
with self._lock:
self._layers_by_num_tokens.clear()
self._layer_names.clear()
self._completed_num_tokens.clear()
self._start_time = time.time()
_CAPTURE: Optional[_RouterCapture] = None
_CAPTURE_DISABLED: bool = False
def _disable_global_capture() -> None:
global _CAPTURE, _CAPTURE_DISABLED
_CAPTURE = None
_CAPTURE_DISABLED = True
def _get_rank() -> Optional[int]:
if torch.distributed.is_available() and torch.distributed.is_initialized():
try:
return torch.distributed.get_rank()
except Exception:
return None
return None
def _get_capture() -> Optional[_RouterCapture]:
global _CAPTURE, _CAPTURE_DISABLED
if _CAPTURE_DISABLED:
return None
if _CAPTURE is not None:
return _CAPTURE
cfg = RouterCaptureConfig.from_env()
if not cfg.enabled:
_disable_global_capture()
return None
if cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != cfg.only_rank:
_disable_global_capture()
return None
_CAPTURE = _RouterCapture(cfg)
return _CAPTURE
@torch.no_grad()
def maybe_record_router_logits(*, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
capture = _get_capture()
if capture is None:
return
capture.record(layer_name=layer_name, router_logits=router_logits, top_k=top_k)
def maybe_flush_router_capture(*, reset: bool = False) -> Optional[str]:
"""Flush capture buffers to disk without exiting the process."""
capture = _get_capture()
if capture is None:
return None
out_path = capture.flush()
if out_path is not None and reset:
capture.reset()
return out_path
...@@ -34,7 +34,11 @@ class SharedFusedMoE(FusedMoE): ...@@ -34,7 +34,11 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: hidden_states_copy: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped: if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states) shared_out = self._shared_experts(hidden_states)
...@@ -53,6 +57,8 @@ class SharedFusedMoE(FusedMoE): ...@@ -53,6 +57,8 @@ class SharedFusedMoE(FusedMoE):
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
hidden_states_copy = hidden_states_copy,
i_s = i_s,
i_q = i_q,
) )
return fused_out return fused_out
...@@ -22,6 +22,7 @@ from vllm.utils import round_up ...@@ -22,6 +22,7 @@ from vllm.utils import round_up
try: try:
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from lightop import op
except Exception: except Exception:
print("INFO: Please install lmslim if you want to use int utils.\n") print("INFO: Please install lmslim if you want to use int utils.\n")
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -622,6 +623,16 @@ def ep_scatter( ...@@ -622,6 +623,16 @@ def ep_scatter(
num_experts = num_recv_tokens_per_expert.shape[0] num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1] hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1] scale_hidden_size = recv_x_scale.shape[-1]
if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts grid = num_experts
...@@ -665,8 +676,8 @@ def ep_scatter( ...@@ -665,8 +676,8 @@ def ep_scatter(
num_warps=num_warps, num_warps=num_warps,
HIDDEN_SIZE=hidden_size, HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D, SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D), SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
) )
return return
...@@ -897,7 +908,7 @@ class EPSharedExperts(nn.Module): ...@@ -897,7 +908,7 @@ class EPSharedExperts(nn.Module):
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x, **_):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
......
...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
RowvLLMParameter) RowvLLMParameter)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
# yapf: enable # yapf: enable
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -331,7 +332,6 @@ class ReplicatedLinear(LinearBase): ...@@ -331,7 +332,6 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -343,7 +343,6 @@ class ReplicatedLinear(LinearBase): ...@@ -343,7 +343,6 @@ class ReplicatedLinear(LinearBase):
quant_config, quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias)
self.eps = eps
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
...@@ -393,45 +392,19 @@ class ReplicatedLinear(LinearBase): ...@@ -393,45 +392,19 @@ class ReplicatedLinear(LinearBase):
def forward( def forward(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, iqis: Optional[tuple] = None, **_
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
if quant_args is not None:
input_quant_args = quant_args
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args) output = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: if not self.return_bias:
return output return output
return output, output_bias return output, output_bias
else:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
else: else:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
...@@ -459,7 +432,6 @@ class FusedQuantedReplicatedLinear(LinearBase): ...@@ -459,7 +432,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -473,7 +445,6 @@ class FusedQuantedReplicatedLinear(LinearBase): ...@@ -473,7 +445,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
quant_config, quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias)
self.eps = eps
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
...@@ -588,7 +559,6 @@ class FusedQuantedReplicatedLinear(LinearBase): ...@@ -588,7 +559,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
assert len(self.kv_a_weight.shape) == 2 assert len(self.kv_a_weight.shape) == 2
fused_weight = torch.cat([self.q_a_weight, self.kv_a_weight], dim=0) # TN fused_weight = torch.cat([self.q_a_weight, self.kv_a_weight], dim=0) # TN
param.data.copy_(fused_weight) param.data.copy_(fused_weight)
#TODO: wjl 删掉无用的显存tensor
else: else:
raise ValueError(f"Unexpected weight: {source}") raise ValueError(f"Unexpected weight: {source}")
...@@ -596,31 +566,17 @@ class FusedQuantedReplicatedLinear(LinearBase): ...@@ -596,31 +566,17 @@ class FusedQuantedReplicatedLinear(LinearBase):
def forward( def forward(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, iqis: Optional[tuple] = None, **_
residual: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[Parameter]]:
update_hd: Optional[bool] = True if envs.USE_FUSED_RMS_QUANT and iqis is not None:
) -> Union[torch.Tensor,
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args) output = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
assert self.return_bias is True assert self.return_bias is True
if not self.return_bias: if not self.return_bias:
raise RuntimeError("Not return bias. Unexpected Error.") raise RuntimeError("Not return bias. Unexpected Error.")
return output, new_residual, output_bias return output, output_bias
else: else:
raise RuntimeError("Unexpected Error.") raise RuntimeError("Unexpected Error.")
...@@ -670,12 +626,18 @@ class ColumnParallelLinear(LinearBase): ...@@ -670,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None: if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size self.tp_size = self.expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition] self.output_partition_sizes = [self.output_size_per_partition]
...@@ -858,31 +820,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -858,31 +820,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
def forward( def forward(
self, input_, self, input_,
rms_weight: Optional[torch.Tensor] = None, xqxs: Optional[tuple] = None,
residual: Optional[torch.Tensor] = None, iqis: Optional[tuple] = None, **_
update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]], tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]],
]: ]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args) output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
...@@ -892,7 +840,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -892,7 +840,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: if not self.return_bias:
return output return output
return output, new_residual, i_q, _scales, output_bias return output, output_bias
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
...@@ -933,13 +881,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -933,13 +881,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
self.eps = eps
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -949,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -949,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
tp_size = get_moe_tp_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=sum(output_sizes), output_size=sum(output_sizes),
...@@ -959,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -959,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
expect_tp_size=expect_tp_size) expect_tp_size=expect_tp_size,
enable_dp_attn_moe=enable_dp_attn_moe)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -1061,6 +1013,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1061,6 +1013,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_rank = 0 tp_rank = 0
tp_size = 1 tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size
...@@ -1182,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1182,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if hasattr(param, "expect_tp_size"): if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe and hasattr(param, "enable_dp_attn_moe"):
tp_size = get_moe_tp_size()
param.enable_dp_attn_moe = self.enable_dp_attn_moe
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod) Fp8LinearMethod, Fp8MoEMethod)
...@@ -1613,6 +1573,7 @@ class RowParallelLinear(LinearBase): ...@@ -1613,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -1621,7 +1582,13 @@ class RowParallelLinear(LinearBase): ...@@ -1621,7 +1582,13 @@ class RowParallelLinear(LinearBase):
if expect_tp_size is not None: if expect_tp_size is not None:
self.tp_rank = 0 self.tp_rank = 0
self.tp_size = 1 self.tp_size = 1
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_rank = get_moe_tp_rank()
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
...@@ -1671,6 +1638,11 @@ class RowParallelLinear(LinearBase): ...@@ -1671,6 +1638,11 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None: if self.expect_tp_size is not None:
tp_rank = 0 tp_rank = 0
tp_size = 1 tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False) is_sharded_weight = getattr(param, "is_sharded_weight", False)
...@@ -1725,6 +1697,9 @@ class RowParallelLinear(LinearBase): ...@@ -1725,6 +1697,9 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"): if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe is not None and hasattr(param, "enable_dp_attn_moe"):
param.enable_dp_attn_moe = self.enable_dp_attn_moe
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
......
...@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe=False,
use_fused_gate: Optional[bool] = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional from typing import Callable, Optional
from vllm import envs
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter from torch.nn import Parameter
...@@ -61,6 +61,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -61,6 +61,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
# If channelwise, scales are already lined up, so just transpose. # If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight weight = layer.weight
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.t()
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None) input_scale = getattr(layer, 'input_scale', None)
...@@ -140,7 +142,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -140,7 +142,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def apply_weights(self, def apply_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None,input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_,
) -> torch.Tensor:
return self.fp8_linear.apply(input=x, return self.fp8_linear.apply(input=x,
weight=layer.weight, weight=layer.weight,
......
...@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric: bool): input_symmetric: bool):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
self.input_symmetric = input_symmetric self.input_symmetric = input_symmetric
@classmethod @classmethod
......
...@@ -331,6 +331,10 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -331,6 +331,10 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data weight_scale_inv = layer.weight_scale_inv.data
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale_inv = weight_scale_inv.T.contiguous()
else:
weight = self._maybe_pad_weight(weight) weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses. # Torch.compile cannot use Parameter subclasses.
...@@ -854,10 +858,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -854,10 +858,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_fused_gate: Optional[bool] = False,
enable_eplb: bool = False, enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,**_,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
assert expert_load_view is not None assert expert_load_view is not None
...@@ -882,6 +887,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -882,6 +887,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_load_view=expert_load_view, expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map, logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count, logical_replica_count=logical_replica_count,
use_fused_gate=use_fused_gate,
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
......
...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config): def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0] n=layer.weight.shape[0]
......
...@@ -6,6 +6,8 @@ import functools ...@@ -6,6 +6,8 @@ import functools
import json import json
import os import os
from typing import Any, Callable, Optional, Union, List from typing import Any, Callable, Optional, Union, List
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
import torch import torch
...@@ -19,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -19,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
try:
from lmslim.layers.gemm.fp8_utils import per_token_group_quant_fp8,w8a8_block_fp8_matmul
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -83,7 +89,7 @@ if current_platform.is_rocm(): ...@@ -83,7 +89,7 @@ if current_platform.is_rocm():
def dispatch_w8a8_blockscale_func( def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool use_cutlass: bool, use_aiter_and_is_supported: bool, use_blaslt: bool
) -> Callable[[ ) -> Callable[[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
...@@ -96,6 +102,9 @@ def dispatch_w8a8_blockscale_func( ...@@ -96,6 +102,9 @@ def dispatch_w8a8_blockscale_func(
return cutlass_scaled_mm return cutlass_scaled_mm
if (use_aiter_and_is_supported): if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if use_blaslt:
return hipblaslt_w8a8_block_fp8_matmul
return w8a8_block_fp8_matmul return w8a8_block_fp8_matmul
...@@ -127,6 +136,10 @@ def apply_w8a8_block_fp8_linear( ...@@ -127,6 +136,10 @@ def apply_w8a8_block_fp8_linear(
assert input_scale is None assert input_scale is None
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype output_dtype = input.dtype
...@@ -166,9 +179,12 @@ def apply_w8a8_block_fp8_linear( ...@@ -166,9 +179,12 @@ def apply_w8a8_block_fp8_linear(
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
else: else:
use_cutlass = False use_cutlass = False
use_blaslt = False
if envs.VLLM_W8A8_BACKEND == 3:
use_blaslt = True
w8a8_blockscale_func = dispatch_w8a8_blockscale_func( w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported) use_cutlass, use_aiter_and_is_supported, use_blaslt)
if use_cutlass: if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass) input_2d, block_size[1], column_major_scales=use_cutlass)
...@@ -197,6 +213,10 @@ def apply_w8a8_block_fp8_linear_fake( ...@@ -197,6 +213,10 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False, use_aiter_and_is_supported: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device) return torch.empty(output_shape, dtype=input.dtype, device=input.device)
...@@ -240,333 +260,9 @@ def block_quant_to_tensor_quant( ...@@ -240,333 +260,9 @@ def block_quant_to_tensor_quant(
return x_q_tensor, scale return x_q_tensor, scale
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
# Ensure offset calculation uses int64 for y_s_ptr
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
tl.int64)
y_s_ptr += y_s_ptr_offset
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
assert out_q is None or out_q.shape == x.shape
x_q = out_q
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block FP8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
"Using default W8A8 Block FP8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s",
config_file_path,
)
return None
def hipblaslt_w8a8_block_fp8_matmul(
def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
...@@ -574,80 +270,19 @@ def w8a8_block_fp8_matmul( ...@@ -574,80 +270,19 @@ def w8a8_block_fp8_matmul(
block_size: list[int], block_size: list[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise m, k = A.shape
quantization. _, n = B.shape
It takes two input tensors `A` and `B` with scales `As` and `Bs`. enum_block_size = BlockSize.block_128x128
The output is returned in the specified `output_dtype`. if block_size[0] == 64:
Args: enum_block_size = BlockSize.block_64x64
A: The input tensor, e.g., activation. elif block_size[0] == 128:
B: The input tensor, e.g., weight. enum_block_size = BlockSize.block_128x128
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# Get the optimal config if there is one
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Default config print(f"[WARN] Unsupported block_size: {block_size}. Falling back to BlockSize.block_128x128")
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1] _, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs,
config = { m, n, k, 'NN', output_dtype,
"BLOCK_SIZE_M": 64, enum_block_size, None)
"BLOCK_SIZE_N": block_size[0], return d
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2,
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment