Commit 1ae8f58c authored by 王敏's avatar 王敏
Browse files

[feat]支持deepep低延迟与共享专家overlap

parent bca29c66
......@@ -4780,8 +4780,9 @@ class VllmConfig:
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
mtp_batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list))
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
......
......@@ -192,8 +192,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=False,
use_default_stream_as_comm_stream=False)
explicitly_destroy=False)
def get_handle(self, kwargs):
......
......@@ -237,10 +237,11 @@ class DeviceCommunicatorBase:
moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config,
module.quant_method.init_prepare_finalize(module, module.moe_config,
module.quant_config)
def dispatch(
......
......@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
......@@ -38,6 +39,7 @@ __all__ = [
"FusedMoEPrepareAndFinalize",
"override_config",
"get_config",
"SharedFusedMoE",
]
if HAS_TRITON:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from collections.abc import Callable
import deep_ep
import torch
......@@ -44,12 +45,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
buffer: deep_ep.Buffer,
max_tokens_per_rank: int,
num_dispatchers: int,
use_fp8_dispatch: bool = False):
use_fp8_dispatch: bool = False,
use_int8_dispatch: bool = False):
super().__init__()
self.buffer = buffer
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
self.use_int8_dispatch = use_int8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
......@@ -154,7 +157,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch,
use_int8=self.use_int8_dispatch,
async_finish=False,
return_recv_hook=False)
......@@ -164,11 +168,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
apply_weights_and_reduce: bool,
do_async: bool,
) -> Callable:
do_recv_hook = do_async
assert self.handle is not None
combine_topk_weights = topk_weights
......@@ -177,12 +187,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode
_, event, hook = self.buffer.low_latency_combine(
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
topk_ids,
combine_topk_weights,
self.handle,
async_finish=False,
zero_copy=False,
return_recv_hook=False,
out=output)
return_recv_hook=do_recv_hook,
out=output,
)
return recv_hook
def finalize_async(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
return self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
apply_weights_and_reduce,
do_async=True,
)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
apply_weights_and_reduce,
do_async=False,
)
......@@ -92,7 +92,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
def init_prepare_finalize(self, moe: FusedMoEConfig,
def init_prepare_finalize(self, layer, moe: FusedMoEConfig,
quant_config: Optional[QuantizationConfig]):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
......@@ -171,6 +171,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = False#moe.quant_config.quant_dtype == torch.int8
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
......@@ -178,6 +180,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
use_int8_dispatch=use_int8_dispatch,
)
self.topk_indices_dtype = None
......@@ -195,6 +198,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None,
)
def select_gemm_impl(
......@@ -913,6 +917,10 @@ class FusedMoE(torch.nn.Module):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return None
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
......@@ -1456,8 +1464,12 @@ class FusedMoE(torch.nn.Module):
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
if self.shared_experts is None:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output)
else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits,
self.layer_name, shared_output)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
......@@ -1668,3 +1680,34 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
......@@ -761,6 +761,20 @@ class FusedMoEModularKernel(torch.nn.Module):
return output
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
# TODO: validate this works properly on ROCm platform.
if _aux_stream is None:
_aux_stream = torch.cuda.Stream()
return _aux_stream
@final
class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
"""
......@@ -779,10 +793,17 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
shared_experts: Optional[torch.nn.Module] = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
if self.shared_experts is not None:
self.shared_experts_stream = aux_stream()
self.shared_experts_overlap_event = torch.cuda.Event()
# assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}."
......@@ -849,7 +870,11 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
"""
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
if inplace and self.shared_experts is None:
output = hidden_states
else:
output = torch.zeros_like(hidden_states)
local_num_experts = w1.size(0)
if global_num_experts == -1:
......@@ -898,7 +923,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
routed_scaling_factor=routed_scaling_factor,
)
shared_output = None
if self.shared_experts is None:
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if hook is not None:
hook()
if self.shared_experts is not None:
return (shared_output, output)
return output
......@@ -189,6 +189,8 @@ class MoriMoE(FusedMoE):
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = False,
shared_experts: Optional[torch.nn.Module] = None,
use_overlapped: bool = False,
):
super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype,
......@@ -214,6 +216,7 @@ class MoriMoE(FusedMoE):
routed_scaling_factor=self.routed_scaling_factor,
apply_router_weight_on_input=self.apply_router_weight_on_input
)
self.shared_experts = shared_experts
local_expert_indices_offset = (
self.ep_rank * self.local_num_experts
......@@ -222,8 +225,6 @@ class MoriMoE(FusedMoE):
local_expert_indices_offset + i for i in range(self.local_num_experts)
]
self.shared_experts = None
self.scales = None
self.use_int8_dispatch = True
......@@ -267,10 +268,6 @@ class MoriMoE(FusedMoE):
return _MORI_OP
def set_shared_experts(self, shared_experts: torch.nn.Module):
if self.shared_experts is None:
self.shared_experts = shared_experts
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return self._shared_experts if self.use_overlapped else None
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if (self.reduce_results and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()):
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
else:
# Matrix multiply.
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
return fused_out
......@@ -82,11 +82,13 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
self.dp_size = get_dp_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
......@@ -97,7 +99,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
if self.use_deepep_ll:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
......@@ -151,7 +153,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep:
if not self.use_deepep_ll:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
......@@ -162,7 +164,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep:
if not self.use_deepep_ll:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
......@@ -236,7 +238,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
#expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.dp_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
# ) // global_num_experts
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
......
......@@ -466,9 +466,9 @@ def apply_int8_linear(
m_=m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
m_= 64#(m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
m_=160#(m + 7) & -8
elif m<200: #256
m_=160
......
......@@ -42,7 +42,8 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
from vllm.model_executor.layers.fused_moe.mori_moe.layer import MoriMoE
from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EPSharedExperts
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -169,7 +170,10 @@ class DeepseekV2MoE(nn.Module):
dp_size = get_dp_group().world_size
self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel
self.use_deepep_ll = dp_size > 1 and parallel_config.enable_expert_parallel and \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
if not self.use_deepep_ll:
moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
self.experts = moe_cls(
num_experts=config.n_routed_experts,
......@@ -198,12 +202,40 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
prefix=f"{prefix}.shared_experts",
)
if self.use_mori_ep:
self.experts.set_shared_experts(self.shared_experts)
else:
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_mori_ep else EPSharedExperts
self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=dp_size != self.ep_size,
prefix=f"{prefix}.shared_experts",
)
self.experts = SharedFusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor,
shared_experts=self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
......@@ -215,7 +247,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_mori_ep:
if not self.use_mori_ep and not self.use_deepep_ll:
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
......@@ -250,9 +282,8 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
if not self.use_mori_ep:
final_hidden_states = self.experts(
hidden_states=hidden_states,
if self.use_deepep_ll:
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
......@@ -263,9 +294,22 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
elif self.use_mori_ep:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
else:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if not self.use_mori_ep:
if self.tp_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment