Commit 52f895ab authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds-wm-1202' into 'v0.9.2-dev-ds'

[feat]支持deepep低延迟与共享专家overlap

See merge request dcutoolkit/deeplearing/vllm!281
parents bca29c66 1ae8f58c
...@@ -4780,8 +4780,9 @@ class VllmConfig: ...@@ -4780,8 +4780,9 @@ class VllmConfig:
# add for spec decode # add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0: if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots), mtp_batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list)) batch_size_capture_list))
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)
......
...@@ -192,8 +192,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -192,8 +192,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False, low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank, num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=False, explicitly_destroy=False)
use_default_stream_as_comm_stream=False)
def get_handle(self, kwargs): def get_handle(self, kwargs):
......
...@@ -237,10 +237,11 @@ class DeviceCommunicatorBase: ...@@ -237,10 +237,11 @@ class DeviceCommunicatorBase:
moe_modules = [ moe_modules = [
module for module in model.modules() module for module in model.modules()
if module.__class__.__name__ == "FusedMoE" if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config, module.quant_method.init_prepare_finalize(module, module.moe_config,
module.quant_config) module.quant_config)
def dispatch( def dispatch(
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None _config: Optional[dict[str, Any]] = None
...@@ -38,6 +39,7 @@ __all__ = [ ...@@ -38,6 +39,7 @@ __all__ = [
"FusedMoEPrepareAndFinalize", "FusedMoEPrepareAndFinalize",
"override_config", "override_config",
"get_config", "get_config",
"SharedFusedMoE",
] ]
if HAS_TRITON: if HAS_TRITON:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union from typing import Optional, Union
from collections.abc import Callable
import deep_ep import deep_ep
import torch import torch
...@@ -44,12 +45,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -44,12 +45,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
buffer: deep_ep.Buffer, buffer: deep_ep.Buffer,
max_tokens_per_rank: int, max_tokens_per_rank: int,
num_dispatchers: int, num_dispatchers: int,
use_fp8_dispatch: bool = False): use_fp8_dispatch: bool = False,
use_int8_dispatch: bool = False):
super().__init__() super().__init__()
self.buffer = buffer self.buffer = buffer
self.max_tokens_per_rank = max_tokens_per_rank self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch self.use_fp8_dispatch = use_fp8_dispatch
self.use_int8_dispatch = use_int8_dispatch
# The dispatch function returns a handle that the combine function # The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the # requires. We store the handle here so it is available to the
# combine function. # combine function.
...@@ -154,7 +157,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -154,7 +157,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids, topk_ids,
self.max_tokens_per_rank, self.max_tokens_per_rank,
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch, use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch,
use_int8=self.use_int8_dispatch,
async_finish=False, async_finish=False,
return_recv_hook=False) return_recv_hook=False)
...@@ -164,11 +168,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -164,11 +168,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_num_tokens, None, None) return (expert_x, expert_x_scale, expert_num_tokens, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def _finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None: apply_weights_and_reduce: bool,
do_async: bool,
) -> Callable:
do_recv_hook = do_async
assert self.handle is not None assert self.handle is not None
combine_topk_weights = topk_weights combine_topk_weights = topk_weights
...@@ -177,12 +187,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -177,12 +187,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights) combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode # TODO (varun) : Enable zero copy mode
_, event, hook = self.buffer.low_latency_combine( _, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output, fused_expert_output,
topk_ids, topk_ids,
combine_topk_weights, combine_topk_weights,
self.handle, self.handle,
async_finish=False, async_finish=False,
zero_copy=False, zero_copy=False,
return_recv_hook=False, return_recv_hook=do_recv_hook,
out=output) out=output,
)
return recv_hook
def finalize_async(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
return self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
apply_weights_and_reduce,
do_async=True,
)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
apply_weights_and_reduce,
do_async=False,
)
...@@ -92,7 +92,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -92,7 +92,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError raise NotImplementedError
def init_prepare_finalize(self, moe: FusedMoEConfig, def init_prepare_finalize(self, layer, moe: FusedMoEConfig,
quant_config: Optional[QuantizationConfig]): quant_config: Optional[QuantizationConfig]):
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
...@@ -171,6 +171,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -171,6 +171,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = False#moe.quant_config.quant_dtype == torch.int8
# Note (varun): Whether to use FP8 dispatch or not needs some # Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now. # profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize( prepare_finalize = DeepEPLLPrepareAndFinalize(
...@@ -178,6 +180,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -178,6 +180,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank=moe.max_num_tokens, max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size, num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
use_int8_dispatch=use_int8_dispatch,
) )
self.topk_indices_dtype = None self.topk_indices_dtype = None
...@@ -195,6 +198,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -195,6 +198,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel( self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None,
) )
def select_gemm_impl( def select_gemm_impl(
...@@ -913,6 +917,10 @@ class FusedMoE(torch.nn.Module): ...@@ -913,6 +917,10 @@ class FusedMoE(torch.nn.Module):
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return None
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
...@@ -1456,8 +1464,12 @@ class FusedMoE(torch.nn.Module): ...@@ -1456,8 +1464,12 @@ class FusedMoE(torch.nn.Module):
if current_platform.is_tpu(): if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
if self.shared_experts is None:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output) self.layer_name, shared_output)
else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits,
self.layer_name, shared_output)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1668,3 +1680,34 @@ direct_register_custom_op( ...@@ -1668,3 +1680,34 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
...@@ -761,6 +761,20 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -761,6 +761,20 @@ class FusedMoEModularKernel(torch.nn.Module):
return output return output
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
# TODO: validate this works properly on ROCm platform.
if _aux_stream is None:
_aux_stream = torch.cuda.Stream()
return _aux_stream
@final @final
class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
""" """
...@@ -779,10 +793,17 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -779,10 +793,17 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute, fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
shared_experts: Optional[torch.nn.Module] = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.shared_experts = shared_experts
if self.shared_experts is not None:
self.shared_experts_stream = aux_stream()
self.shared_experts_overlap_event = torch.cuda.Event()
# assert prepare_finalize.activation_format == \ # assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], ( # fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}." # f"{prepare_finalize.__class__.__name__}."
...@@ -849,7 +870,11 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -849,7 +870,11 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
""" """
a1 = hidden_states a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
if inplace and self.shared_experts is None:
output = hidden_states
else:
output = torch.zeros_like(hidden_states)
local_num_experts = w1.size(0) local_num_experts = w1.size(0)
if global_num_experts == -1: if global_num_experts == -1:
...@@ -898,7 +923,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -898,7 +923,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
shared_output = None
if self.shared_experts is None:
self.prepare_finalize.finalize(output, fused_out, topk_weights, self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False) topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if hook is not None:
hook()
if self.shared_experts is not None:
return (shared_output, output)
return output return output
...@@ -189,6 +189,8 @@ class MoriMoE(FusedMoE): ...@@ -189,6 +189,8 @@ class MoriMoE(FusedMoE):
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
moe_permute_fusion: bool = False, moe_permute_fusion: bool = False,
shared_experts: Optional[torch.nn.Module] = None,
use_overlapped: bool = False,
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype, intermediate_size, params_dtype,
...@@ -214,6 +216,7 @@ class MoriMoE(FusedMoE): ...@@ -214,6 +216,7 @@ class MoriMoE(FusedMoE):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
apply_router_weight_on_input=self.apply_router_weight_on_input apply_router_weight_on_input=self.apply_router_weight_on_input
) )
self.shared_experts = shared_experts
local_expert_indices_offset = ( local_expert_indices_offset = (
self.ep_rank * self.local_num_experts self.ep_rank * self.local_num_experts
...@@ -222,8 +225,6 @@ class MoriMoE(FusedMoE): ...@@ -222,8 +225,6 @@ class MoriMoE(FusedMoE):
local_expert_indices_offset + i for i in range(self.local_num_experts) local_expert_indices_offset + i for i in range(self.local_num_experts)
] ]
self.shared_experts = None
self.scales = None self.scales = None
self.use_int8_dispatch = True self.use_int8_dispatch = True
...@@ -267,10 +268,6 @@ class MoriMoE(FusedMoE): ...@@ -267,10 +268,6 @@ class MoriMoE(FusedMoE):
return _MORI_OP return _MORI_OP
def set_shared_experts(self, shared_experts: torch.nn.Module):
if self.shared_experts is None:
self.shared_experts = shared_experts
def create_quant_method(self, moe, quant_config, prefix): def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return self._shared_experts if self.use_overlapped else None
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if (self.reduce_results and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()):
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
else:
# Matrix multiply.
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
return fused_out
...@@ -82,11 +82,13 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -82,11 +82,13 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
self.fused_experts = self.fused_moe_forward self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
if self.use_deepep: if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
...@@ -97,7 +99,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -97,7 +99,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep: if self.use_deepep_ll:
self.N = 2 * intermediate_size_per_partition self.N = 2 * intermediate_size_per_partition
self.K = hidden_size self.K = hidden_size
...@@ -151,7 +153,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -151,7 +153,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]): for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep: if not self.use_deepep_ll:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else: else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii]) w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
...@@ -162,7 +164,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -162,7 +164,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
del w1_marlin_list del w1_marlin_list
w2_marlin_list = [] w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]): for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep: if not self.use_deepep_ll:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else: else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii]) w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
...@@ -236,7 +238,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -236,7 +238,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
# (from deepgemm docs) : A value hint (which is a value on CPU) # (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value # for the M expectation of each batch, correctly setting this value
# may lead to better performance. # may lead to better performance.
expected_m = max_num_tokens #expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.dp_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
# ) // global_num_experts
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale), m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale), (w1, w1_scale),
......
...@@ -466,9 +466,9 @@ def apply_int8_linear( ...@@ -466,9 +466,9 @@ def apply_int8_linear(
m_=m m_=m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"] #best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
elif m<=64: elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数 m_= 64#(m + 3) & -4 #取值到最近的4的倍数
elif m<=160: elif m<=160:
m_=(m + 7) & -8 m_=160#(m + 7) & -8
elif m<200: #256 elif m<200: #256
m_=160 m_=160
......
...@@ -42,7 +42,8 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig, ...@@ -42,7 +42,8 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group, from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
from vllm.model_executor.layers.fused_moe.mori_moe.layer import MoriMoE from vllm.model_executor.layers.fused_moe.mori_moe.layer import MoriMoE
from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EPSharedExperts from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EPSharedExperts
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -169,7 +170,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -169,7 +170,10 @@ class DeepseekV2MoE(nn.Module):
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori' self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
self.use_deepep_ll = dp_size > 1 and parallel_config.enable_expert_parallel and \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
if not self.use_deepep_ll:
moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
self.experts = moe_cls( self.experts = moe_cls(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
...@@ -198,12 +202,40 @@ class DeepseekV2MoE(nn.Module): ...@@ -198,12 +202,40 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs( reduce_results=self.experts.must_reduce_shared_expert_outputs(),
),
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
if self.use_mori_ep: else:
self.experts.set_shared_experts(self.shared_experts) if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_mori_ep else EPSharedExperts
self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=dp_size != self.ep_size,
prefix=f"{prefix}.shared_experts",
)
self.experts = SharedFusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor,
shared_experts=self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
...@@ -215,7 +247,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -215,7 +247,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_mori_ep: if not self.use_mori_ep and not self.use_deepep_ll:
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
...@@ -250,9 +282,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -250,9 +282,8 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: else:
if not self.use_mori_ep: if self.use_deepep_ll:
final_hidden_states = self.experts( shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if shared_output is not None: if shared_output is not None:
...@@ -263,9 +294,22 @@ class DeepseekV2MoE(nn.Module): ...@@ -263,9 +294,22 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: elif self.use_mori_ep:
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
else:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if not self.use_mori_ep: if not self.use_mori_ep:
if self.tp_size > 1: if self.tp_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment