Unverified Commit 302ef403 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[DSA][MLA] Tiny refactor on DeepSeek to make it reusable for different backends (#26656)


Signed-off-by: default avatarMengqingCao <cmq0113@163.com>
parent 8865da15
......@@ -587,6 +587,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
**extra_impl_args,
):
super().__init__()
self.num_heads = num_heads
......@@ -639,6 +640,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj,
indexer=indexer,
**extra_impl_args,
)
self.use_direct_call = not current_platform.opaque_attention_op()
......
......@@ -17,9 +17,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name
from .deepseek_v2 import (
DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name,
)
from .interfaces import SupportsPP
from .utils import maybe_prefix
......@@ -56,6 +60,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.device = current_platform.device_type
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
......@@ -63,7 +69,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda",
device=self.device,
)
else:
topk_indices_buffer = None
......
......@@ -1165,6 +1165,7 @@ class DeepseekV2Model(nn.Module):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.device = current_platform.device_type
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
......@@ -1174,7 +1175,7 @@ class DeepseekV2Model(nn.Module):
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda",
device=self.device,
)
else:
topk_indices_buffer = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment