Unverified Commit 845adb3e authored by XuruiYang's avatar XuruiYang Committed by GitHub
Browse files

[Model] Add LongCat-Flash (#23991)


Signed-off-by: default avataryangxurui <yangxurui@meituan.com>
Co-authored-by: default avataryangxurui <yangxurui@meituan.com>
parent 90b139cf
...@@ -718,7 +718,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -718,7 +718,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -783,7 +783,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -783,7 +783,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -894,7 +894,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -894,7 +894,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
......
...@@ -329,7 +329,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -329,7 +329,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -531,7 +531,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): ...@@ -531,7 +531,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
......
...@@ -318,7 +318,7 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -318,7 +318,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
......
...@@ -292,6 +292,11 @@ def is_layer_skipped( ...@@ -292,6 +292,11 @@ def is_layer_skipped(
f"Detected some but not all shards of {prefix} " f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers " "are quantized. All shards of fused layers "
"to have the same precision.") "to have the same precision.")
elif "experts" in prefix:
return any([
prefix in layer_name for layer_name in ignored_layers
if "experts" in layer_name
])
else: else:
is_skipped = prefix in ignored_layers is_skipped = prefix in ignored_layers
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
block_dequant)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.longcat_flash import FlashConfig
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer
from .interfaces import SupportsPP
from .utils import maybe_prefix
class LongCatMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = ReplicatedLinear(2 * config.hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix="eh_proj")
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states, _ = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
class LongCatMultiTokenPredictor(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
vllm_config.model_config.hf_config.intermediate_size \
= config.intermediate_size
self.mtp_start_layer_idx = config.num_hidden_layers * 2
self.num_mtp_layers = 1
self.layers = torch.nn.ModuleDict({
str(idx):
LongCatMultiTokenPredictorLayer(
config,
prefix=f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
quant_config=quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
class LongCatFlashMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
# LongCat MTP without MoE layers
vllm_config.model_config.hf_config.n_routed_experts = None
self.config = FlashConfig(
**vllm_config.model_config.hf_config.__dict__)
self.quant_config = None if "mtp" in getattr(
self.config, "disable_quant_module",
[]) else vllm_config.quant_config
self.model = LongCatMultiTokenPredictor(vllm_config=vllm_config,
quant_config=self.quant_config,
prefix=maybe_prefix(
prefix, "model"))
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=self.quant_config,
)
self.logits_processor = LogitsProcessor(self.config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
new_to_old_names_mapping = {
"model.mtp.embed_tokens.weight":
"model.layers.0.embed_tokens.weight",
"model.mtp.layers.0.eh_proj.weight": "eh_proj.weight",
"model.mtp.layers.0.eh_proj.weight_scale_inv":
"eh_proj.weight_scale_inv",
"model.mtp.layers.0.enorm.m.weight": "enorm.weight",
"model.mtp.layers.0.hnorm.m.weight": "hnorm.weight",
"model.mtp.layers.0.input_layernorm.weight":
"model.layers.0.input_layernorm.weight",
"model.mtp.layers.0.post_attention_layernorm.weight":
"model.layers.0.post_attention_layernorm.weight",
"model.mtp.layers.0.self_attn.kv_a_layernorm.weight":
"model.layers.0.self_attn.kv_a_layernorm.weight",
"model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight":
"model.layers.0.self_attn.kv_a_proj_with_mqa.weight",
"model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv":
"model.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv",
"model.mtp.layers.0.self_attn.kv_b_proj.weight":
"model.layers.0.self_attn.kv_b_proj.weight",
"model.mtp.layers.0.self_attn.kv_b_proj.weight_scale_inv":
"model.layers.0.self_attn.kv_b_proj.weight_scale_inv",
"model.mtp.layers.0.self_attn.o_proj.weight":
"model.layers.0.self_attn.o_proj.weight",
"model.mtp.layers.0.self_attn.o_proj.weight_scale_inv":
"model.layers.0.self_attn.o_proj.weight_scale_inv",
"model.mtp.layers.0.self_attn.q_a_layernorm.weight":
"model.layers.0.self_attn.q_a_layernorm.weight",
"model.mtp.layers.0.self_attn.q_a_proj.weight":
"model.layers.0.self_attn.q_a_proj.weight",
"model.mtp.layers.0.self_attn.q_a_proj.weight_scale_inv":
"model.layers.0.self_attn.q_a_proj.weight_scale_inv",
"model.mtp.layers.0.self_attn.q_b_proj.weight":
"model.layers.0.self_attn.q_b_proj.weight",
"model.mtp.layers.0.self_attn.q_b_proj.weight_scale_inv":
"model.layers.0.self_attn.q_b_proj.weight_scale_inv",
"model.mtp.layers.0.transformer_layer.mlp.down_proj.weight":
"model.layers.0.mlp.down_proj.weight",
"model.mtp.layers.0.transformer_layer.mlp.down_proj.weight_scale_inv":
"model.layers.0.mlp.down_proj.weight_scale_inv",
"model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight":
"model.layers.0.mlp.gate_proj.weight",
"model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight_scale_inv":
"model.layers.0.mlp.gate_proj.weight_scale_inv",
"model.mtp.layers.0.transformer_layer.mlp.up_proj.weight":
"model.layers.0.mlp.up_proj.weight",
"model.mtp.layers.0.transformer_layer.mlp.up_proj.weight_scale_inv":
"model.layers.0.mlp.up_proj.weight_scale_inv",
"model.mtp.norm.weight": "final_layernorm.weight",
}
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = self.get_spec_layer_idx_from_weight_name(
self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name,
new_to_old_names_mapping)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
if ((param_name == "fused_qkv_a_proj")
and name not in params_dict):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
spec_layer_id = self.config.num_hidden_layers * 2
self_attn = self.model.layers[str(spec_layer_id)].mtp_block.self_attn
if hasattr(
self.quant_config,
"weight_block_size") and self_attn.kv_b_proj.weight.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
dtype = torch.get_default_dtype()
w = block_dequant(self_attn.kv_b_proj.weight,
self_attn.kv_b_proj.weight_scale_inv,
weight_block_size).to(dtype)
else:
w = self_attn.kv_b_proj.weight
else:
w = self_attn.kv_b_proj.weight
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split(
[self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if self.config.mla_scale_q_lora:
self_attn.q_a_layernorm.weight.data *= (
self.config.hidden_size / self.config.q_lora_rank)**0.5
if self.config.mla_scale_kv_lora:
self_attn.kv_a_layernorm.weight.data *= (
self.config.hidden_size / self.config.kv_lora_rank)**0.5
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str,
new_to_old_names_mapping: dict) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
if name in new_to_old_names_mapping:
name = new_to_old_names_mapping[name]
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
if name.startswith("enorm") or name.startswith(
"hnorm") or name.startswith("eh_proj") or name.startswith(
"final_layernorm"):
name = "model.layers." + str(spec_layer) + "." + name
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace("model.layers.0.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# treat shared weights as top level weights
name = name.replace("model.layers.0.", "model.")
return name
def get_spec_layer_idx_from_weight_name(self, config: PretrainedConfig,
weight_name: str) -> Optional[int]:
if "model.mtp" in weight_name:
return config.num_hidden_layers * 2
return None
...@@ -109,6 +109,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -109,6 +109,7 @@ _TEXT_GENERATION_MODELS = {
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501 "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"), "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
...@@ -287,6 +288,7 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -287,6 +288,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"), "MedusaModel": ("medusa", "Medusa"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
......
...@@ -691,14 +691,14 @@ def maybe_prefix(prefix: str, name: str) -> str: ...@@ -691,14 +691,14 @@ def maybe_prefix(prefix: str, name: str) -> str:
return name if not prefix else f"{prefix}.{name}" return name if not prefix else f"{prefix}.{name}"
def extract_layer_index(layer_name: str) -> int: def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
""" """
Extract the layer index from the module name. Extract the layer index from the module name.
Examples: Examples:
- "encoder.layers.0" -> 0 - "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1 - "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2 - "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
""" """
subnames = layer_name.split(".") subnames = layer_name.split(".")
int_vals: list[int] = [] int_vals: list[int] = []
...@@ -707,9 +707,17 @@ def extract_layer_index(layer_name: str) -> int: ...@@ -707,9 +707,17 @@ def extract_layer_index(layer_name: str) -> int:
int_vals.append(int(subname)) int_vals.append(int(subname))
except ValueError: except ValueError:
continue continue
if num_attn_module == 1 or "attn" not in layer_name:
assert len(int_vals) == 1, (f"layer name {layer_name} should" assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer") " only contain one integer")
return int_vals[0] return int_vals[0]
else:
assert len(int_vals) <= 2, (f"layer name {layer_name} should"
" contain most two integers")
layer_index = int_vals[0] * num_attn_module + int_vals[1] if len(
int_vals) == 2 else int_vals[0]
return layer_index
def cast_overflow_tensors( def cast_overflow_tensors(
......
...@@ -169,7 +169,6 @@ class EagleProposer: ...@@ -169,7 +169,6 @@ class EagleProposer:
target_hidden_states = self.model.combine_hidden_states( target_hidden_states = self.model.combine_hidden_states(
target_hidden_states) target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token. # Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:] self.input_ids[:num_tokens - 1] = target_token_ids[1:]
...@@ -223,7 +222,8 @@ class EagleProposer: ...@@ -223,7 +222,8 @@ class EagleProposer:
hidden_states=self.hidden_states[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"): if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp",
"longcat_flash_mtp"):
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states hidden_states = last_hidden_states
else: else:
...@@ -237,6 +237,9 @@ class EagleProposer: ...@@ -237,6 +237,9 @@ class EagleProposer:
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, 1)
positions = target_positions[last_token_indices] positions = target_positions[last_token_indices]
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
hidden_states = self.hidden_states[last_token_indices]
else:
hidden_states = hidden_states[last_token_indices] hidden_states = hidden_states[last_token_indices]
if isinstance(attn_metadata, TreeAttentionMetadata): if isinstance(attn_metadata, TreeAttentionMetadata):
...@@ -350,7 +353,7 @@ class EagleProposer: ...@@ -350,7 +353,7 @@ class EagleProposer:
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if self.method in ("deepseek_mtp", "ernie_mtp", if self.method in ("deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp"): "qwen3_next_mtp", "longcat_flash_mtp"):
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states hidden_states = ret_hidden_states
else: else:
......
...@@ -3840,9 +3840,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3840,9 +3840,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_layer_name) target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name] kv_caches[layer_name] = kv_caches[target_layer_name]
num_attn_module = 2 \
if self.model_config.hf_config.model_type == "longcat_flash" else 1
bind_kv_cache(kv_caches, bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context, self.compilation_config.static_forward_context,
self.kv_caches) self.kv_caches, num_attn_module)
return kv_caches return kv_caches
def maybe_add_kv_sharing_layers_to_kv_cache_groups( def maybe_add_kv_sharing_layers_to_kv_cache_groups(
......
...@@ -266,6 +266,7 @@ def bind_kv_cache( ...@@ -266,6 +266,7 @@ def bind_kv_cache(
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"], forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor], runner_kv_caches: list[torch.Tensor],
num_attn_module: Optional[int] = 1,
) -> None: ) -> None:
""" """
Bind the allocated KV cache to both ModelRunner and forward context so Bind the allocated KV cache to both ModelRunner and forward context so
...@@ -289,7 +290,8 @@ def bind_kv_cache( ...@@ -289,7 +290,8 @@ def bind_kv_cache(
# Convert kv_caches dict to a list of tensors in the order of layer_index. # Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list) index2name = defaultdict(list)
for layer_name in kv_caches: for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name) index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()): for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index] layer_names = index2name[layer_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment