Commit 44d451ff authored by 王敏's avatar 王敏
Browse files

优化ep moe

parent 36e35fac
...@@ -307,8 +307,9 @@ __global__ void ep_moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ...@@ -307,8 +307,9 @@ __global__ void ep_moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* assigned to expert expert_index. * assigned to expert expert_index.
*/ */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
if (topk_ids[i] >= start_expert && topk_ids[i] < end_expert) { expert_id = topk_ids[i];
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i] - start_expert)]; if (expert_id >= start_expert && expert_id < end_expert) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, expert_id - start_expert)];
} }
} }
......
...@@ -420,6 +420,7 @@ class EngineArgs: ...@@ -420,6 +420,7 @@ class EngineArgs:
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.') help='Number of tensor parallel replicas.')
parser.add_argument('--moe-ep-size', parser.add_argument('--moe-ep-size',
'-ep',
type=int, type=int,
default=EngineArgs.moe_ep_size, default=EngineArgs.moe_ep_size,
help='Number of moe expert parallel replicas.') help='Number of moe expert parallel replicas.')
......
...@@ -1345,9 +1345,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1345,9 +1345,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
if moe_ep_size == 1: if moe_ep_size == 1:
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
......
...@@ -234,6 +234,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -234,6 +234,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1, moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1, start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1 end_expert: Optional[int] = -1
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model.""" """Inference-only DeepseekV2/DeepseekV3 model."""
import os
import re
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
...@@ -56,6 +58,7 @@ from .interfaces import SupportsPP ...@@ -56,6 +58,7 @@ from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
...@@ -100,6 +103,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -100,6 +103,7 @@ class DeepseekV2MoE(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
moe_ep_size: int = 1
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -139,7 +143,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -139,7 +143,8 @@ class DeepseekV2MoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias) e_score_correction_bias=self.gate.e_score_correction_bias,
moe_ep_size=moe_ep_size)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
...@@ -490,6 +495,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -490,6 +495,7 @@ class DeepseekV2DecoderLayer(nn.Module):
model_config: ModelConfig, model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
moe_ep_size : int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -529,6 +535,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -529,6 +535,7 @@ class DeepseekV2DecoderLayer(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
moe_ep_size=moe_ep_size
) )
else: else:
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
...@@ -577,7 +584,7 @@ class DeepseekV2Model(nn.Module): ...@@ -577,7 +584,7 @@ class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", moe_ep_size: int = 1):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -604,6 +611,7 @@ class DeepseekV2Model(nn.Module): ...@@ -604,6 +611,7 @@ class DeepseekV2Model(nn.Module):
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
moe_ep_size=moe_ep_size
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
...@@ -662,8 +670,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -662,8 +670,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.parallel_config = vllm_config.parallel_config
self.moe_ep_size = self.parallel_config.moe_ep_size
self.model = DeepseekV2Model(vllm_config=vllm_config, self.model = DeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"),
moe_ep_size=self.moe_ep_size)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
...@@ -672,6 +684,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -672,6 +684,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -730,11 +748,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -730,11 +748,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
if self.moe_ep_size == 1:
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts) num_experts=self.config.n_routed_experts)
else:
expert_params_mapping = FusedMoE.make_expert_params_mapping_ep(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
moe_ep_size=self.moe_ep_size)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
...@@ -803,6 +829,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -803,6 +829,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
continue continue
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue
# Skip loading extra expert weights for ep moe mode
if name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
...@@ -810,6 +840,37 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -810,6 +840,37 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"mlp.gate.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj",
"shared_experts.gate_up_proj",
"shared_experts.down_proj",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight"
]
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params return loaded_params
......
...@@ -683,9 +683,6 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -683,9 +683,6 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_mla = False
if hasattr(vllm_config.model_config, "use_mla"):
self.use_mla = vllm_config.model_config.use_mla
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
...@@ -753,7 +750,7 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -753,7 +750,7 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts) num_experts=self.config.n_routed_experts)
else: else:
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping_ep(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment