Unverified Commit 302ef403 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[DSA][MLA] Tiny refactor on DeepSeek to make it reusable for different backends (#26656)


Signed-off-by: default avatarMengqingCao <cmq0113@163.com>
parent 8865da15
...@@ -587,6 +587,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -587,6 +587,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
prefix: str = "", prefix: str = "",
use_sparse: bool = False, use_sparse: bool = False,
indexer: object | None = None, indexer: object | None = None,
**extra_impl_args,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
...@@ -639,6 +640,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -639,6 +640,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
v_head_dim=self.v_head_dim, v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj, kv_b_proj=kv_b_proj,
indexer=indexer, indexer=indexer,
**extra_impl_args,
) )
self.use_direct_call = not current_platform.opaque_attention_op() self.use_direct_call = not current_platform.opaque_attention_op()
......
...@@ -17,9 +17,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -17,9 +17,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name from .deepseek_v2 import (
DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name,
)
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
...@@ -56,6 +60,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -56,6 +60,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.device = current_platform.device_type
self.is_v32 = hasattr(config, "index_topk") self.is_v32 = hasattr(config, "index_topk")
if self.is_v32: if self.is_v32:
topk_tokens = config.index_topk topk_tokens = config.index_topk
...@@ -63,7 +69,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -63,7 +69,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
vllm_config.scheduler_config.max_num_batched_tokens, vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens, topk_tokens,
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=self.device,
) )
else: else:
topk_indices_buffer = None topk_indices_buffer = None
......
...@@ -1165,6 +1165,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1165,6 +1165,7 @@ class DeepseekV2Model(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.device = current_platform.device_type
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk") self.is_v32 = hasattr(config, "index_topk")
...@@ -1174,7 +1175,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1174,7 +1175,7 @@ class DeepseekV2Model(nn.Module):
vllm_config.scheduler_config.max_num_batched_tokens, vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens, topk_tokens,
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=self.device,
) )
else: else:
topk_indices_buffer = None topk_indices_buffer = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment