Commit 8943d3db authored by yangql's avatar yangql
Browse files

解决deep的auto冲突

parents 0d3ae2fc ab1acdce
......@@ -173,6 +173,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
if self.internode:
num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30 #self.num_sms // 2
self.num_sms = 30
# import deep_ep
# num_nvl_bytes, num_rdma_bytes = 0, 0
......@@ -184,6 +185,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
else:
num_rdma_bytes = 0
num_qps_per_rank = 1
self.num_sms = 60
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
......@@ -192,6 +194,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
allow_mnnvl=envs.VLLM_ALLOW_MNNVL,
explicitly_destroy=False)
def get_handle(self, kwargs):
......
......@@ -180,6 +180,7 @@ if TYPE_CHECKING:
VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_BALANCE: bool = False
VLLM_USE_ZERO_MTP: bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
def get_default_cache_root():
return os.getenv(
......@@ -1181,6 +1182,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_ZERO_MTP":
lambda: (os.getenv('VLLM_USE_ZERO_MTP', '1').lower() in
("true", "1")),
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -21,11 +21,13 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
super().__init__()
self.ht_prepare_finalize = ht_prepare_finalize
self.ll_prepare_finalize = ll_prepare_finalize
self._current_phase = "decode" # default to prefill (HT)
self._current_phase = "decode" # default to decode (LL)
def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available
# Try to infer phase from forward_context if available:
# - 有 decode tokens -> 使用 LL (decode)
# - 否则默认 HT (prefill)
try:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
......@@ -36,44 +38,60 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else:
attn_metadata = None
if attn_metadata is not None and hasattr(attn_metadata, 'num_prefill_tokens') and hasattr(attn_metadata, 'num_decode_tokens'):
# Only use prefill mode when BOTH conditions are met:
# 1. There are prefill tokens and no decode tokens
# 2. skip_cuda_graphs is True
is_prefill_tokens = attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens == 0
skip_cuda_graphs = forward_context.skip_cuda_graphs
# Only use prefill (HT) when both conditions are satisfied
self._current_phase = "prefill" if (is_prefill_tokens and skip_cuda_graphs) else "decode"
if attn_metadata is not None and hasattr(attn_metadata,
"num_decode_tokens"):
# 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
self._current_phase = ("decode"
if attn_metadata.num_decode_tokens > 0
else "prefill")
except Exception:
# If forward_context is not available, use stored phase
pass
# Prefill uses HT, decode uses LL
# print("self._current_phase",self._current_phase)
# if self._current_phase == "prefill":
if self._current_phase == "prefill":
print("************prefill***********")
# return self.ht_prepare_finalize
# else:
return self.ll_prepare_finalize
# return self.ll_prepare_finalize
return self.ht_prepare_finalize
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
# Use the current prepare_finalize's activation format
# Note: HT uses Standard, LL uses BatchedExperts
# Dynamically return based on current phase
prepare_finalize = self._get_current_prepare_finalize()
return prepare_finalize.activation_format
pf = self._get_current_prepare_finalize()
try:
return pf.activation_format
except NotImplementedError:
# Fallback to standard format if underlying impl does not provide it.
return mk.FusedMoEActivationFormat.Standard
def topk_indices_dtype(self) -> Optional[torch.dtype]:
# Both HT and LL return int64
return torch.int64
pf = self._get_current_prepare_finalize()
return pf.topk_indices_dtype()
def max_num_tokens_per_rank(self) -> Optional[int]:
# LL has a limit, HT returns None
return self.ll_prepare_finalize.max_num_tokens_per_rank()
pf = self._get_current_prepare_finalize()
return pf.max_num_tokens_per_rank()
def num_dispatchers(self) -> int:
# Both should return the same value
return self.ht_prepare_finalize.num_dispatchers()
pf = self._get_current_prepare_finalize()
return pf.num_dispatchers()
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
):
pf = self._get_current_prepare_finalize()
return pf.prepare_async(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def prepare(
self,
......@@ -88,9 +106,8 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Route prepare call to the appropriate implementation."""
prepare_finalize = self._get_current_prepare_finalize()
return prepare_finalize.prepare(
pf = self._get_current_prepare_finalize()
return pf.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
......@@ -103,9 +120,8 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
"""Route finalize call to the appropriate implementation."""
prepare_finalize = self._get_current_prepare_finalize()
return prepare_finalize.finalize(
pf = self._get_current_prepare_finalize()
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
......@@ -118,15 +134,11 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
):
"""Route finalize_async call to the appropriate implementation if available."""
prepare_finalize = self._get_current_prepare_finalize()
if hasattr(prepare_finalize, 'finalize_async'):
return prepare_finalize.finalize_async(
pf = self._get_current_prepare_finalize()
if hasattr(pf, "finalize_async"):
return pf.finalize_async(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
else:
# Fallback to synchronous finalize
return prepare_finalize.finalize(
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from collections.abc import Callable
import deep_ep
import torch
......@@ -58,39 +59,49 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return None
return deep_ep.Buffer.get_combine_config(self.dp_size)
def sync(self):
# torch.cuda.synchronize()
dist.barrier()
def _do_dispatch(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
def _do_dispatch(
self,
tokens: torch.Tensor,
token_scales: torch.Tensor | None,
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor, num_experts: int):
rank_topk_weights: torch.Tensor,
num_experts: int,
quant_config: FusedMoEQuantConfig,
) -> Callable:
has_scales = token_scales is not None
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens,
is_token_in_rank, event) = self.buffer.get_dispatch_layout(
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
dispatch_expert_num_tokens,
is_token_in_rank,
event,
) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
allocate_on_comm_stream=False,
)
token_data = tokens
if has_scales:
token_data = (tokens, token_scales)
(
token_data, expert_topk_ids, expert_topk_weights,
expert_num_tokens_per_expert_list, self.handle, event
token_data,
expert_topk_ids,
expert_topk_weights,
expert_num_tokens_per_expert_list,
self.handle,
event,
) = self.buffer.dispatch(
x=token_data,
handle=None,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens,
num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
......@@ -98,8 +109,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
async_finish=True,
allocate_on_comm_stream=False,
)
return lambda: self._receiver(
event,
has_scales,
token_data,
expert_topk_ids,
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
token_scales,
quant_config,
)
def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
expert_topk_ids: torch.Tensor | None,
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if event.event is not None:
event.current_stream_wait()
if has_scales:
expert_x, expert_x_scale = token_data
......@@ -117,15 +156,45 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
assert expert_topk_ids is not None
expert_topk_ids = torch.where(
expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset)
expert_topk_ids + self.rank_expert_offset,
)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights)
# Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device
)
def prepare(
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.per_act_token_quant:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape,
)
return (
expert_x,
expert_x_scale,
expert_tokens_meta,
expert_topk_ids,
expert_topk_weights,
)
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
......@@ -136,14 +205,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> mk.ReceiverType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
if quant_config.per_act_token_quant:
......@@ -156,35 +224,43 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
else:
a1q = a1
a1q_scale = None
return self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
else:
# DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16.
(expert_x, _, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
# quantize now
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
num_experts=num_experts,
quant_config=quant_config,
)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
return receiver()
def _apply_weights_and_reduce(self, num_tokens: int,
fused_expert_output: torch.Tensor,
......@@ -210,31 +286,88 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return out
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
do_async: bool,
apply_weights_and_reduce: bool = True,
) -> Callable | None:
assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0 and apply_weights_and_reduce:
fused_expert_output = self._apply_weights_and_reduce(
num_tokens=topk_ids.size(0),
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
apply_router_weight_on_input=apply_router_weight_on_input,
output_dtype=output.dtype)
# if fused_expert_output.numel() != 0 and apply_weights_and_reduce:
# fused_expert_output = self._apply_weights_and_reduce(
# num_tokens=topk_ids.size(0),
# fused_expert_output=fused_expert_output,
# topk_weights=topk_weights,
# apply_router_weight_on_input=apply_router_weight_on_input,
# output_dtype=output.dtype)
combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output,
handle=self.handle,
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
async_finish=do_async,
allocate_on_comm_stream=False,
)
if do_async:
def _receiver():
if event.event is not None:
event.current_stream_wait()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return _receiver
else:
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return None
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True,
) -> Callable:
receiver = self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
do_async=True,
apply_weights_and_reduce=apply_weights_and_reduce,
)
assert receiver is not None
return receiver
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
do_async=False,
apply_weights_and_reduce=apply_weights_and_reduce,
)
......@@ -115,7 +115,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales
def prepare(
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
......@@ -126,9 +126,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[Callable, mk.ReceiverType]:
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
......@@ -148,25 +146,74 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \
self.buffer.low_latency_dispatch(a1,
expert_x, expert_num_tokens, self.handles, _, hook = self.buffer.low_latency_dispatch(
a1,
topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch,
use_int8=self.use_int8_dispatch,
async_finish=False,
return_recv_hook=False)
return_recv_hook=True,
)
return (
hook,
lambda: self._receiver(
expert_x,
expert_num_tokens,
a1_scale,
a1.dtype,
quant_config,
),
)
def _receiver(
self,
expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
expert_num_tokens: torch.Tensor,
a1_scale: torch.Tensor | None,
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape, expert_num_tokens)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
hook()
return receiver()
def _finalize(
self,
......
......@@ -4,13 +4,15 @@ from abc import ABC, abstractmethod
from enum import Enum
from math import prod
from typing import Optional, final
from dataclasses import dataclass
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv
from vllm.utils import cdiv, async_tensor_h2d
#
# This file defines a set of base classes used to make MoE kernels more modular.
......@@ -95,6 +97,57 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: torch.Tensor | None
@staticmethod
def make_from_list(
expert_num_tokens_list: list[int], device: str
) -> "ExpertTokensMetadata":
# expert_num_tokens_cpu = torch.tensor(
# expert_num_tokens_list, device="cpu", dtype=torch.int32
# )
expert_num_tokens_cpu = torch.tensor(
expert_num_tokens_list, device="cpu", dtype=torch.int32, pin_memory=True
)
expert_num_tokens = expert_num_tokens_cpu.to(device=device, non_blocking=True)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType = tuple[
torch.Tensor,
torch.Tensor | None,
ExpertTokensMetadata | None,
torch.Tensor | None,
torch.Tensor | None,
]
ReceiverType = Callable[[], PrepareResultType]
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
......@@ -880,8 +933,19 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
# (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
# _expert_topk_weights) = self.prepare_finalize.prepare(
# a1,
# a1_scale,
# a2_scale,
# topk_weights,
# topk_ids,
# global_num_experts,
# expert_map,
# apply_router_weight_on_input,
# self.fused_experts.quant_config,
# )
prepare_ret = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
......@@ -892,12 +956,35 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
)
if hook is not None:
hook()
(
a1q,
a1q_scale,
expert_tokens_meta,
_expert_topk_ids,
_expert_topk_weights,
) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights)
if a1q.numel() == 0:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
fused_out = self.fused_experts.apply(
None,
a1,
......@@ -918,18 +1005,15 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
workspace13=None,
workspace2=None,
use_nn_moe=use_nn_moe,
expert_num_tokens=expert_num_tokens,
expert_num_tokens=expert_tokens_meta.expert_num_tokens,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
expert_num_tokens_cpu=expert_tokens_meta.expert_num_tokens_cpu
)
shared_output = None
if self.shared_experts is None:
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
......
......@@ -85,6 +85,7 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
expert_num_tokens_cpu: torch.Tensor = None,
):
assert self.fused_experts is not None
......@@ -107,4 +108,5 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
q_x=q_hidden_states,
expert_num_tokens_cpu=expert_num_tokens_cpu
)
......@@ -11,6 +11,7 @@ from triton.language.extra import libdevice
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import round_up
try:
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
......@@ -276,8 +277,8 @@ def _int8_quantize(
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None:
assert per_act_token, \
"int8 quantization only supports block or channel-wise"
# assert per_act_token, \
# "int8 quantization only supports block or channel-wise"
if expert_num_tokens is None:
A, A_scale = per_token_quant_int8(A)
else:
......@@ -361,3 +362,502 @@ def _validate_scale_shape(
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
@triton.jit
def _count_expert_num_tokens(
topk_ids_ptr,
expert_num_tokens_ptr,
num_experts,
topk_numel,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
curr_expert = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
topk_ids_ptrs = topk_ids_ptr + offsets
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)):
mask = offsets < (topk_numel - x * BLOCK_SIZE)
expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1)
if HAS_EXPERT_MAP:
expert_map_ptrs = expert_map + expert_ids
expert_map_mask = expert_ids >= 0
expert_ids = tl.load(expert_map_ptrs, mask=expert_map_mask, other=-1)
has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0)
acc = acc + has_curr_expert
topk_ids_ptrs += BLOCK_SIZE
if curr_expert < num_experts:
tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))
def count_expert_num_tokens(
topk_ids: torch.Tensor, num_local_experts: int, expert_map: torch.Tensor | None
) -> torch.Tensor:
"""
Count the number to tokens assigned to each expert.
Parameters:
- topk_ids (torch.Tensor): Tensor mapping each token to its
list of experts.
- num_local_experts (int): Number of experts in this rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
Returns:
A tensor of size num_local_experts, where tensor[i] holds the number
of tokens assigned to the ith expert.
"""
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
expert_num_tokens = torch.empty(
(num_local_experts), device=topk_ids.device, dtype=torch.int32
)
grid = num_local_experts
BLOCK_SIZE = min(topk_ids.numel(), 1024)
BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)
_count_expert_num_tokens[(grid,)](
topk_ids,
expert_num_tokens,
num_local_experts,
topk_ids.numel(),
expert_map,
HAS_EXPERT_MAP=expert_map is not None,
BLOCK_SIZE=BLOCK_SIZE,
)
return expert_num_tokens
def expert_num_tokens_round_up_and_sum(
expert_num_tokens: torch.Tensor, alignment: int
) -> int:
# Round up each element in expert_num_tokens to the nearest multiple of
# alignment.
ent = (expert_num_tokens.to(torch.int64) + (alignment - 1)) // alignment * alignment
return torch.sum(ent).item()
def compute_aligned_M(
M: int,
num_topk: int,
local_num_experts: int,
alignment: int,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
):
if expert_num_tokens_cpu is not None:
return expert_num_tokens_round_up_and_sum(
expert_num_tokens_cpu, alignment=alignment
)
# expert_num_tokens information is not available on the cpu.
# compute the max required size.
M_sum = (M * num_topk) + local_num_experts * (alignment - 1)
M_sum = round_up(M_sum, alignment)
return M_sum
@triton.jit
def apply_expert_map(expert_id, expert_map):
if expert_id != -1:
expert_id = tl.load(expert_map + expert_id).to(expert_id.dtype)
return expert_id
@triton.jit
def round_up_256(x: int) -> int:
y = 256
return ((x + y - 1) // y) * y
@triton.jit
def round_up_128(x: int) -> int:
y = 128
return ((x + y - 1) // y) * y
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts: tl.constexpr,
BLOCK_E: tl.constexpr,
BLOCK_EXPERT_NUM: tl.constexpr,
):
cur_expert = tl.program_id(0)
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
tokens_per_expert = tl.load(
num_recv_tokens_per_expert + offset_cumsum,
mask=offset_cumsum < num_experts,
other=0,
)
#tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert = round_up_256(tokens_per_expert)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
#if cur_expert == 0:
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
tl.debug_barrier()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
tl.store(
m_indices_start_ptr + start_m + off_expert,
cur_expert,
mask=start_m + off_expert < cur_expert_token_num
)
@triton.jit
def _fwd_kernel_ep_scatter_2(
total_token_num,
expert_start_loc,
recv_x,
recv_x_stride0,
recv_x_stride1,
recv_x_scale,
recv_x_scale_stride0,
recv_x_scale_stride1,
recv_topk,
recv_topk_stride0,
recv_topk_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
output_tensor_scale,
output_tensor_scale_stride0,
output_tensor_scale_stride1,
output_index,
output_index_stride0,
output_index_stride1,
topk_num: tl.constexpr,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
HIDDEN_SIZE_PAD: tl.constexpr,
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
):
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = index_in_s < SCALE_HIDDEN_SIZE
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
token_id = token_id_int32.to(tl.int64)
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(
recv_x_scale
+ token_id * recv_x_scale_stride0
+ index_in_s * recv_x_scale_stride1,
mask=mask_s,
)
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
topk_index = topk_idx_int32.to(tl.int64)
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index = dest_token_index_int32.to(tl.int64)
tl.store(
output_index + token_id * output_index_stride0 + topk_index,
dest_token_index_int32,
)
output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0
)
output_tensor_scale_ptr = (
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
)
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
to_copy_s,
mask=mask_s,
)
@torch.no_grad()
def ep_scatter(
recv_x: torch.Tensor,
recv_x_scale: torch.Tensor,
recv_topk: torch.Tensor,
num_recv_tokens_per_expert: torch.Tensor,
expert_map: torch.Tensor | None,
expert_start_loc: torch.Tensor,
output_tensor: torch.Tensor,
output_tensor_scale: torch.Tensor,
m_indices: torch.Tensor,
output_index: torch.Tensor,
):
#BLOCK_E = 128 # token num of per expert is aligned to 128
#BLOCK_D = 128 # block size of quantization
BLOCK_E = 256 # token num of per expert is aligned to 256
num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@triton.jit
def _fwd_kernel_ep_gather(
total_token_num,
input_tensor,
input_tensor_stride0,
input_tensor_stride1,
recv_topk_ids,
recv_topk_ids_stride0,
recv_topk_ids_stride1,
recv_topk_weight,
recv_topk_weight_stride0,
recv_topk_weight_stride1,
input_index,
input_index_stride0,
input_index_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
topk_num: tl.constexpr,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_D: tl.constexpr,
):
cur_block_int32 = tl.program_id(0)
cur_block = cur_block_int32.to(tl.int64)
start_cur_token_int32 = tl.program_id(1)
grid_num = tl.num_programs(1)
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
cur_token = cur_token_int32.to(tl.int64)
off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index_int32 in range(0, topk_num):
topk_index = topk_index_int32.to(tl.int64)
expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
source_token_index_int32 = tl.load(
input_index + cur_token * input_index_stride0 + topk_index
)
source_token_index = source_token_index_int32.to(tl.int64)
acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
)
tmp = tl.load(
input_tensor
+ source_token_index * input_tensor_stride0
+ cur_block * BLOCK_D
+ off_d
)
accumulator += tmp.to(tl.float32) * acc_weight
tl.store(
output_tensor
+ cur_token * output_tensor_stride0
+ cur_block * BLOCK_D
+ off_d,
accumulator.to(output_tensor.dtype.element_ty),
)
@torch.no_grad()
def ep_gather(
input_tensor: torch.Tensor,
recv_topk_ids: torch.Tensor,
recv_topk_weight: torch.Tensor,
input_index: torch.Tensor,
expert_map: torch.Tensor | None,
output_tensor: torch.Tensor,
):
num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
BLOCK_D = min(hidden_size, 1024)
assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid](
num_tokens,
input_tensor,
input_tensor.stride(0),
input_tensor.stride(1),
recv_topk_ids,
recv_topk_ids.stride(0),
recv_topk_ids.stride(1),
recv_topk_weight,
recv_topk_weight.stride(0),
recv_topk_weight.stride(1),
input_index,
input_index.stride(0),
input_index.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
topk_num=recv_topk_ids.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
BLOCK_D=BLOCK_D,
)
return
def deepgemm_moe_permute(
aq: torch.Tensor,
aq_scale: torch.Tensor,
topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: torch.Tensor | None,
block_shape: list[int],
expert_num_tokens: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
aq_out: torch.Tensor | None = None,
M_sum: int | None = None,
):
assert aq.ndim == 2
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
H = aq.size(1)
device = aq.device
block_m = block_shape[0]
if M_sum is None:
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=block_m,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
expert_start_loc = torch.empty(
(local_num_experts), device=device, dtype=torch.int32
)
assert aq_out is None or aq_out.shape == (M_sum, H)
if aq_out is None:
aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype)
aq_scale_out = torch.empty(
(M_sum, aq_scale.shape[-1]), device=device, dtype=torch.float32
#(M_sum, H // block_k), device=device, dtype=torch.float32
)
# maybe_has_empty_blocks = expert_num_tokens_cpu is None
# expert_ids_init = torch.zeros# if maybe_has_empty_blocks else torch.empty
# expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
expert_ids = torch.full(
(M_sum,), -1, dtype=torch.int32, device=device
)
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
if expert_num_tokens is None:
expert_num_tokens = count_expert_num_tokens(
topk_ids, local_num_experts, expert_map
)
ep_scatter(
recv_x=aq,
recv_x_scale=aq_scale,
recv_topk=topk_ids,
num_recv_tokens_per_expert=expert_num_tokens,
expert_start_loc=expert_start_loc,
expert_map=expert_map,
output_tensor=aq_out,
output_tensor_scale=aq_scale_out,
m_indices=expert_ids,
output_index=inv_perm,
)
return aq_out, aq_scale_out, expert_ids, inv_perm
def deepgemm_unpermute_and_reduce(
a: torch.Tensor, # Grouped gemm output
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
inv_perm: torch.Tensor,
expert_map: torch.Tensor | None,
output: torch.Tensor,
):
return ep_gather(
input_tensor=a,
recv_topk_ids=topk_ids,
recv_topk_weight=topk_weights,
input_index=inv_perm,
expert_map=expert_map,
output_tensor=output,
)
......@@ -19,12 +19,15 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig, FusedMoeWeightScaleSupported,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce
from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.utils import round_up
try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep
from lightop import m_grouped_w8a8_gemm_nt_masked, m_grouped_w8a8_gemm_nt_contig_asm, fuse_silu_mul_quant_ep, fuse_silu_mul_quant
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
......@@ -84,26 +87,27 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size
backend = envs.VLLM_ALL2ALL_BACKEND
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(backend == "deepep_high_throughput" or \
backend == "deepep_low_latency" or \
backend == "deepep_auto")
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.use_deepep_ll = self.use_deepep and (backend == "deepep_low_latency" or \
(backend == "deepep_auto"))
#self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
self.block_shape = [256, 256]
self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM or envs.VLLM_ALL2ALL_BACKEND == "deepep_auto"
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep_ll:
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
......@@ -157,7 +161,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep_ll:
if not self.use_deepgemm:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
......@@ -168,7 +172,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep_ll:
if not self.use_deepgemm:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
......@@ -178,7 +182,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def groupgemm_workspace_shapes(self,
def masked_groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
......@@ -201,7 +205,26 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w8a8_groupgemm_forward(self,
def contiguous_groupgemm_workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_num_tokens_cpu: torch.Tensor
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert self.block_shape is not None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m = self.block_shape[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_num_tokens_cpu
)
assert M_sum % block_m == 0
workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype, M_sum)
def w8a8_groupgemm_masked_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -220,6 +243,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
......@@ -230,7 +254,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes(
workspace_dtype) = self.masked_groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
......@@ -269,6 +293,94 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return fused_out
def w8a8_groupgemm_contiguous_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
a1q = q_x
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype, M_sum) = self.contiguous_groupgemm_workspace_shapes(
x, q_x, topk_ids.size(0), N, K, topk_ids.size(1), global_num_experts,
local_num_experts, expert_num_tokens_cpu)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=x.device,
dtype=workspace_dtype)
mm1_out = _resize_cache(workspace13, (M_sum, N))
mm2_out = _resize_cache(workspace2, (M_sum, K))
fused_out = _resize_cache(workspace13, fused_out_shape)
a1q_perm = _resize_cache(workspace2.view(dtype=a1q.dtype), (M_sum, K))
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
block_shape=self.block_shape,
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
aq_out=a1q_perm,
M_sum=M_sum
)
# if expert_map is not None:
# # DeepGemm (Grouped Contiguous) kernel needs a valid B index
# # for all rows of A. To that effect, simply compute with
# # the 0th weight matrix.
# # Note that this relies on the fact that corresponding topk
# # weights would be 0 during weight multiplication.
# expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
a2q, a2q_scale = fuse_silu_mul_quant(mm1_out)
m_grouped_w8a8_gemm_nt_contig_asm(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=fused_out,
)
return fused_out
def fused_moe_forward(self,
x: torch.Tensor,
......@@ -289,6 +401,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ):
return fused_experts_impl_int8_marlin(
hidden_states=x if q_x is None else q_x,
......@@ -401,7 +514,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return TritonOrGroupGemmExperts(
use_int8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_forward
fused_experts=self.w8a8_groupgemm_masked_forward
)
else:
logger.debug(
......@@ -410,5 +523,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
False)
return TritonOrGroupGemmExperts(
fused_experts=self.fused_moe_forward
use_int8_w8a8=envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM,
fused_experts=self.w8a8_groupgemm_contiguous_forward if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else self.fused_moe_forward
)
\ No newline at end of file
......@@ -168,6 +168,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.ep_size = get_ep_group().world_size
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
......@@ -352,7 +354,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
#expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.ep_size
m_grouped_w4a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
......
......@@ -174,14 +174,12 @@ class DeepseekV2MoE(nn.Module):
dp_size = get_dp_group().world_size
self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel
backend = envs.VLLM_ALL2ALL_BACKEND
self.use_deepep_ll = (
dp_size > 1
and parallel_config.enable_expert_parallel
and (backend == "deepep_low_latency" or backend == "deepep_auto")
)
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if not self.use_deepep_ll:
if not self.use_deepep:
moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
self.experts = moe_cls(
num_experts=config.n_routed_experts,
......@@ -254,7 +252,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_mori_ep and not self.use_deepep_ll:
if not self.use_mori_ep and not self.use_deepep:
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
......@@ -289,7 +287,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
if self.use_deepep_ll:
if self.use_deepep:
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
......@@ -721,12 +719,10 @@ class DeepseekV2DecoderLayer(nn.Module):
self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
backend = envs.VLLM_ALL2ALL_BACKEND
self.use_deepep_ll = (
self.dp_size > 1
and parallel_config.enable_expert_parallel
and (backend == "deepep_low_latency" or backend == "deepep_auto")
)
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.tp_size = get_tensor_model_parallel_world_size()
if (config.n_routed_experts is not None
......@@ -855,7 +851,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual)
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep_ll and self.tp_size > 1:
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
self.tp_rank = get_tensor_model_parallel_rank()
ori_bs = hidden_states.shape[0]
......@@ -868,7 +864,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep_ll and self.tp_size > 1:
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0).contiguous()
hidden_states = hidden_states[:ori_bs, :].contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment