Commit 2b7b1a31 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat] 适配block和channel fp8, 添加dp attention功能

[pd] 修改非堆成切分的判断, pd分离使用默认调度,set VLLM_USE_PD_SPLIT=1
[perf](fused-moe): 预打包 Marlin W16A16 MoE 权重,降低 warmup 显存峰值
[fix]修复丢弃MTP代码报错,pp+chunkprefill多并发input ids更新bug, 修复不开启融合图的断言错误, PP 场景 decode 阶段 token 被误丢弃导致卡住
parent 2eda94c6
......@@ -14,6 +14,7 @@ from vllm.utils import direct_register_custom_op
try:
from lmslim import quant_ops
from lmslim import quant_tools
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
......@@ -1691,67 +1692,67 @@ def scaled_fp4_experts_quant(
return output, output_scales
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
# scale: Optional[torch.Tensor] = None,
# num_token_padding: Optional[int] = None,
# scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False,
# output: Optional[torch.Tensor] = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# shape: Union[tuple[int, int], torch.Size] = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, \
# "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1),
# device=input.device,
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
# else:
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# assert scale.numel() == 1, f"{scale.shape}"
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous())
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
# gptq allspark
......
......@@ -1883,6 +1883,9 @@ class ParallelConfig:
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
enable_dp_attention: bool = False
"""Enable dp attention"""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
......@@ -2109,6 +2112,9 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
if self.enable_dp_attention and self.enable_expert_parallel:
raise ValueError("Dp attention and expert parallel can not enable together.")
return self
......@@ -4805,6 +4811,7 @@ class VllmConfig:
dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
enable_dp_attention = self.parallel_config.enable_dp_attention
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
......@@ -4813,11 +4820,15 @@ class VllmConfig:
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp:
if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
if 1 not in batch_size_capture_list:
batch_size_capture_list.insert(0, 1)
else:
if ep_sp:
if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
if 1 not in batch_size_capture_list:
batch_size_capture_list.insert(0, 1)
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
......
......@@ -120,9 +120,11 @@ class P2pNcclEngine:
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
if self.enable_asymmetric_p2p == True:
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
port = int(self.config.kv_port) + port_offset
......@@ -742,7 +744,7 @@ class P2pNcclEngine:
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
}
logger.info(f"""_send_sync_new:{data}""")
# logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data))
response = sock.recv()
......
......@@ -477,6 +477,9 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
enable_dp_attention: bool = \
ParallelConfig.enable_dp_attention
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
......@@ -718,6 +721,9 @@ class EngineArgs:
parallel_group.add_argument(
"--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
parallel_group.add_argument(
"--enable-dp-attention",
**parallel_kwargs["enable_dp_attention"])
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
......@@ -1204,6 +1210,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel,
enable_dp_attention=self.enable_dp_attention,
)
speculative_config = self.create_speculative_config(
......
......@@ -1223,7 +1223,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
# vLLM will sync to avoid pp vmfault
......
from typing import TYPE_CHECKING, List, Optional, Tuple
import logging
import torch
import vllm.envs as envs
from vllm.distributed.parallel_state import GroupCoordinator, init_model_parallel_group, get_world_group
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
tensor_model_parallel_reduce_scatter,
get_tp_group)
_ENABLE_DP_ATTENTION_FLAG: bool = False
_MOE_TP: Optional[GroupCoordinator] = None
_ATTN_DP_SIZE = 0
_ATTN_TP_SIZE = 0
_ATTN_TP_RANK = 0
_ATTN_DP_RANK = 0
_MOT_TP_SIZE = 0
_MOT_TP_RANK = 0
def initialize_dp_attention(vllm_config, backend: Optional[str] = None):
from vllm.config import VllmConfig
assert isinstance(vllm_config, VllmConfig)
global _ENABLE_DP_ATTENTION_FLAG, _ATTN_DP_SIZE, _ATTN_TP_SIZE, _ATTN_TP_RANK, _ATTN_DP_RANK, _MOT_TP_SIZE, _MOT_TP_RANK
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
# Build the moe tensor model-parallel groups.
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
data_parallel_size = vllm_config.parallel_config.data_parallel_size
pipeline_model_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
tensor_model_parallel_size = vllm_config.parallel_config.tensor_parallel_size
moe_tp_size = world_size // pipeline_model_parallel_size
moe_ep_size = moe_tp_size if vllm_config.parallel_config.enable_expert_parallel else 1
_ATTN_DP_SIZE = data_parallel_size
_ATTN_TP_SIZE = tensor_model_parallel_size
_ATTN_TP_RANK = get_tensor_model_parallel_rank()
_ATTN_DP_RANK = vllm_config.parallel_config.data_parallel_rank
_MOT_TP_SIZE = moe_tp_size
_MOT_TP_RANK = rank % _MOT_TP_SIZE
global _MOE_TP
assert _MOE_TP is None, ("moe tensor model parallel group is already initialized")
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
group_ranks = []
for i in range(pipeline_model_parallel_size):
ranks = list(
range(i * moe_tp_size, (i + 1) * moe_tp_size)
)
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_MOE_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="moe_tp")
def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_moe_tp_group() -> GroupCoordinator:
assert _MOE_TP is not None, ("tensor model parallel group is not initialized")
return _MOE_TP
def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_moe_tp_rank() -> int:
assert _MOT_TP_RANK is not None, "dp attention not initialized!"
return _MOT_TP_RANK
def get_moe_tp_size() -> int:
assert _MOT_TP_SIZE is not None, "dp attention not initialized!"
return _MOT_TP_SIZE
def get_attention_tp_group() -> GroupCoordinator:
return get_tp_group()
def moe_tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_moe_tp_group().all_gather(input_, dim)
def moe_tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_moe_tp_group().reduce_scatter(input_, dim)
def dp_gather(
hidden_states: torch.Tensor,)-> torch.Tensor:
if get_attention_tp_size() == 1:
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
def dp_reduce_scatter_tensor(hidden_states: torch.Tensor)-> torch.Tensor:
if get_moe_tp_group().world_size == get_attention_dp_size():
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
else:
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
\ No newline at end of file
......@@ -214,8 +214,6 @@ def moe_align_block_size_lightop(
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor,
topk_weights: torch.Tensor,
......@@ -234,8 +232,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert w1_marlin.is_contiguous(), "Packed weights1 must be contiguous"
assert w2_marlin.is_contiguous(), "Packed weights2 must be contiguous"
# 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype
......@@ -243,11 +241,24 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
"only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
# Packed weights store the same number of elements as the original layout,
# but reshaped/reordered for Marlin kernels:
# - w1_marlin: [E, K/16, (2N)*16]
# - w2_marlin: [E, N/16, K*16]
E, k_div16, twoN_times16 = w1_marlin.shape
K_w1 = k_div16 * 16
assert K_w1 == K, f"w1_marlin K mismatch: {K_w1} vs {K}"
assert twoN_times16 % 16 == 0
twoN = twoN_times16 // 16
assert twoN % 2 == 0
N = twoN // 2
E2, K_w2, N2_w2 = w2.shape
E2, n_div16, k_times16 = w2_marlin.shape
assert E2 == E, f"w2_marlin E mismatch: {E2} vs {E}"
K_w2 = k_times16 // 16
assert K_w2 == K, f"w2_marlin K mismatch: {K_w2} vs {K}"
assert n_div16 * 16 == N, f"w2_marlin N mismatch: {n_div16 * 16} vs {N}"
if global_num_experts == -1:
global_num_experts = E
......
......@@ -5,7 +5,7 @@ import functools
import json
import os
import math
from typing import Any, Callable, Dict, Optional, List, Optional, Tuple
from typing import Any, Callable, Dict, Optional, List
import torch
......@@ -13,6 +13,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
logger = init_logger(__name__)
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, get_config_quant_dtype)
......@@ -31,7 +32,7 @@ try:
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json)
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger.warning_once("Please install lmslim if you want to infer the quantitative model of moe.")
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
......@@ -44,31 +45,9 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
......@@ -1710,63 +1689,71 @@ def fused_experts_impl(
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when
# explicitly requested. When weights are pre-packed in the post-load hook,
# w1/w2 are already in Marlin layout and we can avoid first-run packing
# peaks during KV cache profiling.
if envs.VLLM_USE_MARLIN_W16A16_MOE and not use_nn_moe:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is not None:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
if (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)):
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
# Optional fast path: use lmslim's Marlin W16A16 fused MoE implementation
# when explicitly requested. This reuses the same cache13 buffer as other
# fused paths for consistency.
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin
if (envs.VLLM_USE_MARLIN_W16A16_MOE
and fused_experts_impl_w16a16_marlin is not None):
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
w2 = w2_cpu
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
......@@ -1777,9 +1764,48 @@ def fused_experts_impl(
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output
shared_output=shared_output,
)
# No fallback packing: require pre-packed weights when Marlin W16A16
# MoE is enabled. If weights are still in the original layout, fail
# fast to avoid packing-induced peak memory and unpredictable
# warmup/profiling behavior.
if (w1.dim() == 3 and w2.dim() == 3 and w1.size(0) == w2.size(0)
and w2.size(1) == K):
twoN = w1.size(1)
N = w2.size(2)
if (twoN == 2 * N and (K % 32 == 0) and (N % 16 == 0)
and (twoN % 32 == 0)):
raise RuntimeError(
"VLLM_USE_MARLIN_W16A16_MOE is enabled, but MoE weights "
"are not pre-packed in Marlin layout. Pre-pack weights "
"during the post-load hook or disable "
"VLLM_USE_MARLIN_W16A16_MOE."
)
# Non-Marlin paths need the original weight shapes.
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
N,
K if not use_nn_moe else w2.shape[2],
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(
M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
......
......@@ -37,6 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
......@@ -406,6 +407,86 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is enabled, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (envs.VLLM_USE_MARLIN_W16A16_MOE and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)):
w1 = layer.w13_weight
w2 = layer.w2_weight
if (w1.is_cuda and w2.is_cuda
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)):
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
use_lightop as _use_lightop)
if not _use_lightop:
raise RuntimeError(
"Marlin W16A16 MoE kernel is disabled")
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(
0):
raise RuntimeError("Unexpected MoE weight shapes")
twoN, K = w1.size(1), w1.size(2)
if w2.size(1) != K:
raise RuntimeError("Unexpected MoE w2 layout")
N = w2.size(2)
if twoN != 2 * N:
raise RuntimeError("Unexpected MoE hidden dims")
if (K % 16 != 0 or K % 32 != 0 or N % 16 != 0
or twoN % 32 != 0):
raise RuntimeError("Marlin packing requires alignment")
from vllm.model_executor.layers.fused_moe.marlin_quant import (
w16a16_marlin_weight)
from torch.nn.parameter import Parameter
def _pack_per_expert(weight: torch.Tensor) -> torch.Tensor:
num_experts = weight.shape[0]
packed0 = w16a16_marlin_weight(
weight[0]).contiguous()
packed = packed0.new_empty((num_experts, ) +
packed0.shape)
packed[0].copy_(packed0)
del packed0
for i in range(1, num_experts):
tmp = w16a16_marlin_weight(
weight[i]).contiguous()
packed[i].copy_(tmp)
del tmp
return packed
with torch.no_grad():
w1_packed = _pack_per_expert(w1)
w2_packed = _pack_per_expert(w2)
new_w1 = Parameter(w1_packed, requires_grad=False)
new_w2 = Parameter(w2_packed, requires_grad=False)
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "marlin_w16a16_packed", True)
setattr(new_w2, "marlin_w16a16_packed", True)
layer.w13_weight = new_w1
layer.w2_weight = new_w2
layer._marlin_w16a16_moe_packed = True
return
except Exception:
# If packing dependencies are unavailable, fall back to the
# standard (non-Marlin) layouts.
pass
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
......@@ -819,6 +900,7 @@ class FusedMoE(torch.nn.Module):
self.expert_load_view: Optional[torch.Tensor] = None
self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
# Determine expert maps
if self.use_ep:
......@@ -1545,7 +1627,8 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default.
"""
if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels):
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels
or self.enable_dp_attention):
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -1684,6 +1767,7 @@ class FusedMoE(torch.nn.Module):
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_deepep_ll_kernels
and not self.moe_parallel_config.use_deepep_auto_kernels
and not self.enable_dp_attention
)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
......
......@@ -22,6 +22,7 @@ from vllm.utils import round_up
try:
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from lightop import op
except Exception:
print("INFO: Please install lmslim if you want to use int utils.\n")
from vllm.utils import cdiv
......@@ -622,9 +623,17 @@ def ep_scatter(
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
_fwd_kernel_ep_scatter_1[(grid,)](
......@@ -665,8 +674,8 @@ def ep_scatter(
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return
......
......@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
......@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
......@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
......@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
tp_size = get_moe_tp_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
......@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
expect_tp_size=expect_tp_size)
expect_tp_size=expect_tp_size,
enable_dp_attn_moe=enable_dp_attn_moe)
def weight_loader(self,
param: Parameter,
......@@ -1000,6 +1013,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_rank = 0
tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
......@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe and hasattr(param, "enable_dp_attn_moe"):
tp_size = get_moe_tp_size()
param.enable_dp_attn_moe = self.enable_dp_attn_moe
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
......@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
......@@ -1610,6 +1632,11 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None:
tp_rank = 0
tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
......@@ -1664,6 +1691,9 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe is not None and hasattr(param, "enable_dp_attn_moe"):
param.enable_dp_attn_moe = self.enable_dp_attn_moe
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(
......
......@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe=False,
use_fused_gate: Optional[bool] = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
......
......@@ -140,7 +140,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_,
) -> torch.Tensor:
return self.fp8_linear.apply(input=x,
weight=layer.weight,
......
......@@ -857,7 +857,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,**_,
) -> torch.Tensor:
if enable_eplb:
assert expert_load_view is not None
......
......@@ -12,6 +12,7 @@ from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
......@@ -278,25 +279,27 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
qinput = qinput.view(-1,qinput.shape[-1])
output = triton_scaled_mm_fp8(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
# if type(output) is tuple and len(output) == 2:
# output = output[0]
# # Unpad (undo num_token_padding)
# output = torch.narrow(output, 0, 0, input_2d.shape[0])
# x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
#
# # DQ
# # C = sw * sx * (X * W) + bias
# output = output * x_scale * scale_b.t()
# if bias is not None:
# output = output + bias
return output.view(*output_shape)
def dispatch_w8a8_scaled_mm(
......
......@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.dp_attention import (dp_gather, dp_reduce_scatter_tensor,
get_moe_tp_size, get_moe_tp_rank, get_attention_tp_size)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -84,17 +86,22 @@ class DeepseekV2MLP(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
enable_dp_attn_moe=enable_dp_attention)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
prefix=f"{prefix}.down_proj",
enable_dp_attn_moe=enable_dp_attention)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -105,7 +112,9 @@ class DeepseekV2MLP(nn.Module):
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor:
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = None, None
if iqis is not None:
i_q, i_s = iqis
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
......@@ -1007,6 +1016,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.config = config
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
......@@ -1034,6 +1045,9 @@ class DeepseekV2DecoderLayer(nn.Module):
DeepseekV2MoE) and self.use_deepep and \
self.tp_size > 1 and not self.is_mtp_layer:
reduce_results = False
else:
if self.enable_dp_attention:
reduce_results = False
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
......@@ -1204,6 +1218,10 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention:
if self.tp_rank == 0:
hidden_states += residual
hidden_states = dp_gather(hidden_states)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
......@@ -1216,8 +1234,14 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1. / self.routed_scaling_factor
# Fully Connected
if not self.enable_dp_attention:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
else:
num_tokens = hidden_states.shape[0]
new_bs = num_tokens // get_moe_tp_size() * get_attention_tp_size()
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
......@@ -1232,6 +1256,9 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.mlp(hidden_states)
if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
......
......@@ -10,6 +10,7 @@ from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
......@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self._output_dim = output_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property
......@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
......@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
......@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
self._input_dim = input_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property
def input_dim(self):
......@@ -183,6 +194,10 @@ class RowvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
......
......@@ -1051,11 +1051,14 @@ class Scheduler(SchedulerInterface):
def schedule(self) -> SchedulerOutput:
if envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd()
else:
if self.use_mla:
if self.connector is not None:
return self.schedule_default()
if self.full_cuda_graph and self.use_mla and self.num_spec_tokens > 0 :
if self.full_cuda_graph and self.num_spec_tokens > 0:
return self.schedule_split_pd()
else:
self.schedule_default()
else:
return self.schedule_split_pd()
else:
return self.schedule_default()
......@@ -1107,7 +1110,7 @@ class Scheduler(SchedulerInterface):
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[-num_tokens:]
token_ids = req.all_token_ids[-num_tokens:] if num_tokens > 0 else []
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
......@@ -1241,7 +1244,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
request.num_generated_token_ids = len(generated_token_ids)
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
......@@ -1253,7 +1256,6 @@ class Scheduler(SchedulerInterface):
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
......
......@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import DPMetadata, set_forward_context
from vllm.forward_context import DPMetadata, set_forward_context, get_warming_up
from vllm.logger import init_logger
import vllm.envs as envs
from vllm.model_executor.model_loader import get_model
......@@ -93,6 +93,8 @@ class EagleProposer:
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.attn_tp_size = vllm_config.parallel_config.tensor_parallel_size
def propose(
self,
......@@ -189,6 +191,8 @@ class EagleProposer:
else:
num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
......@@ -279,6 +283,13 @@ class EagleProposer:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......@@ -532,9 +543,8 @@ class EagleProposer:
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if not self.enable_dp_attention and not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit.
return 0, None
......@@ -566,7 +576,7 @@ class EagleProposer:
# Padding for DP
num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_pad, _ = self.get_dp_padding(num_tokens)
num_input_tokens += num_pad
with set_forward_context(attn_metadata,
......@@ -578,9 +588,49 @@ class EagleProposer:
self.hidden_states[:num_input_tokens],
)
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1:
num_token = 1
for _ in range(self.num_speculative_tokens - 1):
if self.dp_size > 1 and (self.enable_expert_parallel or self.enable_dp_attention) and self.num_speculative_tokens > 1:
num_tokens = 1
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
num_tokens = round_up(num_tokens, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(num_tokens)
num_tokens += num_pad
if not get_warming_up():
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_tokens + 1],
seq_lens=self.runner.seq_lens[:num_tokens],
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
slot_mapping=self.runner.slot_mapping[:num_tokens],
spec_layer_decoding=True
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build_for_cudagraph_capture(
common_attn_metadata=common_attn_metadata
)
for i in range(self.num_speculative_tokens - 1):
if self.attn_metadata_cudagraph is not None:
if i == 0:
attn_metadata_cudagraph = self.attn_metadata_cudagraph
attn_metadata_cudagraph.num_actual_tokens = num_tokens
attn_metadata_cudagraph.num_decodes = num_tokens
attn_metadata_cudagraph.num_decode_tokens = num_tokens
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
attn_metadata_cudagraph.decode.seq_lens[:num_tokens] = (
attn_metadata.decode.seq_lens)
self.attn_metadata_cudagraph.query_start_loc[:num_tokens + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:num_tokens] = (
attn_metadata.decode.block_table)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
......
......@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
prepare_communication_buffer_for_model,
get_tensor_model_parallel_world_size)
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context, set_profilling)
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
......@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True
self.enable_dp_attention = self.parallel_config.enable_dp_attention
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
......@@ -512,14 +514,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
num_new_tokens = len(new_token_ids)
if num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
new_token_ids)
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
......@@ -539,6 +537,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
if not is_last_rank:
req_state = self.requests[req_id]
self.input_batch.add_request(req_state)
req_index = self.input_batch.req_id_to_index.get(req_id)
else:
req_ids_to_add.append(req_id)
continue
......@@ -552,6 +555,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
if len(new_token_ids) > 0:
end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[
req_index,
......@@ -1275,9 +1279,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
if not self.enable_dp_attention and not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# Early exit.
return 0, None
......@@ -1357,7 +1360,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......@@ -1600,6 +1603,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if req_state.output_token_ids:
continue
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
......@@ -1638,9 +1646,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
......@@ -1681,6 +1686,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
attn_metadata,
)
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
......@@ -2123,11 +2132,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
if num_tokens < self.tp_size:
num_tokens = self.tp_size
# Padding for DP
num_tokens_across_dp = 0
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
......@@ -2148,13 +2157,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
if not self.ep_sp:
if not (self.ep_sp or self.enable_dp_attention):
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else:
if self.speculative_config is not None:
......@@ -3219,7 +3228,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......@@ -3461,6 +3470,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if req_state.output_token_ids:
continue
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
......@@ -3482,7 +3496,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states[:num_scheduled_tokens],
scheduler_output,
)
#-----------------------------------
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment