Commit 8943d3db authored by yangql's avatar yangql
Browse files

解决deep的auto冲突

parents 0d3ae2fc ab1acdce
...@@ -173,6 +173,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -173,6 +173,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
if self.internode: if self.internode:
num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024 num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30 #self.num_sms // 2 num_qps_per_rank = 30 #self.num_sms // 2
self.num_sms = 30
# import deep_ep # import deep_ep
# num_nvl_bytes, num_rdma_bytes = 0, 0 # num_nvl_bytes, num_rdma_bytes = 0, 0
...@@ -184,6 +185,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -184,6 +185,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
else: else:
num_rdma_bytes = 0 num_rdma_bytes = 0
num_qps_per_rank = 1 num_qps_per_rank = 1
self.num_sms = 60
assert num_rdma_bytes is not None assert num_rdma_bytes is not None
assert num_qps_per_rank is not None assert num_qps_per_rank is not None
...@@ -192,6 +194,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -192,6 +194,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False, low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank, num_qps_per_rank=num_qps_per_rank,
allow_mnnvl=envs.VLLM_ALLOW_MNNVL,
explicitly_destroy=False) explicitly_destroy=False)
def get_handle(self, kwargs): def get_handle(self, kwargs):
......
...@@ -180,6 +180,7 @@ if TYPE_CHECKING: ...@@ -180,6 +180,7 @@ if TYPE_CHECKING:
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_BALANCE: bool = False VLLM_USE_PP_BALANCE: bool = False
VLLM_USE_ZERO_MTP: bool = False VLLM_USE_ZERO_MTP: bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1181,6 +1182,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1181,6 +1182,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_ZERO_MTP": "VLLM_USE_ZERO_MTP":
lambda: (os.getenv('VLLM_USE_ZERO_MTP', '1').lower() in lambda: (os.getenv('VLLM_USE_ZERO_MTP', '1').lower() in
("true", "1")), ("true", "1")),
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -21,11 +21,13 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -21,11 +21,13 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
super().__init__() super().__init__()
self.ht_prepare_finalize = ht_prepare_finalize self.ht_prepare_finalize = ht_prepare_finalize
self.ll_prepare_finalize = ll_prepare_finalize self.ll_prepare_finalize = ll_prepare_finalize
self._current_phase = "decode" # default to prefill (HT) self._current_phase = "decode" # default to decode (LL)
def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize: def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize:
"""Get the appropriate prepare_finalize based on current phase.""" """Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available # Try to infer phase from forward_context if available:
# - 有 decode tokens -> 使用 LL (decode)
# - 否则默认 HT (prefill)
try: try:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
...@@ -36,44 +38,60 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -36,44 +38,60 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else: else:
attn_metadata = None attn_metadata = None
if attn_metadata is not None and hasattr(attn_metadata, 'num_prefill_tokens') and hasattr(attn_metadata, 'num_decode_tokens'): if attn_metadata is not None and hasattr(attn_metadata,
# Only use prefill mode when BOTH conditions are met: "num_decode_tokens"):
# 1. There are prefill tokens and no decode tokens # 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
# 2. skip_cuda_graphs is True self._current_phase = ("decode"
is_prefill_tokens = attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens == 0 if attn_metadata.num_decode_tokens > 0
skip_cuda_graphs = forward_context.skip_cuda_graphs else "prefill")
# Only use prefill (HT) when both conditions are satisfied
self._current_phase = "prefill" if (is_prefill_tokens and skip_cuda_graphs) else "decode"
except Exception: except Exception:
# If forward_context is not available, use stored phase # If forward_context is not available, use stored phase
pass pass
# Prefill uses HT, decode uses LL # Prefill uses HT, decode uses LL
# print("self._current_phase",self._current_phase) if self._current_phase == "prefill":
# if self._current_phase == "prefill": print("************prefill***********")
# return self.ht_prepare_finalize # return self.ht_prepare_finalize
# else: # else:
return self.ll_prepare_finalize # return self.ll_prepare_finalize
return self.ht_prepare_finalize
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
# Use the current prepare_finalize's activation format pf = self._get_current_prepare_finalize()
# Note: HT uses Standard, LL uses BatchedExperts try:
# Dynamically return based on current phase return pf.activation_format
prepare_finalize = self._get_current_prepare_finalize() except NotImplementedError:
return prepare_finalize.activation_format # Fallback to standard format if underlying impl does not provide it.
return mk.FusedMoEActivationFormat.Standard
def topk_indices_dtype(self) -> Optional[torch.dtype]: def topk_indices_dtype(self) -> Optional[torch.dtype]:
# Both HT and LL return int64 pf = self._get_current_prepare_finalize()
return torch.int64 return pf.topk_indices_dtype()
def max_num_tokens_per_rank(self) -> Optional[int]: def max_num_tokens_per_rank(self) -> Optional[int]:
# LL has a limit, HT returns None pf = self._get_current_prepare_finalize()
return self.ll_prepare_finalize.max_num_tokens_per_rank() return pf.max_num_tokens_per_rank()
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
# Both should return the same value pf = self._get_current_prepare_finalize()
return self.ht_prepare_finalize.num_dispatchers() return pf.num_dispatchers()
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
):
pf = self._get_current_prepare_finalize()
return pf.prepare_async(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def prepare( def prepare(
self, self,
...@@ -88,9 +106,8 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -88,9 +106,8 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Route prepare call to the appropriate implementation.""" pf = self._get_current_prepare_finalize()
prepare_finalize = self._get_current_prepare_finalize() return pf.prepare(
return prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids, a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config) num_experts, expert_map, apply_router_weight_on_input, quant_config)
...@@ -103,9 +120,8 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -103,9 +120,8 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True apply_weights_and_reduce: bool = True
) -> None: ) -> None:
"""Route finalize call to the appropriate implementation.""" pf = self._get_current_prepare_finalize()
prepare_finalize = self._get_current_prepare_finalize() return pf.finalize(
return prepare_finalize.finalize(
output, fused_expert_output, topk_weights, topk_ids, output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce) apply_router_weight_on_input, apply_weights_and_reduce)
...@@ -118,15 +134,11 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -118,15 +134,11 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True apply_weights_and_reduce: bool = True
): ):
"""Route finalize_async call to the appropriate implementation if available.""" pf = self._get_current_prepare_finalize()
prepare_finalize = self._get_current_prepare_finalize() if hasattr(pf, "finalize_async"):
if hasattr(prepare_finalize, 'finalize_async'): return pf.finalize_async(
return prepare_finalize.finalize_async(
output, fused_expert_output, topk_weights, topk_ids, output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce) apply_router_weight_on_input, apply_weights_and_reduce)
else: return pf.finalize(
# Fallback to synchronous finalize output, fused_expert_output, topk_weights, topk_ids,
return prepare_finalize.finalize( apply_router_weight_on_input, apply_weights_and_reduce)
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional
from collections.abc import Callable
import deep_ep import deep_ep
import torch import torch
...@@ -58,39 +59,49 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -58,39 +59,49 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return None return None
return deep_ep.Buffer.get_combine_config(self.dp_size) return deep_ep.Buffer.get_combine_config(self.dp_size)
def sync(self): def _do_dispatch(
# torch.cuda.synchronize() self,
dist.barrier() tokens: torch.Tensor,
token_scales: torch.Tensor | None,
def _do_dispatch(self, tokens: torch.Tensor, rank_topk_ids: torch.Tensor,
token_scales: Optional[torch.Tensor], rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor, num_experts: int,
rank_topk_weights: torch.Tensor, num_experts: int): quant_config: FusedMoEQuantConfig,
) -> Callable:
has_scales = token_scales is not None has_scales = token_scales is not None
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens, (
is_token_in_rank, event) = self.buffer.get_dispatch_layout( num_tokens_per_rank,
topk_idx=rank_topk_ids, num_tokens_per_rdma_rank,
num_experts=num_experts, dispatch_expert_num_tokens,
previous_event=None, is_token_in_rank,
async_finish=False, event,
allocate_on_comm_stream=False) ) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False,
)
token_data = tokens token_data = tokens
if has_scales: if has_scales:
token_data = (tokens, token_scales) token_data = (tokens, token_scales)
( (
token_data, expert_topk_ids, expert_topk_weights, token_data,
expert_num_tokens_per_expert_list, self.handle, event expert_topk_ids,
expert_topk_weights,
expert_num_tokens_per_expert_list,
self.handle,
event,
) = self.buffer.dispatch( ) = self.buffer.dispatch(
x=token_data, x=token_data,
handle=None, handle=None,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens, num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids, topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights, topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert # expert_alignment rounds the number of tokens per expert
...@@ -98,8 +109,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -98,8 +109,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment=1, expert_alignment=1,
config=self._get_dispatch_config(), config=self._get_dispatch_config(),
previous_event=None, previous_event=None,
async_finish=False, async_finish=True,
allocate_on_comm_stream=False) allocate_on_comm_stream=False,
)
return lambda: self._receiver(
event,
has_scales,
token_data,
expert_topk_ids,
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
token_scales,
quant_config,
)
def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
expert_topk_ids: torch.Tensor | None,
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if event.event is not None:
event.current_stream_wait()
if has_scales: if has_scales:
expert_x, expert_x_scale = token_data expert_x, expert_x_scale = token_data
...@@ -117,15 +156,45 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -117,15 +156,45 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset # DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns # the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces. # with existing vLLM interfaces.
assert expert_topk_ids is not None
expert_topk_ids = torch.where( expert_topk_ids = torch.where(
expert_topk_ids == -1, expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0, num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset) expert_topk_ids + self.rank_expert_offset,
)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights) # Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
def prepare( # on GPU.
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device
)
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.per_act_token_quant:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape,
)
return (
expert_x,
expert_x_scale,
expert_tokens_meta,
expert_topk_ids,
expert_topk_weights,
)
def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
...@@ -136,14 +205,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -136,14 +205,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> mk.ReceiverType:
Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1 # TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, ( assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1") "apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
if quant_config.per_act_token_quant: if quant_config.per_act_token_quant:
...@@ -156,35 +224,43 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -156,35 +224,43 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
if a1q_scale is not None and a1q_scale.numel() == 1: if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1) a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
else: else:
# DeepEP kernels only support dispatching per-token-quant a1q = a1
# quantization. dispatch in bfloat16. a1q_scale = None
(expert_x, _, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch( return self._do_dispatch(
tokens=a1, tokens=a1q,
token_scales=None, token_scales=a1q_scale,
rank_topk_ids=topk_ids, rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights, rank_topk_weights=topk_weights,
num_experts=num_experts) num_experts=num_experts,
# quantize now quant_config=quant_config,
expert_x_scale = None )
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input( def prepare(
expert_x, self,
a1_scale, a1: torch.Tensor,
quant_dtype=quant_config.quant_dtype, a1_scale: Optional[torch.Tensor],
per_act_token_quant=False, a2_scale: Optional[torch.Tensor],
block_shape=quant_config.block_shape) topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, num_experts: int,
expert_topk_weights) expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
return receiver()
def _apply_weights_and_reduce(self, num_tokens: int, def _apply_weights_and_reduce(self, num_tokens: int,
fused_expert_output: torch.Tensor, fused_expert_output: torch.Tensor,
...@@ -210,31 +286,88 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -210,31 +286,88 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return out return out
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def _finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
apply_router_weight_on_input: bool, output: torch.Tensor,
apply_weights_and_reduce: bool = True) -> None: fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
do_async: bool,
apply_weights_and_reduce: bool = True,
) -> Callable | None:
assert self.handle is not None assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the # fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank. # tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0 and apply_weights_and_reduce: # if fused_expert_output.numel() != 0 and apply_weights_and_reduce:
fused_expert_output = self._apply_weights_and_reduce( # fused_expert_output = self._apply_weights_and_reduce(
num_tokens=topk_ids.size(0), # num_tokens=topk_ids.size(0),
fused_expert_output=fused_expert_output, # fused_expert_output=fused_expert_output,
topk_weights=topk_weights, # topk_weights=topk_weights,
apply_router_weight_on_input=apply_router_weight_on_input, # apply_router_weight_on_input=apply_router_weight_on_input,
output_dtype=output.dtype) # output_dtype=output.dtype)
combined_x, _, event = self.buffer.combine( combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output, x=fused_expert_output,
handle=self.handle, handle=self.handle,
topk_weights=None, topk_weights=None,
config=self._get_combine_config(), config=self._get_combine_config(),
previous_event=None, previous_event=None,
async_finish=False, async_finish=do_async,
allocate_on_comm_stream=False) allocate_on_comm_stream=False,
)
if do_async:
def _receiver():
if event.event is not None:
event.current_stream_wait()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
# Respect inplace outputs. return _receiver
output.copy_(combined_x, non_blocking=True) else:
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return None
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True,
) -> Callable:
receiver = self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
do_async=True,
apply_weights_and_reduce=apply_weights_and_reduce,
)
assert receiver is not None
return receiver
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
do_async=False,
apply_weights_and_reduce=apply_weights_and_reduce,
)
...@@ -114,8 +114,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -114,8 +114,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x_scales = normalize_batched_scales_shape(x_scales, num_experts) x_scales = normalize_batched_scales_shape(x_scales, num_experts)
return x, x_scales return x, x_scales
def prepare( def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
...@@ -126,9 +126,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -126,9 +126,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[Callable, mk.ReceiverType]:
Optional[torch.Tensor], Optional[torch.Tensor]]:
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
(f"Hidden Size {hidden_size} not in supported list of hidden sizes" (f"Hidden Size {hidden_size} not in supported list of hidden sizes"
...@@ -148,25 +146,74 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -148,25 +146,74 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk = topk_ids.size(1) topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1 # TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, ( assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1") "apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch # Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \ expert_x, expert_num_tokens, self.handles, _, hook = self.buffer.low_latency_dispatch(
self.buffer.low_latency_dispatch(a1, a1,
topk_ids, topk_ids,
self.max_tokens_per_rank, self.max_tokens_per_rank,
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch, use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch,
use_int8=self.use_int8_dispatch, use_int8=self.use_int8_dispatch,
async_finish=False, async_finish=False,
return_recv_hook=False) return_recv_hook=True,
)
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, return (
quant_config.per_act_token_quant, quant_config.block_shape, expert_num_tokens) hook,
lambda: self._receiver(
return (expert_x, expert_x_scale, expert_num_tokens, None, None) expert_x,
expert_num_tokens,
a1_scale,
a1.dtype,
quant_config,
),
)
def _receiver(
self,
expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
expert_num_tokens: torch.Tensor,
a1_scale: torch.Tensor | None,
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
hook()
return receiver()
def _finalize( def _finalize(
self, self,
......
...@@ -4,13 +4,15 @@ from abc import ABC, abstractmethod ...@@ -4,13 +4,15 @@ from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from math import prod from math import prod
from typing import Optional, final from typing import Optional, final
from dataclasses import dataclass
from collections.abc import Callable
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv from vllm.utils import cdiv, async_tensor_h2d
# #
# This file defines a set of base classes used to make MoE kernels more modular. # This file defines a set of base classes used to make MoE kernels more modular.
...@@ -95,6 +97,57 @@ class FusedMoEActivationFormat(Enum): ...@@ -95,6 +97,57 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts", BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: torch.Tensor | None
@staticmethod
def make_from_list(
expert_num_tokens_list: list[int], device: str
) -> "ExpertTokensMetadata":
# expert_num_tokens_cpu = torch.tensor(
# expert_num_tokens_list, device="cpu", dtype=torch.int32
# )
expert_num_tokens_cpu = torch.tensor(
expert_num_tokens_list, device="cpu", dtype=torch.int32, pin_memory=True
)
expert_num_tokens = expert_num_tokens_cpu.to(device=device, non_blocking=True)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType = tuple[
torch.Tensor,
torch.Tensor | None,
ExpertTokensMetadata | None,
torch.Tensor | None,
torch.Tensor | None,
]
ReceiverType = Callable[[], PrepareResultType]
# TODO: pass FusedMoEParallelConfig in as ctor parameter? # TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
...@@ -880,62 +933,93 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -880,62 +933,93 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, # (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare( # _expert_topk_weights) = self.prepare_finalize.prepare(
a1, # a1,
a1_scale, # a1_scale,
a2_scale, # a2_scale,
topk_weights, # topk_weights,
topk_ids, # topk_ids,
global_num_experts, # global_num_experts,
expert_map, # expert_map,
apply_router_weight_on_input, # apply_router_weight_on_input,
self.fused_experts.quant_config, # self.fused_experts.quant_config,
) # )
prepare_ret = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
)
if hook is not None:
hook()
(
a1q,
a1q_scale,
expert_tokens_meta,
_expert_topk_ids,
_expert_topk_weights,
) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks. # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights) _expert_topk_weights)
fused_out = self.fused_experts.apply( if a1q.numel() == 0:
None, # This happens when none of the tokens from the all2all reach this
a1, # EP rank. Also, note that this is only relevant for CUDAGraph
a1q, # incompatible all2all kernels like the DeepEP high-throughput
w1, # kernels. CUDAGraph compatible all2all kernels like the pplx
w2, # kernels and the DeepEP low-latency kernels are always batched
topk_ids, # and can never run into the tensor.numel() == 0 case.
topk_weights=topk_weights, fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
activation=activation, else:
global_num_experts=global_num_experts, fused_out = self.fused_experts.apply(
expert_map=expert_map, None,
w1_scale=w1_scale, a1,
w2_scale=w2_scale, a1q,
w1_zp=w1_zp, w1,
w2_zp=w2_zp, w2,
a1q_scale=a1q_scale, topk_ids,
a2_scale=a2_scale, topk_weights=topk_weights,
workspace13=None, activation=activation,
workspace2=None, global_num_experts=global_num_experts,
use_nn_moe=use_nn_moe, expert_map=expert_map,
expert_num_tokens=expert_num_tokens, w1_scale=w1_scale,
shared_output=shared_output, w2_scale=w2_scale,
routed_scaling_factor=routed_scaling_factor, w1_zp=w1_zp,
) w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=None,
workspace2=None,
use_nn_moe=use_nn_moe,
expert_num_tokens=expert_tokens_meta.expert_num_tokens,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
expert_num_tokens_cpu=expert_tokens_meta.expert_num_tokens_cpu
)
shared_output = None shared_output = None
if self.shared_experts is None: hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
if hook is not None: if hook is not None:
hook() hook()
if self.shared_experts is not None: if self.shared_experts is not None:
return (shared_output, output) return (shared_output, output)
......
...@@ -85,6 +85,7 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute): ...@@ -85,6 +85,7 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
expert_num_tokens_cpu: torch.Tensor = None,
): ):
assert self.fused_experts is not None assert self.fused_experts is not None
...@@ -107,4 +108,5 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute): ...@@ -107,4 +108,5 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
q_x=q_hidden_states, q_x=q_hidden_states,
expert_num_tokens_cpu=expert_num_tokens_cpu
) )
...@@ -11,6 +11,7 @@ from triton.language.extra import libdevice ...@@ -11,6 +11,7 @@ from triton.language.extra import libdevice
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.utils import round_up
try: try:
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
...@@ -276,8 +277,8 @@ def _int8_quantize( ...@@ -276,8 +277,8 @@ def _int8_quantize(
# activations apply per-token quantization. Otherwise, assume # activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static # activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None: if block_shape is None:
assert per_act_token, \ # assert per_act_token, \
"int8 quantization only supports block or channel-wise" # "int8 quantization only supports block or channel-wise"
if expert_num_tokens is None: if expert_num_tokens is None:
A, A_scale = per_token_quant_int8(A) A, A_scale = per_token_quant_int8(A)
else: else:
...@@ -361,3 +362,502 @@ def _validate_scale_shape( ...@@ -361,3 +362,502 @@ def _validate_scale_shape(
assert block_shape is not None assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
@triton.jit
def _count_expert_num_tokens(
topk_ids_ptr,
expert_num_tokens_ptr,
num_experts,
topk_numel,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
curr_expert = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
topk_ids_ptrs = topk_ids_ptr + offsets
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)):
mask = offsets < (topk_numel - x * BLOCK_SIZE)
expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1)
if HAS_EXPERT_MAP:
expert_map_ptrs = expert_map + expert_ids
expert_map_mask = expert_ids >= 0
expert_ids = tl.load(expert_map_ptrs, mask=expert_map_mask, other=-1)
has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0)
acc = acc + has_curr_expert
topk_ids_ptrs += BLOCK_SIZE
if curr_expert < num_experts:
tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))
def count_expert_num_tokens(
topk_ids: torch.Tensor, num_local_experts: int, expert_map: torch.Tensor | None
) -> torch.Tensor:
"""
Count the number to tokens assigned to each expert.
Parameters:
- topk_ids (torch.Tensor): Tensor mapping each token to its
list of experts.
- num_local_experts (int): Number of experts in this rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
Returns:
A tensor of size num_local_experts, where tensor[i] holds the number
of tokens assigned to the ith expert.
"""
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
expert_num_tokens = torch.empty(
(num_local_experts), device=topk_ids.device, dtype=torch.int32
)
grid = num_local_experts
BLOCK_SIZE = min(topk_ids.numel(), 1024)
BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)
_count_expert_num_tokens[(grid,)](
topk_ids,
expert_num_tokens,
num_local_experts,
topk_ids.numel(),
expert_map,
HAS_EXPERT_MAP=expert_map is not None,
BLOCK_SIZE=BLOCK_SIZE,
)
return expert_num_tokens
def expert_num_tokens_round_up_and_sum(
expert_num_tokens: torch.Tensor, alignment: int
) -> int:
# Round up each element in expert_num_tokens to the nearest multiple of
# alignment.
ent = (expert_num_tokens.to(torch.int64) + (alignment - 1)) // alignment * alignment
return torch.sum(ent).item()
def compute_aligned_M(
M: int,
num_topk: int,
local_num_experts: int,
alignment: int,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
):
if expert_num_tokens_cpu is not None:
return expert_num_tokens_round_up_and_sum(
expert_num_tokens_cpu, alignment=alignment
)
# expert_num_tokens information is not available on the cpu.
# compute the max required size.
M_sum = (M * num_topk) + local_num_experts * (alignment - 1)
M_sum = round_up(M_sum, alignment)
return M_sum
@triton.jit
def apply_expert_map(expert_id, expert_map):
if expert_id != -1:
expert_id = tl.load(expert_map + expert_id).to(expert_id.dtype)
return expert_id
@triton.jit
def round_up_256(x: int) -> int:
y = 256
return ((x + y - 1) // y) * y
@triton.jit
def round_up_128(x: int) -> int:
y = 128
return ((x + y - 1) // y) * y
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts: tl.constexpr,
BLOCK_E: tl.constexpr,
BLOCK_EXPERT_NUM: tl.constexpr,
):
cur_expert = tl.program_id(0)
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
tokens_per_expert = tl.load(
num_recv_tokens_per_expert + offset_cumsum,
mask=offset_cumsum < num_experts,
other=0,
)
#tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert = round_up_256(tokens_per_expert)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
#if cur_expert == 0:
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
tl.debug_barrier()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
tl.store(
m_indices_start_ptr + start_m + off_expert,
cur_expert,
mask=start_m + off_expert < cur_expert_token_num
)
@triton.jit
def _fwd_kernel_ep_scatter_2(
total_token_num,
expert_start_loc,
recv_x,
recv_x_stride0,
recv_x_stride1,
recv_x_scale,
recv_x_scale_stride0,
recv_x_scale_stride1,
recv_topk,
recv_topk_stride0,
recv_topk_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
output_tensor_scale,
output_tensor_scale_stride0,
output_tensor_scale_stride1,
output_index,
output_index_stride0,
output_index_stride1,
topk_num: tl.constexpr,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
HIDDEN_SIZE_PAD: tl.constexpr,
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
):
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = index_in_s < SCALE_HIDDEN_SIZE
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
token_id = token_id_int32.to(tl.int64)
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(
recv_x_scale
+ token_id * recv_x_scale_stride0
+ index_in_s * recv_x_scale_stride1,
mask=mask_s,
)
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
topk_index = topk_idx_int32.to(tl.int64)
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index = dest_token_index_int32.to(tl.int64)
tl.store(
output_index + token_id * output_index_stride0 + topk_index,
dest_token_index_int32,
)
output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0
)
output_tensor_scale_ptr = (
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
)
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
to_copy_s,
mask=mask_s,
)
@torch.no_grad()
def ep_scatter(
recv_x: torch.Tensor,
recv_x_scale: torch.Tensor,
recv_topk: torch.Tensor,
num_recv_tokens_per_expert: torch.Tensor,
expert_map: torch.Tensor | None,
expert_start_loc: torch.Tensor,
output_tensor: torch.Tensor,
output_tensor_scale: torch.Tensor,
m_indices: torch.Tensor,
output_index: torch.Tensor,
):
#BLOCK_E = 128 # token num of per expert is aligned to 128
#BLOCK_D = 128 # block size of quantization
BLOCK_E = 256 # token num of per expert is aligned to 256
num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@triton.jit
def _fwd_kernel_ep_gather(
total_token_num,
input_tensor,
input_tensor_stride0,
input_tensor_stride1,
recv_topk_ids,
recv_topk_ids_stride0,
recv_topk_ids_stride1,
recv_topk_weight,
recv_topk_weight_stride0,
recv_topk_weight_stride1,
input_index,
input_index_stride0,
input_index_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
topk_num: tl.constexpr,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_D: tl.constexpr,
):
cur_block_int32 = tl.program_id(0)
cur_block = cur_block_int32.to(tl.int64)
start_cur_token_int32 = tl.program_id(1)
grid_num = tl.num_programs(1)
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
cur_token = cur_token_int32.to(tl.int64)
off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index_int32 in range(0, topk_num):
topk_index = topk_index_int32.to(tl.int64)
expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
source_token_index_int32 = tl.load(
input_index + cur_token * input_index_stride0 + topk_index
)
source_token_index = source_token_index_int32.to(tl.int64)
acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
)
tmp = tl.load(
input_tensor
+ source_token_index * input_tensor_stride0
+ cur_block * BLOCK_D
+ off_d
)
accumulator += tmp.to(tl.float32) * acc_weight
tl.store(
output_tensor
+ cur_token * output_tensor_stride0
+ cur_block * BLOCK_D
+ off_d,
accumulator.to(output_tensor.dtype.element_ty),
)
@torch.no_grad()
def ep_gather(
input_tensor: torch.Tensor,
recv_topk_ids: torch.Tensor,
recv_topk_weight: torch.Tensor,
input_index: torch.Tensor,
expert_map: torch.Tensor | None,
output_tensor: torch.Tensor,
):
num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
BLOCK_D = min(hidden_size, 1024)
assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid](
num_tokens,
input_tensor,
input_tensor.stride(0),
input_tensor.stride(1),
recv_topk_ids,
recv_topk_ids.stride(0),
recv_topk_ids.stride(1),
recv_topk_weight,
recv_topk_weight.stride(0),
recv_topk_weight.stride(1),
input_index,
input_index.stride(0),
input_index.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
topk_num=recv_topk_ids.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
BLOCK_D=BLOCK_D,
)
return
def deepgemm_moe_permute(
aq: torch.Tensor,
aq_scale: torch.Tensor,
topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: torch.Tensor | None,
block_shape: list[int],
expert_num_tokens: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
aq_out: torch.Tensor | None = None,
M_sum: int | None = None,
):
assert aq.ndim == 2
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
H = aq.size(1)
device = aq.device
block_m = block_shape[0]
if M_sum is None:
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=block_m,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
expert_start_loc = torch.empty(
(local_num_experts), device=device, dtype=torch.int32
)
assert aq_out is None or aq_out.shape == (M_sum, H)
if aq_out is None:
aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype)
aq_scale_out = torch.empty(
(M_sum, aq_scale.shape[-1]), device=device, dtype=torch.float32
#(M_sum, H // block_k), device=device, dtype=torch.float32
)
# maybe_has_empty_blocks = expert_num_tokens_cpu is None
# expert_ids_init = torch.zeros# if maybe_has_empty_blocks else torch.empty
# expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
expert_ids = torch.full(
(M_sum,), -1, dtype=torch.int32, device=device
)
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
if expert_num_tokens is None:
expert_num_tokens = count_expert_num_tokens(
topk_ids, local_num_experts, expert_map
)
ep_scatter(
recv_x=aq,
recv_x_scale=aq_scale,
recv_topk=topk_ids,
num_recv_tokens_per_expert=expert_num_tokens,
expert_start_loc=expert_start_loc,
expert_map=expert_map,
output_tensor=aq_out,
output_tensor_scale=aq_scale_out,
m_indices=expert_ids,
output_index=inv_perm,
)
return aq_out, aq_scale_out, expert_ids, inv_perm
def deepgemm_unpermute_and_reduce(
a: torch.Tensor, # Grouped gemm output
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
inv_perm: torch.Tensor,
expert_map: torch.Tensor | None,
output: torch.Tensor,
):
return ep_gather(
input_tensor=a,
recv_topk_ids=topk_ids,
recv_topk_weight=topk_weights,
input_index=inv_perm,
expert_map=expert_map,
output_tensor=output,
)
...@@ -19,12 +19,15 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -19,12 +19,15 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig, FusedMoeWeightScaleSupported, FusedMoEConfig, FusedMoeWeightScaleSupported,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,) FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache, compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce
from vllm.model_executor.layers.quantization.utils.w8a8_utils import( from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight) get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.utils import round_up
try: try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep from lightop import m_grouped_w8a8_gemm_nt_masked, m_grouped_w8a8_gemm_nt_contig_asm, fuse_silu_mul_quant_ep, fuse_silu_mul_quant
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -84,26 +87,27 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -84,26 +87,27 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size self.ep_size = get_ep_group().world_size
backend = envs.VLLM_ALL2ALL_BACKEND
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(backend == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
backend == "deepep_low_latency" or \ envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
backend == "deepep_auto") envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.use_deepep_ll = self.use_deepep and (backend == "deepep_low_latency" or \ #self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
(backend == "deepep_auto"))
if self.use_deepep: if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size self.num_dispatchers = all2all_manager.world_size
self.block_shape = [256, 256]
self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM or envs.VLLM_ALL2ALL_BACKEND == "deepep_auto"
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep_ll: if self.use_deepep:
self.N = 2 * intermediate_size_per_partition self.N = 2 * intermediate_size_per_partition
self.K = hidden_size self.K = hidden_size
...@@ -157,7 +161,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -157,7 +161,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]): for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep_ll: if not self.use_deepgemm:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else: else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii]) w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
...@@ -168,7 +172,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -168,7 +172,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
del w1_marlin_list del w1_marlin_list
w2_marlin_list = [] w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]): for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep_ll: if not self.use_deepgemm:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else: else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii]) w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
...@@ -178,7 +182,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -178,7 +182,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False) layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def groupgemm_workspace_shapes(self, def masked_groupgemm_workspace_shapes(self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor, aq: torch.Tensor,
M: int, M: int,
...@@ -200,8 +204,27 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -200,8 +204,27 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K) output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def contiguous_groupgemm_workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_num_tokens_cpu: torch.Tensor
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert self.block_shape is not None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m = self.block_shape[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_num_tokens_cpu
)
assert M_sum % block_m == 0
workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype, M_sum)
def w8a8_groupgemm_forward(self, def w8a8_groupgemm_masked_forward(self,
x: torch.Tensor, x: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -220,6 +243,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -220,6 +243,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None, q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ): **_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
...@@ -230,7 +254,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -230,7 +254,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
N, K = self.N, self.K N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape, (workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes( workspace_dtype) = self.masked_groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts, x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts) local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape), workspace13 = torch.empty(prod(workspace13_shape),
...@@ -269,6 +293,94 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -269,6 +293,94 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return fused_out return fused_out
def w8a8_groupgemm_contiguous_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
a1q = q_x
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype, M_sum) = self.contiguous_groupgemm_workspace_shapes(
x, q_x, topk_ids.size(0), N, K, topk_ids.size(1), global_num_experts,
local_num_experts, expert_num_tokens_cpu)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=x.device,
dtype=workspace_dtype)
mm1_out = _resize_cache(workspace13, (M_sum, N))
mm2_out = _resize_cache(workspace2, (M_sum, K))
fused_out = _resize_cache(workspace13, fused_out_shape)
a1q_perm = _resize_cache(workspace2.view(dtype=a1q.dtype), (M_sum, K))
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
block_shape=self.block_shape,
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
aq_out=a1q_perm,
M_sum=M_sum
)
# if expert_map is not None:
# # DeepGemm (Grouped Contiguous) kernel needs a valid B index
# # for all rows of A. To that effect, simply compute with
# # the 0th weight matrix.
# # Note that this relies on the fact that corresponding topk
# # weights would be 0 during weight multiplication.
# expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
a2q, a2q_scale = fuse_silu_mul_quant(mm1_out)
m_grouped_w8a8_gemm_nt_contig_asm(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=fused_out,
)
return fused_out
def fused_moe_forward(self, def fused_moe_forward(self,
x: torch.Tensor, x: torch.Tensor,
...@@ -289,6 +401,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -289,6 +401,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None, q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ): **_ ):
return fused_experts_impl_int8_marlin( return fused_experts_impl_int8_marlin(
hidden_states=x if q_x is None else q_x, hidden_states=x if q_x is None else q_x,
...@@ -401,7 +514,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -401,7 +514,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return TritonOrGroupGemmExperts( return TritonOrGroupGemmExperts(
use_int8_w8a8=True, use_int8_w8a8=True,
per_act_token_quant=True, per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_forward fused_experts=self.w8a8_groupgemm_masked_forward
) )
else: else:
logger.debug( logger.debug(
...@@ -410,5 +523,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -410,5 +523,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
False) False)
return TritonOrGroupGemmExperts( return TritonOrGroupGemmExperts(
fused_experts=self.fused_moe_forward use_int8_w8a8=envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM,
fused_experts=self.w8a8_groupgemm_contiguous_forward if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else self.fused_moe_forward
) )
\ No newline at end of file
...@@ -167,6 +167,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -167,6 +167,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.ep_size = get_ep_group().world_size
if self.use_deepep: if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
...@@ -352,7 +354,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -352,7 +354,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# (from deepgemm docs) : A value hint (which is a value on CPU) # (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value # for the M expectation of each batch, correctly setting this value
# may lead to better performance. # may lead to better performance.
expected_m = max_num_tokens #expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.ep_size
m_grouped_w4a8_gemm_nt_masked((q_x, a1_scale), m_grouped_w4a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale), (w1, w1_scale),
......
...@@ -174,14 +174,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -174,14 +174,12 @@ class DeepseekV2MoE(nn.Module):
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori' self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
backend = envs.VLLM_ALL2ALL_BACKEND self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
self.use_deepep_ll = ( (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
dp_size > 1 envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
and parallel_config.enable_expert_parallel envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
and (backend == "deepep_low_latency" or backend == "deepep_auto")
)
if not self.use_deepep_ll: if not self.use_deepep:
moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
self.experts = moe_cls( self.experts = moe_cls(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
...@@ -254,7 +252,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -254,7 +252,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_mori_ep and not self.use_deepep_ll: if not self.use_mori_ep and not self.use_deepep:
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
...@@ -289,7 +287,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -289,7 +287,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: else:
if self.use_deepep_ll: if self.use_deepep:
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states, shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
...@@ -721,12 +719,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -721,12 +719,10 @@ class DeepseekV2DecoderLayer(nn.Module):
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
backend = envs.VLLM_ALL2ALL_BACKEND self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
self.use_deepep_ll = ( (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
self.dp_size > 1 envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
and parallel_config.enable_expert_parallel envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
and (backend == "deepep_low_latency" or backend == "deepep_auto")
)
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
...@@ -855,7 +851,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -855,7 +851,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual) hidden_states, residual)
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep_ll and self.tp_size > 1: DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
ori_bs = hidden_states.shape[0] ori_bs = hidden_states.shape[0]
...@@ -868,7 +864,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -868,7 +864,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep_ll and self.tp_size > 1: DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0).contiguous() hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0).contiguous()
hidden_states = hidden_states[:ori_bs, :].contiguous() hidden_states = hidden_states[:ori_bs, :].contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment