Commit e7ffe66a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[fix]修复mtp eager模式下显存占用增加问题

See merge request dcutoolkit/deeplearing/vllm!180
parents 693d5ed4 49559d79
......@@ -58,6 +58,11 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
......@@ -75,6 +80,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
......@@ -111,10 +118,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
......@@ -125,8 +129,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
......@@ -308,25 +310,353 @@ class DeepSeekMTP(nn.Module, SupportsPP):
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
# # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# import os
# import re
# from collections.abc import Iterable
# from typing import Iterable, Optional
# import torch
# import torch.nn as nn
# from transformers import PretrainedConfig
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
# from vllm.model_executor.layers.fused_moe import FusedMoE
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
# from vllm.model_executor.layers.quantization import QuantizationConfig
# from vllm.model_executor.layers.vocab_parallel_embedding import (
# ParallelLMHead, VocabParallelEmbedding)
# from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors
# from vllm.compilation.decorators import support_torch_compile
# from .deepseek_v2 import (DeepseekV2DecoderLayer,
# get_spec_layer_idx_from_weight_name)
# from .interfaces import SupportsPP
# from .utils import maybe_prefix
# from vllm import _custom_ops as ops
# from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
# class SharedHead(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# return self.norm(hidden_states)
# class DeepSeekMultiTokenPredictorLayer(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# prefix: str,
# model_config: ModelConfig,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.eh_proj = nn.Linear(config.hidden_size * 2,
# config.hidden_size,
# bias=False)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
# self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
# cache_config, quant_config)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_index: int = 0,
# ) -> torch.Tensor:
# assert inputs_embeds is not None
# # masking inputs at position 0, as not needed by MTP
# inputs_embeds[positions == 0] = 0
# inputs_embeds = self.enorm(inputs_embeds)
# previous_hidden_states = self.hnorm(previous_hidden_states)
# hidden_states = self.eh_proj(
# torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
# hidden_states, residual = self.mtp_block(positions=positions,
# hidden_states=hidden_states,
# residual=None)
# hidden_states = residual + hidden_states
# return hidden_states
# class DeepSeekMultiTokenPredictor(nn.Module):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# config = vllm_config.model_config.hf_config
# self.mtp_start_layer_idx = config.num_hidden_layers
# self.num_mtp_layers = config.num_nextn_predict_layers
# # to map the exact layer index from weights
# self.layers = torch.nn.ModuleDict({
# str(idx):
# DeepSeekMultiTokenPredictorLayer(
# config,
# f"{prefix}.layers.{idx}",
# model_config=vllm_config.model_config,
# cache_config=vllm_config.cache_config,
# quant_config=vllm_config.quant_config,
# )
# for idx in range(self.mtp_start_layer_idx,
# self.mtp_start_layer_idx + self.num_mtp_layers)
# })
# self.embed_tokens = VocabParallelEmbedding(
# config.vocab_size,
# config.hidden_size,
# )
# self.logits_processor = LogitsProcessor(config.vocab_size)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
# input_ids,
# positions,
# previous_hidden_states,
# inputs_embeds,
# current_step_idx,
# )
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# mtp_layer = self.layers[str(self.mtp_start_layer_idx +
# current_step_idx)]
# logits = self.logits_processor(mtp_layer.shared_head.head,
# mtp_layer.shared_head(hidden_states),
# sampling_metadata)
# return logits
# @support_torch_compile
# class DeepSeekMTP(nn.Module, SupportsPP):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# self.config = vllm_config.model_config.hf_config
# quant_config = vllm_config.quant_config
# self.quant_method = None
# if quant_config is not None:
# self.quant_method = quant_config.get_name()
# os.environ['LLAMA_NN'] = '0'
# os.environ['LM_NN'] = '0'
# # The AWQ layer of MTP uses BlockInt8W8A8.
# if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
# vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
# self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
# prefix=maybe_prefix(
# prefix, "model"))
# self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# intermediate_tensors: Optional[IntermediateTensors] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# hidden_states = self.model(input_ids, positions,
# previous_hidden_states, inputs_embeds,
# spec_step_idx)
# return hidden_states
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> Optional[torch.Tensor]:
# return self.model.compute_logits(hidden_states, sampling_metadata,
# spec_step_idx)
# def load_weights(self, weights: Iterable[tuple[str,
# torch.Tensor]]) -> set[str]:
# stacked_params_mapping = [
# ("gate_up_proj", "gate_proj", 0),
# ("gate_up_proj", "up_proj", 1),
# ]
# expert_params_mapping = FusedMoE.make_expert_params_mapping(
# ckpt_gate_proj_name="gate_proj",
# ckpt_down_proj_name="down_proj",
# ckpt_up_proj_name="up_proj",
# num_experts=self.config.n_routed_experts)
# params_dict = dict(self.named_parameters())
# loaded_params: set[str] = set()
# for name, loaded_weight in weights:
# if "rotary_emb.inv_freq" in name:
# continue
# spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
# if spec_layer is None:
# continue
# name = self._rewrite_spec_layer_name(spec_layer, name)
# for (param_name, weight_name, shard_id) in stacked_params_mapping:
# # Skip non-stacked layers and experts (experts handled below).
# if weight_name not in name:
# continue
# # We have mlp.experts[0].gate_proj in the checkpoint.
# # Since we handle the experts below in expert_params_mapping,
# # we need to skip here BEFORE we update the name, otherwise
# # name will be updated to mlp.experts[0].gate_up_proj, which
# # will then be updated below in expert_params_mapping
# # for mlp.experts[0].gate_gate_up_proj, which breaks load.
# if (("mlp.experts." in name) and name not in params_dict):
# continue
# name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param, loaded_weight, shard_id)
# break
# else:
# for mapping in expert_params_mapping:
# param_name, weight_name, expert_id, shard_id = mapping
# if weight_name not in name:
# continue
# name = name.replace(weight_name, param_name)
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param,
# loaded_weight,
# name,
# shard_id=shard_id,
# expert_id=expert_id)
# break
# else:
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# # According to DeepSeek-V3 Technical Report, MTP modules
# # shares embedding layer. We only load the first weights.
# if (spec_layer != self.model.mtp_start_layer_idx
# and ".layers" not in name):
# continue
# param = params_dict[name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# weight_loader(param, loaded_weight)
# loaded_params.add(name)
# if self.use_llama_nn and self.quant_method is None:
# lay_key_words = [
# "self_attn.eh_proj.weight",
# "self_attn.q_proj.weight",
# "self_attn.q_a_proj.weight",
# "self_attn.q_b_proj.weight",
# "self_attn.kv_a_proj_with_mqa.weight",
# "self_attn.kv_b_proj.weight",
# "self_attn.o_proj.weight",
# "mlp.gate_up_proj.weight",
# "mlp.down_proj.weight",
# "mlp.gate.weight",
# "shared_experts.gate_up_proj.weight",
# "shared_experts.down_proj.weight",
# "shared_head.head.weight",
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# _weight = torch.zeros_like(weight.data)
# ori_shape =_weight.shape
# ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
# weight.data.copy_(_weight)
# weight.data=weight.data.reshape(ori_shape[1],-1)
# return loaded_params
# def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
# """
# Rewrite the weight name to match the format of the original model.
# Add .mtp_block for modules in transformer layer block for spec layer
# and rename shared layer weights to be top level.
# """
# spec_layer_weight_names = [
# "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
# ]
# shared_weight_names = ["embed_tokens"]
# spec_layer_weight = False
# shared_weight = False
# for weight_name in spec_layer_weight_names:
# if weight_name in name:
# spec_layer_weight = True
# if weight_name in shared_weight_names:
# shared_weight = True
# break
# if not spec_layer_weight:
# # treat rest weights as weights for transformer layer block
# name = name.replace(f"model.layers.{spec_layer}.",
# f"model.layers.{spec_layer}.mtp_block.")
# elif shared_weight:
# # treat shared weights as top level weights
# name = name.replace(f"model.layers.{spec_layer}.", "model.")
# return name
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment