Commit cd42bf87 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev-wm-1218' into v0.9.2-dev

parents 43546076 9925dd0e
...@@ -52,7 +52,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, ...@@ -52,7 +52,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
LayerBlockType, LazyLoader, common_broadcastable_dtype, LayerBlockType, LazyLoader, common_broadcastable_dtype,
cuda_device_count_stateless, get_cpu_memory, cuda_device_count_stateless, get_cpu_memory,
get_open_port, is_torch_equal_or_newer, random_uuid, get_open_port, is_torch_equal_or_newer, random_uuid,
resolve_obj_by_qualname) resolve_obj_by_qualname, round_up)
from vllm.utils import SUPPORT_TC from vllm.utils import SUPPORT_TC
# yapf: enable # yapf: enable
...@@ -4783,10 +4783,23 @@ class VllmConfig: ...@@ -4783,10 +4783,23 @@ class VllmConfig:
if size <= max_num_tokens if size <= max_num_tokens
] ]
# add for ep sp
dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
# add for spec decode # add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0: if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots), mtp_batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list)) batch_size_capture_list))
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
else:
if ep_sp:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)
......
...@@ -8,7 +8,7 @@ import torch.distributed as dist ...@@ -8,7 +8,7 @@ import torch.distributed as dist
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx from vllm.utils import has_deep_ep, has_pplx
import vllm.envs as envs
from .base_device_communicator import All2AllManagerBase, Cache from .base_device_communicator import All2AllManagerBase, Cache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -140,7 +140,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -140,7 +140,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish # This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling. # reasonable defaults based on profiling.
self.num_sms = 20 self.num_sms = 30
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
...@@ -166,16 +166,26 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -166,16 +166,26 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def _make_all2all_kwargs(self) -> dict[Any, Any]: def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests. # Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024 num_nvl_bytes = int(2e9/2)#1024 * 1024 * 1024
num_rdma_bytes = None num_rdma_bytes = None
num_qps_per_rank = None num_qps_per_rank = None
if self.internode: if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024 num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = self.num_sms // 2 num_qps_per_rank = 30 #self.num_sms // 2
self.num_sms = 30
# import deep_ep
# num_nvl_bytes, num_rdma_bytes = 0, 0
# hidden_size = 7168
# hidden_bytes = hidden_size * 2
# for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
# num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
# num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
else: else:
num_rdma_bytes = 0 num_rdma_bytes = 0
num_qps_per_rank = 1 num_qps_per_rank = 1
self.num_sms = 60
assert num_rdma_bytes is not None assert num_rdma_bytes is not None
assert num_qps_per_rank is not None assert num_qps_per_rank is not None
...@@ -183,7 +193,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -183,7 +193,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_nvl_bytes=num_nvl_bytes, num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False, low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank) num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=False)
def get_handle(self, kwargs): def get_handle(self, kwargs):
...@@ -244,7 +255,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -244,7 +255,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_nvl_bytes=num_nvl_bytes, num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True, low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank) num_qps_per_rank=num_qps_per_rank,
)
def get_handle(self, kwargs): def get_handle(self, kwargs):
""" """
......
...@@ -237,10 +237,11 @@ class DeviceCommunicatorBase: ...@@ -237,10 +237,11 @@ class DeviceCommunicatorBase:
moe_modules = [ moe_modules = [
module for module in model.modules() module for module in model.modules()
if module.__class__.__name__ == "FusedMoE" if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config, module.quant_method.init_prepare_finalize(module, module.moe_config,
module.quant_config) module.quant_config)
def dispatch( def dispatch(
......
...@@ -199,6 +199,7 @@ if TYPE_CHECKING: ...@@ -199,6 +199,7 @@ if TYPE_CHECKING:
VLLM_USE_MARLIN_W16A16_MOE:bool = False VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1286,11 +1287,17 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1286,11 +1287,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_FILL_RMS_CAT": "VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# If set to 1/True, enable the reduced top-k/top-p sampling path in the # If set to 1/True, enable the reduced top-k/top-p sampling path in the
# V1 PyTorch-native sampler. # V1 PyTorch-native sampler.
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER": "VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER":
lambda: (os.getenv("VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER", lambda: (os.getenv("VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER",
"0").lower() in ("true", "1")), "0").lower() in ("true", "1")),
# vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -136,7 +136,8 @@ def set_forward_context( ...@@ -136,7 +136,8 @@ def set_forward_context(
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
dp_size = vllm_config.parallel_config.data_parallel_size dp_size = vllm_config.parallel_config.data_parallel_size
if dp_size > 1 and ( use_navie_ep = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if use_navie_ep and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None): attn_metadata is not None or num_tokens is not None):
dp_metadata = DPMetadata.make(vllm_config.parallel_config, dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0, attn_metadata, num_tokens or 0,
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None _config: Optional[dict[str, Any]] = None
...@@ -38,6 +39,7 @@ __all__ = [ ...@@ -38,6 +39,7 @@ __all__ = [
"FusedMoEPrepareAndFinalize", "FusedMoEPrepareAndFinalize",
"override_config", "override_config",
"get_config", "get_config",
"SharedFusedMoE",
] ]
if HAS_TRITON: if HAS_TRITON:
...@@ -59,6 +61,8 @@ if HAS_TRITON: ...@@ -59,6 +61,8 @@ if HAS_TRITON:
get_config_file_name, grouped_topk) get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts) TritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.triton_group_gemm_moe import (
TritonOrGroupGemmExperts)
__all__ += [ __all__ += [
"fused_moe", "fused_moe",
...@@ -75,4 +79,5 @@ if HAS_TRITON: ...@@ -75,4 +79,5 @@ if HAS_TRITON:
"BatchedDeepGemmExperts", "BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts", "TritonOrDeepGemmExperts",
"BatchedTritonOrDeepGemmExperts", "BatchedTritonOrDeepGemmExperts",
"TritonOrGroupGemmExperts",
] ]
...@@ -54,7 +54,7 @@ def get_config_quant_dtype( ...@@ -54,7 +54,7 @@ def get_config_quant_dtype(
) -> Optional[torch.dtype]: ) -> Optional[torch.dtype]:
if use_fp8_w8a8: if use_fp8_w8a8:
return torch.float8_e4m3fn return torch.float8_e4m3fn
elif use_int8_w8a8: elif use_int8_w8a8 or use_int4_w4a8:
return torch.int8 return torch.int8
return None return None
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional
from collections.abc import Callable
import deep_ep import deep_ep
import torch import torch
import torch.distributed as dist
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
from vllm.distributed.parallel_state import get_ep_group
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...@@ -55,35 +59,49 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -55,35 +59,49 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return None return None
return deep_ep.Buffer.get_combine_config(self.dp_size) return deep_ep.Buffer.get_combine_config(self.dp_size)
def _do_dispatch(self, tokens: torch.Tensor, def _do_dispatch(
token_scales: Optional[torch.Tensor], self,
tokens: torch.Tensor,
token_scales: torch.Tensor | None,
rank_topk_ids: torch.Tensor, rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor, num_experts: int): rank_topk_weights: torch.Tensor,
num_experts: int,
quant_config: FusedMoEQuantConfig,
) -> Callable:
has_scales = token_scales is not None has_scales = token_scales is not None
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens, (
is_token_in_rank, event) = self.buffer.get_dispatch_layout( num_tokens_per_rank,
num_tokens_per_rdma_rank,
dispatch_expert_num_tokens,
is_token_in_rank,
event,
) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids, topk_idx=rank_topk_ids,
num_experts=num_experts, num_experts=num_experts,
previous_event=None, previous_event=None,
async_finish=False, async_finish=False,
allocate_on_comm_stream=False) allocate_on_comm_stream=False,
)
token_data = tokens token_data = tokens
if has_scales: if has_scales:
token_data = (tokens, token_scales) token_data = (tokens, token_scales)
( (
token_data, expert_topk_ids, expert_topk_weights, token_data,
expert_num_tokens_per_expert_list, self.handle, event expert_topk_ids,
expert_topk_weights,
expert_num_tokens_per_expert_list,
self.handle,
event,
) = self.buffer.dispatch( ) = self.buffer.dispatch(
x=token_data, x=token_data,
handle=None, handle=None,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens, num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids, topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights, topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert # expert_alignment rounds the number of tokens per expert
...@@ -91,8 +109,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -91,8 +109,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment=1, expert_alignment=1,
config=self._get_dispatch_config(), config=self._get_dispatch_config(),
previous_event=None, previous_event=None,
async_finish=False, async_finish=True,
allocate_on_comm_stream=False) allocate_on_comm_stream=False,
)
return lambda: self._receiver(
event,
has_scales,
token_data,
expert_topk_ids,
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
token_scales,
quant_config,
)
def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
expert_topk_ids: torch.Tensor | None,
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if event.event is not None:
event.current_stream_wait()
if has_scales: if has_scales:
expert_x, expert_x_scale = token_data expert_x, expert_x_scale = token_data
...@@ -110,15 +156,45 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -110,15 +156,45 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset # DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns # the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces. # with existing vLLM interfaces.
assert expert_topk_ids is not None
expert_topk_ids = torch.where( expert_topk_ids = torch.where(
expert_topk_ids == -1, expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0, num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset) expert_topk_ids + self.rank_expert_offset,
)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, # Makes a GPU-CPU copy.
expert_topk_weights) # TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device
)
def prepare( # Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.per_act_token_quant:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape,
)
return (
expert_x,
expert_x_scale,
expert_tokens_meta,
expert_topk_ids,
expert_topk_weights,
)
def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
...@@ -129,14 +205,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -129,14 +205,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> mk.ReceiverType:
Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1 # TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, ( assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1") "apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
if quant_config.per_act_token_quant: if quant_config.per_act_token_quant:
...@@ -149,35 +224,43 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -149,35 +224,43 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
if a1q_scale is not None and a1q_scale.numel() == 1: if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1) a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, else:
expert_topk_weights) = self._do_dispatch( a1q = a1
a1q_scale = None
return self._do_dispatch(
tokens=a1q, tokens=a1q,
token_scales=a1q_scale, token_scales=a1q_scale,
rank_topk_ids=topk_ids, rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights, rank_topk_weights=topk_weights,
num_experts=num_experts) num_experts=num_experts,
else: quant_config=quant_config,
# DeepEP kernels only support dispatching per-token-quant )
# quantization. dispatch in bfloat16.
(expert_x, _, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
# quantize now
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, def prepare(
expert_topk_weights) self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
return receiver()
def _apply_weights_and_reduce(self, num_tokens: int, def _apply_weights_and_reduce(self, num_tokens: int,
fused_expert_output: torch.Tensor, fused_expert_output: torch.Tensor,
...@@ -203,29 +286,79 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -203,29 +286,79 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return out return out
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def _finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
apply_router_weight_on_input: bool) -> None: output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
do_async: bool,
apply_weights_and_reduce: bool = True,
) -> Callable | None:
assert self.handle is not None assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0:
fused_expert_output = self._apply_weights_and_reduce(
num_tokens=topk_ids.size(0),
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
apply_router_weight_on_input=apply_router_weight_on_input,
output_dtype=output.dtype)
combined_x, _, event = self.buffer.combine( combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output, x=fused_expert_output,
handle=self.handle, handle=self.handle,
topk_weights=None, topk_weights=None,
config=self._get_combine_config(), config=self._get_combine_config(),
previous_event=None, previous_event=None,
async_finish=False, async_finish=do_async,
allocate_on_comm_stream=False) allocate_on_comm_stream=False,
)
if do_async:
def _receiver():
if event.event is not None:
event.current_stream_wait()
# Respect inplace outputs. # Respect inplace outputs.
output.copy_(combined_x, non_blocking=True) output.copy_(combined_x, non_blocking=True)
return _receiver
else:
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return None
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True,
) -> Callable:
receiver = self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
do_async=True,
apply_weights_and_reduce=apply_weights_and_reduce,
)
assert receiver is not None
return receiver
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
do_async=False,
apply_weights_and_reduce=apply_weights_and_reduce,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union from typing import Optional, Union
from collections.abc import Callable
import deep_ep import deep_ep
import torch import torch
...@@ -44,12 +45,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -44,12 +45,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
buffer: deep_ep.Buffer, buffer: deep_ep.Buffer,
max_tokens_per_rank: int, max_tokens_per_rank: int,
num_dispatchers: int, num_dispatchers: int,
use_fp8_dispatch: bool = False): use_fp8_dispatch: bool = False,
use_int8_dispatch: bool = False):
super().__init__() super().__init__()
self.buffer = buffer self.buffer = buffer
self.max_tokens_per_rank = max_tokens_per_rank self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch self.use_fp8_dispatch = use_fp8_dispatch
self.use_int8_dispatch = use_int8_dispatch
# The dispatch function returns a handle that the combine function # The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the # requires. We store the handle here so it is available to the
# combine function. # combine function.
...@@ -71,17 +74,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -71,17 +74,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _do_quant( def _do_quant(
self, self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype, a1_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype], quant_config: FusedMoEQuantConfig,
per_act_token_quant: bool, expert_num_tokens: Optional[torch.Tensor]= None,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = block_shape[1] if block_shape is not None else None
if self.use_fp8_dispatch: if self.use_fp8_dispatch:
block_k = (
quant_config.block_shape[1]
if quant_config.block_shape is not None
else None
)
if block_k == DEEPEP_QUANT_BLOCK_SIZE: if block_k == DEEPEP_QUANT_BLOCK_SIZE:
# DeepEP kernels did the quantization for us. # DeepEP kernels did the quantization for us.
x, x_scales = x x, x_scales = x
...@@ -96,19 +101,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -96,19 +101,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts, max_tokens, hidden_dim = x.size() num_experts, max_tokens, hidden_dim = x.size()
# TODO (varun): Optimization - Use a batched version of quant # TODO (varun): Optimization - Use a batched version of quant
if expert_num_tokens is None:
x = x.view((-1, hidden_dim)) x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, x, x_scales = moe_kernel_quantize_input(
per_act_token_quant, x,
block_shape) a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
expert_num_tokens
)
x = x.view((num_experts, -1, hidden_dim)) x = x.view((num_experts, -1, hidden_dim))
if quant_dtype is not None: if quant_config.quant_dtype is not None:
assert x_scales is not None assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts) x_scales = normalize_batched_scales_shape(x_scales, num_experts)
return x, x_scales return x, x_scales
def prepare( def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
...@@ -119,9 +130,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -119,9 +130,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[Callable, mk.ReceiverType]:
Optional[torch.Tensor], Optional[torch.Tensor]]:
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
(f"Hidden Size {hidden_size} not in supported list of hidden sizes" (f"Hidden Size {hidden_size} not in supported list of hidden sizes"
...@@ -141,29 +150,89 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -141,29 +150,89 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk = topk_ids.size(1) topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1 # TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, ( assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1") "apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch # Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \ expert_x, expert_num_tokens, self.handle, _, hook = self.buffer.low_latency_dispatch(
self.buffer.low_latency_dispatch(a1, a1,
topk_ids, topk_ids,
self.max_tokens_per_rank, self.max_tokens_per_rank,
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch, use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch,
use_int8=self.use_int8_dispatch,
async_finish=False, async_finish=False,
return_recv_hook=False) return_recv_hook=True,
)
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, return (
quant_config.per_act_token_quant, quant_config.block_shape) hook,
lambda: self._receiver(
expert_x,
expert_num_tokens,
a1_scale,
a1.dtype,
quant_config,
),
)
def _receiver(
self,
expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
expert_num_tokens: torch.Tensor,
a1_scale: torch.Tensor | None,
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if not self.use_int8_dispatch:
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a1_dtype, quant_config, expert_num_tokens)
else:
expert_x, expert_x_scale = expert_x
return (expert_x, expert_x_scale, expert_num_tokens, None, None) expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, return expert_x, expert_x_scale, expert_tokens_meta, None, None
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
hook()
return receiver()
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool,
do_async: bool,
) -> Callable:
do_recv_hook = do_async
assert self.handle is not None assert self.handle is not None
combine_topk_weights = topk_weights combine_topk_weights = topk_weights
...@@ -172,12 +241,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -172,12 +241,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights) combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode # TODO (varun) : Enable zero copy mode
_, event, hook = self.buffer.low_latency_combine( _, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output, fused_expert_output,
topk_ids, topk_ids,
combine_topk_weights, combine_topk_weights,
self.handle, self.handle,
async_finish=False, async_finish=False,
zero_copy=False, zero_copy=False,
return_recv_hook=False, return_recv_hook=do_recv_hook,
out=output) out=output,
)
return recv_hook
def finalize_async(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
return self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
apply_weights_and_reduce,
do_async=True,
)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
apply_weights_and_reduce,
do_async=False,
)
...@@ -596,6 +596,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -596,6 +596,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None: ) -> None:
num_tokens = topk_ids.size(0) num_tokens = topk_ids.size(0)
num_local_experts = fused_expert_output.size(0) num_local_experts = fused_expert_output.size(0)
......
...@@ -29,7 +29,8 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -29,7 +29,8 @@ from vllm.model_executor.layers.fused_moe.config import (
# yapf: enable # yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel, FusedMoEActivationFormat, FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) DeepGemmDisabledFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled) # is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -40,7 +41,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -40,7 +41,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx, has_deep_gemm
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -91,7 +92,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -91,7 +92,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError raise NotImplementedError
def init_prepare_finalize(self, moe: FusedMoEConfig, def init_prepare_finalize(self, layer, moe: FusedMoEConfig,
quant_config: Optional[QuantizationConfig]): quant_config: Optional[QuantizationConfig]):
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
...@@ -170,6 +171,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -170,6 +171,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8
# Note (varun): Whether to use FP8 dispatch or not needs some # Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now. # profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize( prepare_finalize = DeepEPLLPrepareAndFinalize(
...@@ -177,6 +180,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -177,6 +180,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank=moe.max_num_tokens, max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size, num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
use_int8_dispatch=use_int8_dispatch,
) )
self.topk_indices_dtype = None self.topk_indices_dtype = None
...@@ -184,10 +188,18 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -184,10 +188,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
logger.debug("%s", prepare_finalize.__class__.__name__) logger.debug("%s", prepare_finalize.__class__.__name__)
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, moe) experts = self.select_gemm_impl(prepare_finalize, moe)
if has_deep_gemm():
self.fused_experts = FusedMoEModularKernel( self.fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
) )
else:
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None,
)
def select_gemm_impl( def select_gemm_impl(
self, self,
...@@ -898,6 +910,10 @@ class FusedMoE(torch.nn.Module): ...@@ -898,6 +910,10 @@ class FusedMoE(torch.nn.Module):
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return None
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
...@@ -1445,9 +1461,13 @@ class FusedMoE(torch.nn.Module): ...@@ -1445,9 +1461,13 @@ class FusedMoE(torch.nn.Module):
assert i_q is None and i_s is None, "moe.quant fused not support TPU now" assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
if self.shared_experts is None:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output, self.layer_name, shared_output,
i_q, i_s) i_q, i_s)
else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits,
self.layer_name, shared_output)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1531,13 +1551,14 @@ class FusedMoE(torch.nn.Module): ...@@ -1531,13 +1551,14 @@ class FusedMoE(torch.nn.Module):
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_): i_s: Optional[torch.Tensor] = None, **_):
assert self.quant_method is not None assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels):
or self.moe_parallel_config.use_deepep_ll_kernels): #or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = ( do_naive_dispatch_combine: bool = (
self.dp_size > 1 self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels) and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_deepep_ll_kernels)
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)
...@@ -1694,3 +1715,34 @@ direct_register_custom_op( ...@@ -1694,3 +1715,34 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
...@@ -4,6 +4,8 @@ from abc import ABC, abstractmethod ...@@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from math import prod from math import prod
from typing import Optional, final from typing import Optional, final
from dataclasses import dataclass
from collections.abc import Callable
import torch import torch
...@@ -95,6 +97,54 @@ class FusedMoEActivationFormat(Enum): ...@@ -95,6 +97,54 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts", BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: torch.Tensor | None
@staticmethod
def make_from_list(
expert_num_tokens_list: list[int], device: str
) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(
expert_num_tokens_list, device="cpu", dtype=torch.int32, pin_memory=True
)
expert_num_tokens = expert_num_tokens_cpu.to(device=device, non_blocking=True)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType = tuple[
torch.Tensor,
torch.Tensor | None,
ExpertTokensMetadata | None,
torch.Tensor | None,
torch.Tensor | None,
]
ReceiverType = Callable[[], PrepareResultType]
# TODO: pass FusedMoEParallelConfig in as ctor parameter? # TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
...@@ -149,6 +199,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -149,6 +199,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None: ) -> None:
""" """
Perform any combine plus apply weights and perform a reduction on the Perform any combine plus apply weights and perform a reduction on the
...@@ -357,6 +408,168 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -357,6 +408,168 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise NotImplementedError raise NotImplementedError
class CustomizedFusedMoEPermuteExpertsUnpermute(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
def __init__(
self,
quant_config: Optional[FusedMoEQuantConfig],
):
if quant_config is not None:
self.quant_config = quant_config
else:
self.quant_config = FusedMoEQuantConfig()
@property
@abstractmethod
def activation_formats(
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
"""
raise NotImplementedError
@property
def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype
@property
def block_shape(self) -> Optional[list[int]]:
return self.quant_config.block_shape
@property
def per_act_token_quant(self) -> bool:
return self.quant_config.per_act_token_quant
@property
def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise NotImplementedError
@abstractmethod
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps
"""
raise NotImplementedError
@abstractmethod
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
raise NotImplementedError
def activation(self, activation: str, output: torch.Tensor,
input: torch.Tensor) -> None:
assert output.size(-1) * 2 == input.size(-1)
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
def enable_chunking(self):
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
self.supports_chunking()
@abstractmethod
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
):
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- output: (torch.Tensor): The unweighted, unreduced output tensor.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
"""
raise NotImplementedError
def _chunk_scales(scales: Optional[torch.Tensor], start: int, def _chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]: end: int) -> Optional[torch.Tensor]:
if scales is not None: if scales is not None:
...@@ -596,3 +809,218 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -596,3 +809,218 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_ids, apply_router_weight_on_input) topk_ids, apply_router_weight_on_input)
return output return output
_alt_stream: torch.cuda.Stream | None = None
def alt_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _alt_stream
# TODO: validate this works properly on ROCm platform.
if _alt_stream is None:
_alt_stream = torch.cuda.Stream()
return _alt_stream
@final
class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the `fused_experts` function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model
layer due to any layer specific state that may be used by the component
objects.
"""
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
shared_experts: Optional[torch.nn.Module] = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
if self.shared_experts is not None:
self.alt_stream = alt_stream()
self.alt_event = torch.cuda.Event()
# assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}."
# f"{prepare_finalize.activation_format} == "
# f"{fused_experts.__class__.__name__}."
# f"{fused_experts.activation_formats[0]}")
def forward(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
apply_router_weight_on_input: bool = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
**_
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
a1 = hidden_states
if inplace and self.shared_experts is None:
output = hidden_states
else:
output = torch.zeros_like(hidden_states)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
prepare_ret = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
)
if hook is not None:
hook()
(
a1q,
a1q_scale,
expert_tokens_meta,
_expert_topk_ids,
_expert_topk_weights,
) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights)
if a1q.numel() == 0:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
fused_out = self.fused_experts.apply(
None,
a1,
a1q,
w1,
w2,
topk_ids,
topk_weights=topk_weights,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=None,
workspace2=None,
use_nn_moe=use_nn_moe,
expert_num_tokens=expert_tokens_meta.expert_num_tokens,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
expert_num_tokens_cpu=expert_tokens_meta.expert_num_tokens_cpu
)
shared_output = None
self.alt_event.record()
current_stream = torch.cuda.current_stream()
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
self.alt_stream.wait_event(self.alt_event)
hook = None
if self.prepare_finalize.activation_format == \
FusedMoEActivationFormat.BatchedExperts:
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if hook is not None:
hook()
self.alt_event.record()
current_stream.wait_event(self.alt_event)
if self.shared_experts is not None:
return (shared_output, output)
return output
...@@ -207,6 +207,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -207,6 +207,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None: ) -> None:
# This argument is optional # This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0) # There's not much point setting this unless it is != topk_ids.size(0)
......
...@@ -61,6 +61,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -61,6 +61,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None: ) -> None:
_moe_unpermute_and_reduce(output, fused_expert_output, None, _moe_unpermute_and_reduce(output, fused_expert_output, None,
topk_weights, apply_router_weight_on_input) topk_weights, apply_router_weight_on_input)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return self._shared_experts if self.use_overlapped else None
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if (self.reduce_results and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()):
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
else:
# Matrix multiply.
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
return fused_out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
fused_experts = None
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a8=use_int4_w4a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.fused_experts = fused_experts
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
raise NotImplementedError
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
q_hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
topk_weights: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
expert_num_tokens_cpu: torch.Tensor = None,
):
assert self.fused_experts is not None
return self.fused_experts(
x=hidden_states,
w1=w1,
w2=w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=False,
activation=activation,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1q_scale,
a2_scale=a2_scale,
expert_num_tokens=expert_num_tokens,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
q_x=q_hidden_states,
expert_num_tokens_cpu=expert_num_tokens_cpu
)
...@@ -4,10 +4,21 @@ from math import prod ...@@ -4,10 +4,21 @@ from math import prod
from typing import Optional from typing import Optional
import torch import torch
from torch import nn
import triton
import triton.language as tl
from triton.language.extra import libdevice
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.utils import round_up
try: try:
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
...@@ -48,12 +59,221 @@ def _fp8_quantize( ...@@ -48,12 +59,221 @@ def _fp8_quantize(
return A, A_scale return A, A_scale
# @triton.jit
# def _per_token_quant_int8_kernel_opt(
# x_ptr,
# xq_ptr,
# scale_ptr,
# stride_x,
# stride_xq,
# N,
# T_dim,
# has_tokens_per_expert: tl.constexpr,
# tokens_per_expert_ptr,
# BLOCK: tl.constexpr
# ):
# row_id = tl.program_id(0)
# if has_tokens_per_expert:
# e = row_id // T_dim
# t = row_id % T_dim
# num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
# if t >= num_valid_tokens_for_e:
# return
# cols = tl.arange(0, BLOCK)
# mask = cols < N
# x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
# other=0.0).to(tl.float32)
# absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
# scale_x = absmax / 127
# x_q = x * (127 / absmax)
# x_q = libdevice.nearbyint(x_q).to(tl.int8)
# tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
# tl.store(scale_ptr + row_id, scale_x)
# def per_token_quant_int8_triton_opt(x: torch.Tensor,
# tokens_per_expert: Optional[torch.Tensor] = None):
# """
# Python wrapper for the Triton kernel.
# """
# if x.dim() != 3:
# raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
# E, T, H = x.shape
# M = E * T
# N = H
# x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
# scales = torch.empty(x.shape[:-1] + (1, ),
# device=x.device,
# dtype=torch.float32)
# BLOCK = triton.next_power_of_2(N)
# num_warps = min(max(BLOCK // 256, 1), 8)
# grid_opt = M
# _per_token_quant_int8_kernel_opt[(grid_opt, )](
# x,
# x_q,
# scales,
# stride_x=x.stride(-2),
# stride_xq=x_q.stride(-2),
# N=N,
# T_dim=T,
# has_tokens_per_expert= tokens_per_expert is not None,
# tokens_per_expert_ptr=tokens_per_expert,
# BLOCK=BLOCK,
# num_warps=num_warps,
# num_stages=1,
# )
# return x_q, scales
@triton.jit
def _per_token_quant_int8_one_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
row_id = tl.program_id(0)
if has_tokens_per_expert:
e = row_id // T_dim
t = row_id % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit
def _per_token_quant_int8_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
E_dim,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
token_idx_start = tl.program_id(0)
grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
for token_idx in range(token_idx_start, num_total_tokens, grid_size):
is_valid_token = True
if has_tokens_per_expert:
e = token_idx // T_dim
t = token_idx % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
is_valid_token = False
if is_valid_token:
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
def per_token_quant_int8_triton_opt(x: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None):
if x.dim() != 3:
raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
E, T, H = x.shape
N = H
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
if T >= 4096:
num_warps = 1
num_tokens = E * T
grid_opt = num_tokens
if E == 16 and T >= 1024 :
grid_opt = max(1, num_tokens // (T // 256))
_per_token_quant_int8_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
E_dim=E,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_per_token_quant_int8_one_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
def _int8_quantize( def _int8_quantize(
A: torch.Tensor, A: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
per_act_token: bool, per_act_token: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Perform int8 quantization on the inputs. If a block_shape Perform int8 quantization on the inputs. If a block_shape
...@@ -64,9 +284,12 @@ def _int8_quantize( ...@@ -64,9 +284,12 @@ def _int8_quantize(
# activations apply per-token quantization. Otherwise, assume # activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static # activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None: if block_shape is None:
assert per_act_token, \ # assert per_act_token, \
"int8 quantization only supports block or channel-wise" # "int8 quantization only supports block or channel-wise"
if expert_num_tokens is None:
A, A_scale = per_token_quant_int8(A) A, A_scale = per_token_quant_int8(A)
else:
A, A_scale = per_token_quant_int8_triton_opt(A, expert_num_tokens)
else: else:
assert not per_act_token assert not per_act_token
assert len(block_shape) == 2 assert len(block_shape) == 2
...@@ -83,11 +306,12 @@ def moe_kernel_quantize_input( ...@@ -83,11 +306,12 @@ def moe_kernel_quantize_input(
quant_dtype: Optional[torch.dtype], quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
expert_num_tokens: Optional[torch.Tensor]= None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if quant_dtype == torch.float8_e4m3fn: if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8: elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) return _int8_quantize(A, A_scale, per_act_token_quant, block_shape, expert_num_tokens)
else: else:
return A, A_scale return A, A_scale
...@@ -145,3 +369,536 @@ def _validate_scale_shape( ...@@ -145,3 +369,536 @@ def _validate_scale_shape(
assert block_shape is not None assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
@triton.jit
def _count_expert_num_tokens(
topk_ids_ptr,
expert_num_tokens_ptr,
num_experts,
topk_numel,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
curr_expert = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
topk_ids_ptrs = topk_ids_ptr + offsets
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)):
mask = offsets < (topk_numel - x * BLOCK_SIZE)
expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1)
if HAS_EXPERT_MAP:
expert_map_ptrs = expert_map + expert_ids
expert_map_mask = expert_ids >= 0
expert_ids = tl.load(expert_map_ptrs, mask=expert_map_mask, other=-1)
has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0)
acc = acc + has_curr_expert
topk_ids_ptrs += BLOCK_SIZE
if curr_expert < num_experts:
tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))
def count_expert_num_tokens(
topk_ids: torch.Tensor, num_local_experts: int, expert_map: torch.Tensor | None
) -> torch.Tensor:
"""
Count the number to tokens assigned to each expert.
Parameters:
- topk_ids (torch.Tensor): Tensor mapping each token to its
list of experts.
- num_local_experts (int): Number of experts in this rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
Returns:
A tensor of size num_local_experts, where tensor[i] holds the number
of tokens assigned to the ith expert.
"""
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
expert_num_tokens = torch.empty(
(num_local_experts), device=topk_ids.device, dtype=torch.int32
)
grid = num_local_experts
BLOCK_SIZE = min(topk_ids.numel(), 1024)
BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)
_count_expert_num_tokens[(grid,)](
topk_ids,
expert_num_tokens,
num_local_experts,
topk_ids.numel(),
expert_map,
HAS_EXPERT_MAP=expert_map is not None,
BLOCK_SIZE=BLOCK_SIZE,
)
return expert_num_tokens
def expert_num_tokens_round_up_and_sum(
expert_num_tokens: torch.Tensor, alignment: int
) -> int:
# Round up each element in expert_num_tokens to the nearest multiple of
# alignment.
ent = (expert_num_tokens.to(torch.int64) + (alignment - 1)) // alignment * alignment
return torch.sum(ent).item()
def compute_aligned_M(
M: int,
num_topk: int,
local_num_experts: int,
alignment: int,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
):
if expert_num_tokens_cpu is not None:
return expert_num_tokens_round_up_and_sum(
expert_num_tokens_cpu, alignment=alignment
)
# expert_num_tokens information is not available on the cpu.
# compute the max required size.
M_sum = (M * num_topk) + local_num_experts * (alignment - 1)
M_sum = round_up(M_sum, alignment)
return M_sum
@triton.jit
def apply_expert_map(expert_id, expert_map):
if expert_id != -1:
expert_id = tl.load(expert_map + expert_id).to(expert_id.dtype)
return expert_id
@triton.jit
def round_up_256(x: int) -> int:
y = 256
return ((x + y - 1) // y) * y
@triton.jit
def round_up_128(x: int) -> int:
y = 128
return ((x + y - 1) // y) * y
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts: tl.constexpr,
BLOCK_E: tl.constexpr,
BLOCK_EXPERT_NUM: tl.constexpr,
):
cur_expert = tl.program_id(0)
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
tokens_per_expert = tl.load(
num_recv_tokens_per_expert + offset_cumsum,
mask=offset_cumsum < num_experts,
other=0,
)
#tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert = round_up_256(tokens_per_expert)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
#if cur_expert == 0:
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
tl.debug_barrier()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
tl.store(
m_indices_start_ptr + start_m + off_expert,
cur_expert,
mask=start_m + off_expert < cur_expert_token_num
)
@triton.jit
def _fwd_kernel_ep_scatter_2(
total_token_num,
expert_start_loc,
recv_x,
recv_x_stride0,
recv_x_stride1,
recv_x_scale,
recv_x_scale_stride0,
recv_x_scale_stride1,
recv_topk,
recv_topk_stride0,
recv_topk_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
output_tensor_scale,
output_tensor_scale_stride0,
output_tensor_scale_stride1,
output_index,
output_index_stride0,
output_index_stride1,
topk_num: tl.constexpr,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
HIDDEN_SIZE_PAD: tl.constexpr,
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
):
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = index_in_s < SCALE_HIDDEN_SIZE
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
token_id = token_id_int32.to(tl.int64)
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(
recv_x_scale
+ token_id * recv_x_scale_stride0
+ index_in_s * recv_x_scale_stride1,
mask=mask_s,
)
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
topk_index = topk_idx_int32.to(tl.int64)
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index = dest_token_index_int32.to(tl.int64)
tl.store(
output_index + token_id * output_index_stride0 + topk_index,
dest_token_index_int32,
)
output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0
)
output_tensor_scale_ptr = (
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
)
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
to_copy_s,
mask=mask_s,
)
@torch.no_grad()
def ep_scatter(
recv_x: torch.Tensor,
recv_x_scale: torch.Tensor,
recv_topk: torch.Tensor,
num_recv_tokens_per_expert: torch.Tensor,
expert_map: torch.Tensor | None,
expert_start_loc: torch.Tensor,
output_tensor: torch.Tensor,
output_tensor_scale: torch.Tensor,
m_indices: torch.Tensor,
output_index: torch.Tensor,
):
#BLOCK_E = 128 # token num of per expert is aligned to 128
#BLOCK_D = 128 # block size of quantization
BLOCK_E = 256 # token num of per expert is aligned to 256
num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@triton.jit
def _fwd_kernel_ep_gather(
total_token_num,
input_tensor,
input_tensor_stride0,
input_tensor_stride1,
recv_topk_ids,
recv_topk_ids_stride0,
recv_topk_ids_stride1,
recv_topk_weight,
recv_topk_weight_stride0,
recv_topk_weight_stride1,
input_index,
input_index_stride0,
input_index_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
topk_num: tl.constexpr,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_D: tl.constexpr,
):
cur_block_int32 = tl.program_id(0)
cur_block = cur_block_int32.to(tl.int64)
start_cur_token_int32 = tl.program_id(1)
grid_num = tl.num_programs(1)
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
cur_token = cur_token_int32.to(tl.int64)
off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index_int32 in range(0, topk_num):
topk_index = topk_index_int32.to(tl.int64)
expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
source_token_index_int32 = tl.load(
input_index + cur_token * input_index_stride0 + topk_index
)
source_token_index = source_token_index_int32.to(tl.int64)
acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
)
tmp = tl.load(
input_tensor
+ source_token_index * input_tensor_stride0
+ cur_block * BLOCK_D
+ off_d
)
accumulator += tmp.to(tl.float32) * acc_weight
tl.store(
output_tensor
+ cur_token * output_tensor_stride0
+ cur_block * BLOCK_D
+ off_d,
accumulator.to(output_tensor.dtype.element_ty),
)
@torch.no_grad()
def ep_gather(
input_tensor: torch.Tensor,
recv_topk_ids: torch.Tensor,
recv_topk_weight: torch.Tensor,
input_index: torch.Tensor,
expert_map: torch.Tensor | None,
output_tensor: torch.Tensor,
):
num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
BLOCK_D = min(hidden_size, 1024)
assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid](
num_tokens,
input_tensor,
input_tensor.stride(0),
input_tensor.stride(1),
recv_topk_ids,
recv_topk_ids.stride(0),
recv_topk_ids.stride(1),
recv_topk_weight,
recv_topk_weight.stride(0),
recv_topk_weight.stride(1),
input_index,
input_index.stride(0),
input_index.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
topk_num=recv_topk_ids.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
BLOCK_D=BLOCK_D,
)
return
def deepgemm_moe_permute(
aq: torch.Tensor,
aq_scale: torch.Tensor,
topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: torch.Tensor | None,
block_shape: list[int],
expert_num_tokens: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
aq_out: torch.Tensor | None = None,
M_sum: int | None = None,
):
assert aq.ndim == 2
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
H = aq.size(1)
device = aq.device
block_m = block_shape[0]
if M_sum is None:
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=block_m,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
expert_start_loc = torch.empty(
(local_num_experts), device=device, dtype=torch.int32
)
assert aq_out is None or aq_out.shape == (M_sum, H)
if aq_out is None:
aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype)
aq_scale_out = torch.empty(
(M_sum, aq_scale.shape[-1]), device=device, dtype=torch.float32
)
expert_ids = torch.full(
(M_sum,), -1, dtype=torch.int32, device=device
)
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
if expert_num_tokens is None:
expert_num_tokens = count_expert_num_tokens(
topk_ids, local_num_experts, expert_map
)
ep_scatter(
recv_x=aq,
recv_x_scale=aq_scale,
recv_topk=topk_ids,
num_recv_tokens_per_expert=expert_num_tokens,
expert_start_loc=expert_start_loc,
expert_map=expert_map,
output_tensor=aq_out,
output_tensor_scale=aq_scale_out,
m_indices=expert_ids,
output_index=inv_perm,
)
return aq_out, aq_scale_out, expert_ids, inv_perm
def deepgemm_unpermute_and_reduce(
a: torch.Tensor, # Grouped gemm output
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
inv_perm: torch.Tensor,
expert_map: torch.Tensor | None,
output: torch.Tensor,
):
return ep_gather(
input_tensor=a,
recv_topk_ids=topk_ids,
recv_topk_weight=topk_weights,
input_index=inv_perm,
expert_map=expert_map,
output_tensor=output,
)
class EPSharedExperts(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
expect_tp_size=1)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
expect_tp_size=1)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
...@@ -485,9 +485,13 @@ class ColumnParallelLinear(LinearBase): ...@@ -485,9 +485,13 @@ class ColumnParallelLinear(LinearBase):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None,
): ):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition] self.output_partition_sizes = [self.output_size_per_partition]
...@@ -749,10 +753,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -749,10 +753,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None,
): ):
self.eps = eps self.eps = eps
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
tp_size = expect_tp_size
self.expect_tp_size = expect_tp_size
self.expect_tp_size = expect_tp_size
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=sum(output_sizes), output_size=sum(output_sizes),
...@@ -762,7 +774,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -762,7 +774,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias,
expect_tp_size=expect_tp_size)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -859,6 +872,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -859,6 +872,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_rank = 0
tp_size = 1
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size
...@@ -975,6 +993,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -975,6 +993,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_size = 1
if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod) Fp8LinearMethod, Fp8MoEMethod)
...@@ -1405,10 +1428,16 @@ class RowParallelLinear(LinearBase): ...@@ -1405,10 +1428,16 @@ class RowParallelLinear(LinearBase):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None,
): ):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.tp_rank = 0
self.tp_size = 1
self.expect_tp_size = expect_tp_size
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
...@@ -1454,6 +1483,10 @@ class RowParallelLinear(LinearBase): ...@@ -1454,6 +1483,10 @@ class RowParallelLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None:
tp_rank = 0
tp_size = 1
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False) is_sharded_weight = getattr(param, "is_sharded_weight", False)
...@@ -1506,6 +1539,8 @@ class RowParallelLinear(LinearBase): ...@@ -1506,6 +1539,8 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
......
...@@ -4,18 +4,32 @@ ...@@ -4,18 +4,32 @@
import enum import enum
from enum import Enum from enum import Enum
from typing import Callable, Optional, List from typing import Callable, Optional, List
from math import prod
import torch import torch
from compressed_tensors.quantization import (QuantizationStrategy) from compressed_tensors.quantization import (QuantizationStrategy)
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group, get_dp_group
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase, FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoEConfig, FusedMoeWeightScaleSupported,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce
from vllm.model_executor.layers.quantization.utils.w8a8_utils import( from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights) get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.utils import round_up
try: try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep, fuse_silu_mul_quant
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
from lightop import m_grouped_w8a8_gemm_nt_contig_asm
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -69,11 +83,32 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -69,11 +83,32 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"For INT8 Fused MoE layers, we require channelwise, " "For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.") "dynamic per token quantization. Found static input scales.")
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.use_deepgemm = False
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
self.block_shape = [256, 256]
self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.int8 params_dtype = torch.int8
# WEIGHTS # WEIGHTS
...@@ -124,19 +159,270 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -124,19 +159,270 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]): for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepgemm:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in) w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0) w1_marlin = torch.stack(w1_marlin_list, dim=0)
del w1_marlin_list del w1_marlin_list
w2_marlin_list = [] w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]): for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepgemm:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in) w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False) layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def masked_groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def contiguous_groupgemm_workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_num_tokens_cpu: torch.Tensor
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert self.block_shape is not None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m = self.block_shape[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_num_tokens_cpu
)
assert M_sum % block_m == 0
workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype, M_sum)
def w8a8_groupgemm_masked_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.masked_groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
#expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.ep_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
# ) // global_num_experts
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def w8a8_groupgemm_contiguous_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
a1q = q_x
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype, M_sum) = self.contiguous_groupgemm_workspace_shapes(
x, q_x, topk_ids.size(0), N, K, topk_ids.size(1), global_num_experts,
local_num_experts, expert_num_tokens_cpu)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=x.device,
dtype=workspace_dtype)
mm1_out = _resize_cache(workspace13, (M_sum, N))
mm2_out = _resize_cache(workspace2, (M_sum, K))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(
workspace13.view(dtype=a1q.dtype), (M_sum, N // 2)
)
fused_out = _resize_cache(workspace13, fused_out_shape)
a1q_perm = _resize_cache(workspace2.view(dtype=a1q.dtype), (M_sum, K))
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
block_shape=self.block_shape,
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
aq_out=a1q_perm,
M_sum=M_sum
)
m_grouped_w8a8_gemm_nt_contig_asm(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
a2q, a2q_scale = fuse_silu_mul_quant(mm1_out)
#a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, expert_ids=expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=fused_out,
)
return fused_out
def fused_moe_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_ ):
return fused_experts_impl_int8_marlin(
hidden_states=x if q_x is None else q_x,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
def apply( def apply(
self, self,
...@@ -164,14 +450,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -164,14 +450,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for " "EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.") "`CompressedTensorsW8A8Int8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -184,27 +470,64 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -184,27 +470,64 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate, use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None,)
return fused_experts_impl_int8_marlin( return self.fused_experts(
hidden_states=x, x,
w1=layer.w13_weight, layer.w13_weight,
w2=layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, w1_scale=(layer.w13_weight_scale),
w2_scale=layer.w2_weight_scale, w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False, use_nn_moe=False,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
i_q=i_q, i_q=i_q,
i_s=i_s) i_s=i_s
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
self.max_num_tokens_per_rank = max_num_tokens_per_rank
logger.debug(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_int8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_masked_forward
)
else:
logger.debug(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
use_int8_w8a8=True if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else False,
per_act_token_quant=True if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else False,
fused_experts=self.w8a8_groupgemm_contiguous_forward if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else self.fused_moe_forward
)
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import os import os
from math import prod
import torch import torch
from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group, get_dp_group
from torch.nn.parameter import Parameter from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
...@@ -13,16 +19,25 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon ...@@ -13,16 +19,25 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter) ModelWeightParameter)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
try: try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant, fuse_silu_mul_quant_ep
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
class MarlinMoeWorkspace: class MarlinMoeWorkspace:
""" """
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE. Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
...@@ -145,6 +160,21 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -145,6 +160,21 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def __init__(self, quant_config): def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
self.fused_experts = self.w4a8_fused_moe_marlin_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.ep_size = get_ep_group().world_size
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def create_weights( def create_weights(
self, self,
...@@ -157,6 +187,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -157,6 +187,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
): ):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.use_deepep:
self.N = 2 * intermediate_size
self.K = hidden_size
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -209,6 +243,139 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -209,6 +243,139 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False) layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False) layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def w4a8_fused_moe_marlin_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x if q_x is None else q_x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w4a8_groupgemm_marlin_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
#expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.ep_size
m_grouped_w4a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w4a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -233,9 +400,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -233,9 +400,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: # if enable_eplb:
raise NotImplementedError( # raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") # "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -248,30 +415,62 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -248,30 +415,62 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate use_fused_gate=use_fused_gate
) )
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() return self.fused_experts(
return fused_experts_impl_w4a8_marlin(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation, activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
self.max_num_tokens_per_rank = max_num_tokens_per_rank
assert max_num_tokens_per_rank is not None
logger.info(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_int4_w4a8=True,
per_act_token_quant=True,
fused_experts=self.w4a8_groupgemm_marlin_forward
)
else:
logger.info(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
# use_int4_w4a8=True,
# per_act_token_quant=True,
fused_experts=self.w4a8_fused_moe_marlin_forward
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment