Commit bbe4df8b authored by wanglong3's avatar wanglong3
Browse files

feat: Support shared experts fusion.

feat: support moe sum when topk==9

bugfix: Fix mtp model load error when eable shared experts fusion.
parent 3a58da2c
...@@ -173,6 +173,39 @@ __global__ void moe_sum_kernel( ...@@ -173,6 +173,39 @@ __global__ void moe_sum_kernel(
} }
} }
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const int d) {
const int token_idx = blockIdx.x / SPLIT_D;
const int sub_block = blockIdx.x % SPLIT_D;
const int d_per_block = (d + SPLIT_D - 1) / SPLIT_D;
const int64_t d_start = sub_block * d_per_block;
const int64_t d_end = min(d_start + d_per_block, d);
const int64_t token_offset = token_idx * TOPK * d;
__shared__ __align__(16) scalar_t sem_input[TOPK][BLOCK_DIM];
for (int64_t idx = d_start + threadIdx.x; idx < d_end; idx += blockDim.x) {
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
sem_input[k][threadIdx.x] =
input[token_offset + k * d + idx];
}
__syncthreads();
scalar_t x = 0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += sem_input[k][threadIdx.x];
}
out[token_idx * d + idx] = x;
}
}
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM> template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem_topk8( __global__ void moe_sum_sharedmem_topk8(
scalar_t* __restrict__ out, scalar_t* __restrict__ out,
...@@ -440,7 +473,14 @@ void moe_sum_opt1(torch::Tensor& input, // [num_tokens, topk, hidden_size] ...@@ -440,7 +473,14 @@ void moe_sum_opt1(torch::Tensor& input, // [num_tokens, topk, hidden_size]
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size); hidden_size);
}); });
break;
case 9:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_sharedmem", [&]{
vllm::moe::moe_sum_sharedmem<scalar_t, 9, 9, 256><<<num_tokens * 9, 256, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break; break;
default: default:
......
...@@ -216,6 +216,7 @@ if TYPE_CHECKING: ...@@ -216,6 +216,7 @@ if TYPE_CHECKING:
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0 VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1 VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1 VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False VLLM_USE_MOE_W16A16_TRITON: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1352,9 +1353,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1352,9 +1353,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")), ("true", "1")),
# shared experts overlap with routed experts
# VLLM_DISABLE_SHARED_EXPERTS_STREAM = 1 disable shared experts overlap
# VLLM_DISABLE_SHARED_EXPERTS_STREAM = 0 enable shared experts overlap
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1")) int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
), ),
# shared experts fusion
# VLLM_ENABLE_SHARED_EXPERTS_FUSION = 1 enable shared experts fusion
# VLLM_ENABLE_SHARED_EXPERTS_FUSION = 0 disable shared experts fusion
"VLLM_ENABLE_SHARED_EXPERTS_FUSION": lambda: bool(
int(os.getenv("VLLM_ENABLE_SHARED_EXPERTS_FUSION", "0"))
),
# W8A8 GEMM backend selection for vLLM quantized models. # W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1 # lightop/triton: 1
......
...@@ -1304,7 +1304,9 @@ def grouped_topk( ...@@ -1304,7 +1304,9 @@ def grouped_topk(
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
num_fused_shared_experts: Optional[int] = 0
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), ( assert hidden_states.size(0) == gating_output.size(0), (
...@@ -1317,7 +1319,7 @@ def grouped_topk( ...@@ -1317,7 +1319,7 @@ def grouped_topk(
else: else:
raise ValueError(f"Unsupported scoring function: {scoring_func}") raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0) num_token, num_experts = scores.shape
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased # Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights # scores for expert selection but original scores for routing weights
...@@ -1348,8 +1350,20 @@ def grouped_topk( ...@@ -1348,8 +1350,20 @@ def grouped_topk(
dim=-1, dim=-1,
sorted=False) sorted=False)
if num_fused_shared_experts != 0:
topk_ids[:, -1] = num_experts
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
...@@ -2140,12 +2154,17 @@ def fused_moe( ...@@ -2140,12 +2154,17 @@ def fused_moe(
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
if use_grouped_topk: if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk_weights, topk_ids = grouped_topk(
topk, renormalize, hidden_states = hidden_states,
num_expert_group, topk_group) gating_output = gating_output,
topk = topk,
renormalize = renormalize,
num_expert_group = num_expert_group,
topk_group = topk_group,
routed_scaling_factor = routed_scaling_factor,
num_fused_shared_experts = 1)
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize) hidden_states, gating_output, topk, renormalize)
......
...@@ -1559,8 +1559,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1559,8 +1559,9 @@ class FusedMoE(torch.nn.Module):
num_expert_group, num_expert_group,
topk_group, topk_group,
top_k, top_k,
0, # TODO also required num of shared expert is not None
routed_scaling_factor, (1 if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION else 0),
routed_scaling_factor
) )
else: else:
topk_weights, topk_ids = ops.moe_fused_gate( topk_weights, topk_ids = ops.moe_fused_gate(
...@@ -1581,7 +1582,11 @@ class FusedMoE(torch.nn.Module): ...@@ -1581,7 +1582,11 @@ class FusedMoE(torch.nn.Module):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias = e_score_correction_bias,
routed_scaling_factor = routed_scaling_factor,
# TODO also required num of shared expert is not None
num_fused_shared_experts = (1 if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION else 0)
)
if indices_type is not None: if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type) topk_ids = topk_ids.to(dtype=indices_type)
elif custom_routing_function is None: elif custom_routing_function is None:
......
...@@ -252,7 +252,8 @@ def get_model_architecture( ...@@ -252,7 +252,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD") \
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
...@@ -301,7 +302,8 @@ def get_model_architecture( ...@@ -301,7 +302,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD") \
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
...@@ -370,7 +372,6 @@ def get_model_architecture( ...@@ -370,7 +372,6 @@ def get_model_architecture(
mixtral_supported = [ mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
] ]
vllm_supported_archs = ModelRegistry.get_supported_archs() vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures) for arch in architectures)
......
...@@ -235,12 +235,15 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -235,12 +235,15 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("qa_kva_proj", "kv_a_proj_with_mqa", 1) ("qa_kva_proj", "kv_a_proj_with_mqa", 1)
] ]
stacked_params_mapping += fused_params_mapping stacked_params_mapping += fused_params_mapping
enable_shared_experts_fusion = envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION and self.config.n_shared_experts > 0
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts) num_experts=self.config.n_routed_experts + (
self.config.n_shared_experts
if enable_shared_experts_fusion else 0
))
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
...@@ -251,6 +254,16 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -251,6 +254,16 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if spec_layer is None: if spec_layer is None:
continue continue
name = self._rewrite_spec_layer_name(spec_layer, name) name = self._rewrite_spec_layer_name(spec_layer, name)
# Assuame num of shared experts is only one.
if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION \
and self.config.n_shared_experts > 0 \
and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
...@@ -273,7 +286,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -273,7 +286,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)): if envs.USE_FUSED_RMS_QUANT \
and envs.VLLM_USE_FUSED_QA_KVA_GEMM \
and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
weight_loader(param, loaded_weight, old_weight_name) weight_loader(param, loaded_weight, old_weight_name)
else: else:
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -361,340 +376,4 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -361,340 +376,4 @@ class DeepSeekMTP(nn.Module, SupportsPP):
# treat rest weights as weights for transformer layer block # treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.", name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.") f"model.layers.{spec_layer}.mtp_block.")
return name return name
\ No newline at end of file
# # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# import os
# import re
# from collections.abc import Iterable
# from typing import Iterable, Optional
# import torch
# import torch.nn as nn
# from transformers import PretrainedConfig
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
# from vllm.model_executor.layers.fused_moe import FusedMoE
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
# from vllm.model_executor.layers.quantization import QuantizationConfig
# from vllm.model_executor.layers.vocab_parallel_embedding import (
# ParallelLMHead, VocabParallelEmbedding)
# from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors
# from vllm.compilation.decorators import support_torch_compile
# from .deepseek_v2 import (DeepseekV2DecoderLayer,
# get_spec_layer_idx_from_weight_name)
# from .interfaces import SupportsPP
# from .utils import maybe_prefix
# from vllm import _custom_ops as ops
# from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
# class SharedHead(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# return self.norm(hidden_states)
# class DeepSeekMultiTokenPredictorLayer(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# prefix: str,
# model_config: ModelConfig,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.eh_proj = nn.Linear(config.hidden_size * 2,
# config.hidden_size,
# bias=False)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
# self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
# cache_config, quant_config)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_index: int = 0,
# ) -> torch.Tensor:
# assert inputs_embeds is not None
# # masking inputs at position 0, as not needed by MTP
# inputs_embeds[positions == 0] = 0
# inputs_embeds = self.enorm(inputs_embeds)
# previous_hidden_states = self.hnorm(previous_hidden_states)
# hidden_states = self.eh_proj(
# torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
# hidden_states, residual = self.mtp_block(positions=positions,
# hidden_states=hidden_states,
# residual=None)
# hidden_states = residual + hidden_states
# return hidden_states
# class DeepSeekMultiTokenPredictor(nn.Module):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# config = vllm_config.model_config.hf_config
# self.mtp_start_layer_idx = config.num_hidden_layers
# self.num_mtp_layers = config.num_nextn_predict_layers
# # to map the exact layer index from weights
# self.layers = torch.nn.ModuleDict({
# str(idx):
# DeepSeekMultiTokenPredictorLayer(
# config,
# f"{prefix}.layers.{idx}",
# model_config=vllm_config.model_config,
# cache_config=vllm_config.cache_config,
# quant_config=vllm_config.quant_config,
# )
# for idx in range(self.mtp_start_layer_idx,
# self.mtp_start_layer_idx + self.num_mtp_layers)
# })
# self.embed_tokens = VocabParallelEmbedding(
# config.vocab_size,
# config.hidden_size,
# )
# self.logits_processor = LogitsProcessor(config.vocab_size)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
# input_ids,
# positions,
# previous_hidden_states,
# inputs_embeds,
# current_step_idx,
# )
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# mtp_layer = self.layers[str(self.mtp_start_layer_idx +
# current_step_idx)]
# logits = self.logits_processor(mtp_layer.shared_head.head,
# mtp_layer.shared_head(hidden_states),
# sampling_metadata)
# return logits
# @support_torch_compile
# class DeepSeekMTP(nn.Module, SupportsPP):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# self.config = vllm_config.model_config.hf_config
# quant_config = vllm_config.quant_config
# self.quant_method = None
# if quant_config is not None:
# self.quant_method = quant_config.get_name()
# os.environ['LLAMA_NN'] = '0'
# os.environ['LM_NN'] = '0'
# # The AWQ layer of MTP uses BlockInt8W8A8.
# if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
# vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
# self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
# prefix=maybe_prefix(
# prefix, "model"))
# self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# intermediate_tensors: Optional[IntermediateTensors] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# hidden_states = self.model(input_ids, positions,
# previous_hidden_states, inputs_embeds,
# spec_step_idx)
# return hidden_states
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> Optional[torch.Tensor]:
# return self.model.compute_logits(hidden_states, sampling_metadata,
# spec_step_idx)
# def load_weights(self, weights: Iterable[tuple[str,
# torch.Tensor]]) -> set[str]:
# stacked_params_mapping = [
# ("gate_up_proj", "gate_proj", 0),
# ("gate_up_proj", "up_proj", 1),
# ]
# expert_params_mapping = FusedMoE.make_expert_params_mapping(
# ckpt_gate_proj_name="gate_proj",
# ckpt_down_proj_name="down_proj",
# ckpt_up_proj_name="up_proj",
# num_experts=self.config.n_routed_experts)
# params_dict = dict(self.named_parameters())
# loaded_params: set[str] = set()
# for name, loaded_weight in weights:
# if "rotary_emb.inv_freq" in name:
# continue
# spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
# if spec_layer is None:
# continue
# name = self._rewrite_spec_layer_name(spec_layer, name)
# for (param_name, weight_name, shard_id) in stacked_params_mapping:
# # Skip non-stacked layers and experts (experts handled below).
# if weight_name not in name:
# continue
# # We have mlp.experts[0].gate_proj in the checkpoint.
# # Since we handle the experts below in expert_params_mapping,
# # we need to skip here BEFORE we update the name, otherwise
# # name will be updated to mlp.experts[0].gate_up_proj, which
# # will then be updated below in expert_params_mapping
# # for mlp.experts[0].gate_gate_up_proj, which breaks load.
# if (("mlp.experts." in name) and name not in params_dict):
# continue
# name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param, loaded_weight, shard_id)
# break
# else:
# for mapping in expert_params_mapping:
# param_name, weight_name, expert_id, shard_id = mapping
# if weight_name not in name:
# continue
# name = name.replace(weight_name, param_name)
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param,
# loaded_weight,
# name,
# shard_id=shard_id,
# expert_id=expert_id)
# break
# else:
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# # According to DeepSeek-V3 Technical Report, MTP modules
# # shares embedding layer. We only load the first weights.
# if (spec_layer != self.model.mtp_start_layer_idx
# and ".layers" not in name):
# continue
# param = params_dict[name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# weight_loader(param, loaded_weight)
# loaded_params.add(name)
# if self.use_llama_nn and self.quant_method is None:
# lay_key_words = [
# "self_attn.eh_proj.weight",
# "self_attn.q_proj.weight",
# "self_attn.q_a_proj.weight",
# "self_attn.q_b_proj.weight",
# "self_attn.kv_a_proj_with_mqa.weight",
# "self_attn.kv_b_proj.weight",
# "self_attn.o_proj.weight",
# "mlp.gate_up_proj.weight",
# "mlp.down_proj.weight",
# "mlp.gate.weight",
# "shared_experts.gate_up_proj.weight",
# "shared_experts.down_proj.weight",
# "shared_head.head.weight",
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# _weight = torch.zeros_like(weight.data)
# ori_shape =_weight.shape
# ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
# weight.data.copy_(_weight)
# weight.data=weight.data.reshape(ori_shape[1],-1)
# return loaded_params
# def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
# """
# Rewrite the weight name to match the format of the original model.
# Add .mtp_block for modules in transformer layer block for spec layer
# and rename shared layer weights to be top level.
# """
# spec_layer_weight_names = [
# "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
# ]
# shared_weight_names = ["embed_tokens"]
# spec_layer_weight = False
# shared_weight = False
# for weight_name in spec_layer_weight_names:
# if weight_name in name:
# spec_layer_weight = True
# if weight_name in shared_weight_names:
# shared_weight = True
# break
# if not spec_layer_weight:
# # treat rest weights as weights for transformer layer block
# name = name.replace(f"model.layers.{spec_layer}.",
# f"model.layers.{spec_layer}.mtp_block.")
# elif shared_weight:
# # treat shared weights as top level weights
# name = name.replace(f"model.layers.{spec_layer}.", "model.")
# return name
...@@ -74,7 +74,6 @@ from vllm import _custom_ops as ops ...@@ -74,7 +74,6 @@ from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from lmslim.quantize.quant_ops import lm_faster_rmsquant from lmslim.quantize.quant_ops import lm_faster_rmsquant
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -194,8 +193,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -194,8 +193,10 @@ class DeepseekV2MoE(nn.Module):
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto") envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.enable_shared_experts_overlap = False self.enable_shared_experts_overlap = False
self.enable_shared_experts_fusion = (self.n_shared_experts != 0 and envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION)
if not self.use_deepep: if not self.use_deepep:
if config.n_shared_experts is not None: if config.n_shared_experts is not None and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2MLP(
...@@ -230,9 +231,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -230,9 +231,13 @@ class DeepseekV2MoE(nn.Module):
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor) routed_scaling_factor=self.routed_scaling_factor)
else: else:
num_fake_experts = (
config.n_shared_experts
if self.enable_shared_experts_fusion
else 0)
self.experts = FusedMoE( self.experts = FusedMoE(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts + (num_fake_experts),
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok + min(num_fake_experts, 1),
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False, reduce_results=False,
...@@ -278,7 +283,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -278,7 +283,11 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
shared_experts=self.shared_experts) shared_experts=self.shared_experts)
self.run_shared_expert_singlely = (self.n_shared_experts is not None and not self.enable_shared_experts_overlap) self.run_shared_expert_singlely = (
self.n_shared_experts is not None
and not self.enable_shared_experts_overlap
and not envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
...@@ -292,6 +301,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -292,6 +301,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
# For shared experts overlap optimization.
def shared_exprts_overlap_pass( def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor, hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
...@@ -307,6 +317,21 @@ class DeepseekV2MoE(nn.Module): ...@@ -307,6 +317,21 @@ class DeepseekV2MoE(nn.Module):
hidden_states_copy = hidden_states_copy, hidden_states_copy = hidden_states_copy,
i_q = i_q, i_q = i_q,
i_s = i_s) i_s = i_s)
# For shared experts fusion optimization.
def shared_exprts_fusion_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
return self.experts(
hidden_states = hidden_states,
router_logits = router_logits,
i_q = i_q,
i_s = i_s)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
...@@ -321,7 +346,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -321,7 +346,9 @@ class DeepseekV2MoE(nn.Module):
router_logits=router_logits, router_logits=router_logits,
shared_output=shared_output) shared_output=shared_output)
else: else:
if self.enable_shared_experts_overlap: if self.enable_shared_experts_fusion:
final_hidden_states = shared_exprts_fusion_pass(hidden_states, router_logits)
elif self.enable_shared_experts_overlap:
assert self.shared_experts is not None assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits) shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow # Fix FP16 overflow
...@@ -364,7 +391,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -364,7 +391,9 @@ class DeepseekV2MoE(nn.Module):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if self.enable_shared_experts_overlap: if self.enable_shared_experts_fusion:
final_hidden_states = shared_exprts_fusion_pass(hidden_states, router_logits, iqis = iqis)
elif self.enable_shared_experts_overlap:
assert self.shared_experts is not None assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits, iqis = iqis) shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits, iqis = iqis)
# Fix FP16 overflow # Fix FP16 overflow
...@@ -1516,7 +1545,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1516,7 +1545,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts, num_experts=self.config.n_routed_experts
+ (
self.num_shared_experts
if (envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION and self.num_shared_experts > 0)
else 0
),
num_redundant_experts=self.num_redundant_experts) num_redundant_experts=self.num_redundant_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -1529,6 +1563,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1529,6 +1563,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if spec_layer is not None: if spec_layer is not None:
continue # skip spec decode layers for main model continue # skip spec decode layers for main model
# Assuame num of shared experts is only one.
if envs.VLLM_ENABLE_SHARED_EXPERTS_FUSION \
and self.num_shared_experts > 0 \
and "mlp.shared_experts" in name:
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment