Commit b91b3028 authored by wanglong3's avatar wanglong3
Browse files

feat: Support shared experts fusion.

feat: support moe sum when topk==9

bugfix: Fix mtp model load error when eable shared experts fusion.
parent 9f68733a
......@@ -173,6 +173,39 @@ __global__ void moe_sum_kernel(
}
}
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const int d) {
const int token_idx = blockIdx.x / SPLIT_D;
const int sub_block = blockIdx.x % SPLIT_D;
const int d_per_block = (d + SPLIT_D - 1) / SPLIT_D;
const int64_t d_start = sub_block * d_per_block;
const int64_t d_end = min(d_start + d_per_block, d);
const int64_t token_offset = token_idx * TOPK * d;
__shared__ __align__(16) scalar_t sem_input[TOPK][BLOCK_DIM];
for (int64_t idx = d_start + threadIdx.x; idx < d_end; idx += blockDim.x) {
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
sem_input[k][threadIdx.x] =
input[token_offset + k * d + idx];
}
__syncthreads();
scalar_t x = 0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += sem_input[k][threadIdx.x];
}
out[token_idx * d + idx] = x;
}
}
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem_topk8(
scalar_t* __restrict__ out,
......@@ -440,7 +473,14 @@ void moe_sum_opt1(torch::Tensor& input, // [num_tokens, topk, hidden_size]
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 9:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_sharedmem", [&]{
vllm::moe::moe_sum_sharedmem<scalar_t, 9, 9, 256><<<num_tokens * 9, 256, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default:
......
......@@ -216,6 +216,7 @@ if TYPE_CHECKING:
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1351,9 +1352,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
# shared experts overlap with routed experts
# VLLM_DISABLE_SHARED_EXPERTS_STREAM = 1 disable shared experts overlap
# VLLM_DISABLE_SHARED_EXPERTS_STREAM = 0 enable shared experts overlap
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
),
# shared experts fusion
# VLLM_ENABLE_SHARED_EXPERTS_FUSION = 1 enable shared experts fusion
# VLLM_ENABLE_SHARED_EXPERTS_FUSION = 0 disable shared experts fusion
"VLLM_ENABLE_SHARED_EXPERTS_FUSION": lambda: bool(
int(os.getenv("VLLM_ENABLE_SHARED_EXPERTS_FUSION", "0"))
),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
......
......@@ -1304,7 +1304,9 @@ def grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
num_fused_shared_experts: Optional[int] = 0
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), (
......@@ -1317,7 +1319,7 @@ def grouped_topk(
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0)
num_token, num_experts = scores.shape
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
......@@ -1348,8 +1350,20 @@ def grouped_topk(
dim=-1,
sorted=False)
if num_fused_shared_experts != 0:
topk_ids[:, -1] = num_experts
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
......@@ -2135,12 +2149,17 @@ def fused_moe(
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
topk_weights, topk_ids = grouped_topk(
hidden_states = hidden_states,
gating_output = gating_output,
topk = topk,
renormalize = renormalize,
num_expert_group = num_expert_group,
topk_group = topk_group,
routed_scaling_factor = routed_scaling_factor,
num_fused_shared_experts = 1)
elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize)
......
......@@ -1560,8 +1560,9 @@ class FusedMoE(torch.nn.Module):
num_expert_group,
topk_group,
top_k,
0,
routed_scaling_factor,
# TODO also required num of shared expert is not None
(1 if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION else 0),
routed_scaling_factor
)
else:
topk_weights, topk_ids = ops.moe_fused_gate(
......@@ -1582,7 +1583,11 @@ class FusedMoE(torch.nn.Module):
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias = e_score_correction_bias,
routed_scaling_factor = routed_scaling_factor,
# TODO also required num of shared expert is not None
num_fused_shared_experts = (1 if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION else 0)
)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
elif custom_routing_function is None:
......
......@@ -252,7 +252,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD") \
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
......@@ -301,7 +302,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD") \
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
......@@ -370,7 +372,6 @@ def get_model_architecture(
mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
]
vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures)
......
......@@ -235,12 +235,15 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("qa_kva_proj", "kv_a_proj_with_mqa", 1)
]
stacked_params_mapping += fused_params_mapping
enable_shared_experts_fusion = envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION and self.config.n_shared_experts > 0
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
num_experts=self.config.n_routed_experts + (
self.config.n_shared_experts
if enable_shared_experts_fusion else 0
))
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
......@@ -251,6 +254,16 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
# Assuame num of shared experts is only one.
if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION \
and self.config.n_shared_experts > 0 \
and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
......@@ -273,7 +286,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
param = params_dict[name]
weight_loader = param.weight_loader
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
if envs.USE_FUSED_RMS_QUANT \
and envs.VLLM_USE_FUSED_QA_KVA_GEMM \
and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
weight_loader(param, loaded_weight, old_weight_name)
else:
weight_loader(param, loaded_weight, shard_id)
......@@ -362,339 +377,3 @@ class DeepSeekMTP(nn.Module, SupportsPP):
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
return name
\ No newline at end of file
# # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# import os
# import re
# from collections.abc import Iterable
# from typing import Iterable, Optional
# import torch
# import torch.nn as nn
# from transformers import PretrainedConfig
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
# from vllm.model_executor.layers.fused_moe import FusedMoE
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
# from vllm.model_executor.layers.quantization import QuantizationConfig
# from vllm.model_executor.layers.vocab_parallel_embedding import (
# ParallelLMHead, VocabParallelEmbedding)
# from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors
# from vllm.compilation.decorators import support_torch_compile
# from .deepseek_v2 import (DeepseekV2DecoderLayer,
# get_spec_layer_idx_from_weight_name)
# from .interfaces import SupportsPP
# from .utils import maybe_prefix
# from vllm import _custom_ops as ops
# from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
# class SharedHead(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# return self.norm(hidden_states)
# class DeepSeekMultiTokenPredictorLayer(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# prefix: str,
# model_config: ModelConfig,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.eh_proj = nn.Linear(config.hidden_size * 2,
# config.hidden_size,
# bias=False)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
# self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
# cache_config, quant_config)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_index: int = 0,
# ) -> torch.Tensor:
# assert inputs_embeds is not None
# # masking inputs at position 0, as not needed by MTP
# inputs_embeds[positions == 0] = 0
# inputs_embeds = self.enorm(inputs_embeds)
# previous_hidden_states = self.hnorm(previous_hidden_states)
# hidden_states = self.eh_proj(
# torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
# hidden_states, residual = self.mtp_block(positions=positions,
# hidden_states=hidden_states,
# residual=None)
# hidden_states = residual + hidden_states
# return hidden_states
# class DeepSeekMultiTokenPredictor(nn.Module):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# config = vllm_config.model_config.hf_config
# self.mtp_start_layer_idx = config.num_hidden_layers
# self.num_mtp_layers = config.num_nextn_predict_layers
# # to map the exact layer index from weights
# self.layers = torch.nn.ModuleDict({
# str(idx):
# DeepSeekMultiTokenPredictorLayer(
# config,
# f"{prefix}.layers.{idx}",
# model_config=vllm_config.model_config,
# cache_config=vllm_config.cache_config,
# quant_config=vllm_config.quant_config,
# )
# for idx in range(self.mtp_start_layer_idx,
# self.mtp_start_layer_idx + self.num_mtp_layers)
# })
# self.embed_tokens = VocabParallelEmbedding(
# config.vocab_size,
# config.hidden_size,
# )
# self.logits_processor = LogitsProcessor(config.vocab_size)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
# input_ids,
# positions,
# previous_hidden_states,
# inputs_embeds,
# current_step_idx,
# )
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# mtp_layer = self.layers[str(self.mtp_start_layer_idx +
# current_step_idx)]
# logits = self.logits_processor(mtp_layer.shared_head.head,
# mtp_layer.shared_head(hidden_states),
# sampling_metadata)
# return logits
# @support_torch_compile
# class DeepSeekMTP(nn.Module, SupportsPP):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# self.config = vllm_config.model_config.hf_config
# quant_config = vllm_config.quant_config
# self.quant_method = None
# if quant_config is not None:
# self.quant_method = quant_config.get_name()
# os.environ['LLAMA_NN'] = '0'
# os.environ['LM_NN'] = '0'
# # The AWQ layer of MTP uses BlockInt8W8A8.
# if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
# vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
# self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
# prefix=maybe_prefix(
# prefix, "model"))
# self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# intermediate_tensors: Optional[IntermediateTensors] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# hidden_states = self.model(input_ids, positions,
# previous_hidden_states, inputs_embeds,
# spec_step_idx)
# return hidden_states
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> Optional[torch.Tensor]:
# return self.model.compute_logits(hidden_states, sampling_metadata,
# spec_step_idx)
# def load_weights(self, weights: Iterable[tuple[str,
# torch.Tensor]]) -> set[str]:
# stacked_params_mapping = [
# ("gate_up_proj", "gate_proj", 0),
# ("gate_up_proj", "up_proj", 1),
# ]
# expert_params_mapping = FusedMoE.make_expert_params_mapping(
# ckpt_gate_proj_name="gate_proj",
# ckpt_down_proj_name="down_proj",
# ckpt_up_proj_name="up_proj",
# num_experts=self.config.n_routed_experts)
# params_dict = dict(self.named_parameters())
# loaded_params: set[str] = set()
# for name, loaded_weight in weights:
# if "rotary_emb.inv_freq" in name:
# continue
# spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
# if spec_layer is None:
# continue
# name = self._rewrite_spec_layer_name(spec_layer, name)
# for (param_name, weight_name, shard_id) in stacked_params_mapping:
# # Skip non-stacked layers and experts (experts handled below).
# if weight_name not in name:
# continue
# # We have mlp.experts[0].gate_proj in the checkpoint.
# # Since we handle the experts below in expert_params_mapping,
# # we need to skip here BEFORE we update the name, otherwise
# # name will be updated to mlp.experts[0].gate_up_proj, which
# # will then be updated below in expert_params_mapping
# # for mlp.experts[0].gate_gate_up_proj, which breaks load.
# if (("mlp.experts." in name) and name not in params_dict):
# continue
# name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param, loaded_weight, shard_id)
# break
# else:
# for mapping in expert_params_mapping:
# param_name, weight_name, expert_id, shard_id = mapping
# if weight_name not in name:
# continue
# name = name.replace(weight_name, param_name)
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param,
# loaded_weight,
# name,
# shard_id=shard_id,
# expert_id=expert_id)
# break
# else:
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# # According to DeepSeek-V3 Technical Report, MTP modules
# # shares embedding layer. We only load the first weights.
# if (spec_layer != self.model.mtp_start_layer_idx
# and ".layers" not in name):
# continue
# param = params_dict[name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# weight_loader(param, loaded_weight)
# loaded_params.add(name)
# if self.use_llama_nn and self.quant_method is None:
# lay_key_words = [
# "self_attn.eh_proj.weight",
# "self_attn.q_proj.weight",
# "self_attn.q_a_proj.weight",
# "self_attn.q_b_proj.weight",
# "self_attn.kv_a_proj_with_mqa.weight",
# "self_attn.kv_b_proj.weight",
# "self_attn.o_proj.weight",
# "mlp.gate_up_proj.weight",
# "mlp.down_proj.weight",
# "mlp.gate.weight",
# "shared_experts.gate_up_proj.weight",
# "shared_experts.down_proj.weight",
# "shared_head.head.weight",
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# _weight = torch.zeros_like(weight.data)
# ori_shape =_weight.shape
# ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
# weight.data.copy_(_weight)
# weight.data=weight.data.reshape(ori_shape[1],-1)
# return loaded_params
# def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
# """
# Rewrite the weight name to match the format of the original model.
# Add .mtp_block for modules in transformer layer block for spec layer
# and rename shared layer weights to be top level.
# """
# spec_layer_weight_names = [
# "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
# ]
# shared_weight_names = ["embed_tokens"]
# spec_layer_weight = False
# shared_weight = False
# for weight_name in spec_layer_weight_names:
# if weight_name in name:
# spec_layer_weight = True
# if weight_name in shared_weight_names:
# shared_weight = True
# break
# if not spec_layer_weight:
# # treat rest weights as weights for transformer layer block
# name = name.replace(f"model.layers.{spec_layer}.",
# f"model.layers.{spec_layer}.mtp_block.")
# elif shared_weight:
# # treat shared weights as top level weights
# name = name.replace(f"model.layers.{spec_layer}.", "model.")
# return name
......@@ -74,7 +74,6 @@ from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from lmslim.quantize.quant_ops import lm_faster_rmsquant
class DeepseekV2MLP(nn.Module):
def __init__(
......@@ -194,8 +193,10 @@ class DeepseekV2MoE(nn.Module):
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.enable_shared_experts_overlap = False
self.enable_shared_experts_fusion = (self.n_shared_experts != 0 and envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION)
if not self.use_deepep:
if config.n_shared_experts is not None:
if config.n_shared_experts is not None and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekV2MLP(
......@@ -230,9 +231,13 @@ class DeepseekV2MoE(nn.Module):
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
num_fake_experts = (
config.n_shared_experts
if self.enable_shared_experts_fusion
else 0)
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
num_experts=config.n_routed_experts + (num_fake_experts),
top_k=config.num_experts_per_tok + min(num_fake_experts, 1),
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
......@@ -278,7 +283,11 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor,
shared_experts=self.shared_experts)
self.run_shared_expert_singlely = (self.n_shared_experts is not None and not self.enable_shared_experts_overlap)
self.run_shared_expert_singlely = (
self.n_shared_experts is not None
and not self.enable_shared_experts_overlap
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
......@@ -292,6 +301,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# For shared experts overlap optimization.
def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
......@@ -308,6 +318,21 @@ class DeepseekV2MoE(nn.Module):
i_q = i_q,
i_s = i_s)
# For shared experts fusion optimization.
def shared_exprts_fusion_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
return self.experts(
hidden_states = hidden_states,
router_logits = router_logits,
i_q = i_q,
i_s = i_s)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
if self.n_shared_experts is not None and not self.enable_shared_experts_overlap:
......@@ -321,7 +346,9 @@ class DeepseekV2MoE(nn.Module):
router_logits=router_logits,
shared_output=shared_output)
else:
if self.enable_shared_experts_overlap:
if self.enable_shared_experts_fusion:
final_hidden_states = shared_exprts_fusion_pass(hidden_states, router_logits)
elif self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow
......@@ -364,7 +391,9 @@ class DeepseekV2MoE(nn.Module):
router_logits, _ = self.gate(hidden_states)
if self.enable_shared_experts_overlap:
if self.enable_shared_experts_fusion:
final_hidden_states = shared_exprts_fusion_pass(hidden_states, router_logits, iqis = iqis)
elif self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits, iqis = iqis)
# Fix FP16 overflow
......@@ -1516,7 +1545,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts
+ (
self.num_shared_experts
if (envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION and self.num_shared_experts > 0)
else 0
),
num_redundant_experts=self.num_redundant_experts)
params_dict = dict(self.named_parameters())
......@@ -1529,6 +1563,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if spec_layer is not None:
continue # skip spec decode layers for main model
# Assuame num of shared experts is only one.
if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION \
and self.num_shared_experts > 0 \
and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
from ._version import __version__, __version_tuple__
__version__ = "0.9.2"
__version_tuple__ = (0, 9, 2)
__hcu_version__ = f'0.9.2+das.opt5.test.dtk2604'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n{e}",
warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version.
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
......@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment