Unverified Commit fa7e254a authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files
parent e23cacda
......@@ -35,7 +35,6 @@ def test_chunked_local_attention_possible_cached_prefix():
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
......@@ -100,7 +99,6 @@ def test_sliding_window_possible_cached_prefix():
head_size=1,
dtype=torch.float32,
sliding_window=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
......@@ -165,7 +163,6 @@ def test_chunked_local_attention_remove_skipped_blocks():
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
......@@ -217,7 +214,6 @@ def test_sliding_window_remove_skipped_blocks():
head_size=1,
dtype=torch.float32,
sliding_window=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
......@@ -285,7 +281,6 @@ def test_get_num_blocks_to_allocate():
head_size=1,
dtype=torch.float32,
sliding_window=4, # Placeholder value, not related to test result
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
......@@ -308,7 +303,6 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
head_size=1,
dtype=torch.float32,
attention_chunk_size=4, # Placeholder value, not related to test result
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
......
......@@ -836,8 +836,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
mock_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
dtype=torch.float16)
mock_executor.get_kv_cache_specs.return_value = [{
"default": mock_spec
......
......@@ -39,7 +39,6 @@ def initialize_kv_cache(runner: GPUModelRunner):
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
)
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig(
......
......@@ -1678,6 +1678,15 @@ def cp_gather_cache(src_cache: torch.Tensor,
cu_seq_lens, batch_size, seq_starts)
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
quant_block_size: int,
kv_cache_dtype: str) -> None:
torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping,
quant_block_size,
kv_cache_dtype)
def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
......
......@@ -70,6 +70,7 @@ class AttentionBackend(ABC):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
raise NotImplementedError
......
......@@ -94,6 +94,7 @@ class Attention(nn.Module, AttentionLayerBase):
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
use_sparse: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
......@@ -154,6 +155,7 @@ class Attention(nn.Module, AttentionLayerBase):
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.use_sparse = use_sparse
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
......@@ -186,7 +188,8 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dtype,
block_size,
use_mla=use_mla,
has_sink=self.has_sink)
has_sink=self.has_sink,
use_sparse=use_sparse)
else:
self.attn_backend = attn_backend
......
......@@ -138,3 +138,208 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out = cp_group.reduce_scatter(out, dim=1)
return out
@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
out_ptr, # [B, Lmax, D]
lengths_ptr, # *i32, [B]
N: tl.constexpr,
D: tl.constexpr,
Lmax: tl.constexpr,
PAD_VALUE: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# Compute start index and sequence length from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
# compute input row indices for valid (b, t)
in_row = in_start + off_t
valid_row = (off_t < seq_len) & t_mask
# Pointers
# x_ptr: row-major [N, D]
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
# out_ptr: row-major [B, Lmax, D]
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
None] * D + off_d[None, :]
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
d_mask = off_d[None, :] < D
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
# Load & write only where within seq_len
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
def pack_seq_triton(x: torch.Tensor,
lengths: torch.Tensor,
pad_value: float = -float('inf'),
block_t: int = 64,
block_d: int = 64) -> torch.Tensor:
"""
Pack sequences of different lengths into a batched tensor.
Args:
x: [N, ...] - input tensor where N is total number of tokens
lengths: [B] - sequence lengths for each batch
pad_value: value to use for padding
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
packed: [B, Lmax, ...] - packed tensor
"""
# Handle multi-dimensional input by reshaping to (N, -1)
original_shape = x.shape
if len(original_shape) > 2:
N = original_shape[0]
x_reshaped = x.reshape(N, -1)
D = x_reshaped.shape[1]
else:
N, D = x.shape
x_reshaped = x
B = lengths.numel()
Lmax = int(lengths.max().item())
# Starts are computed inside the kernel from lengths
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_pack_seq_kernel[grid](x_reshaped,
out,
lengths.int(),
N,
D,
Lmax,
PAD_VALUE=float(pad_value),
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 2:
output_shape = (B, Lmax) + original_shape[1:]
out = out.reshape(output_shape)
return out
@triton.jit
def _unpack_seq_triton_kernel(
packed_ptr, # [B, Lmax, D]
out_ptr, # [N, D]
lengths_ptr, # *i32, [B]
B: tl.constexpr,
Lmax: tl.constexpr,
D: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# bounds: compute start from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
valid_row = (off_t < seq_len) & t_mask
# compute output row indices for valid (b, t)
out_row = in_start + off_t
# Pointers
# packed_ptr: row-major [B, Lmax, D]
packed_row_ptr = packed_ptr + (pid_b * Lmax +
off_t)[:, None] * D + off_d[None, :]
# out_ptr: row-major [N, D]
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
# Load from packed tensor and store to output
d_mask = off_d[None, :] < D
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
def unpack_seq_triton(packed_tensor: torch.Tensor,
lengths: torch.Tensor,
block_t: int = 64,
block_d: int = 64) -> torch.Tensor:
"""
Unpack a packed decode query tensor back to the original format.
Efficient Triton implementation.
Args:
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
lengths: [B] - sequence lengths for each batch
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
unpacked_tensor: [N, ...] where N = sum(lengths)
"""
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
original_shape = packed_tensor.shape
if len(original_shape) > 3:
B, Lmax = original_shape[:2]
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
D = packed_reshaped.shape[2]
else:
B, Lmax, D = packed_tensor.shape
packed_reshaped = packed_tensor
# Calculate total number of elements
N = int(lengths.sum().item())
out = torch.empty((N, D),
device=packed_tensor.device,
dtype=packed_tensor.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_unpack_seq_triton_kernel[grid](packed_reshaped,
out,
lengths.int(),
B,
Lmax,
D,
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 3:
output_shape = (N, ) + original_shape[2:]
out = out.reshape(output_shape)
return out
......@@ -19,6 +19,15 @@ if current_platform.is_cuda():
else:
_flashmla_C_AVAILABLE = False
if current_platform.is_cuda():
try:
import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
else:
_flashmla_extension_C_AVAILABLE = False
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
"""
......@@ -37,24 +46,34 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
- cache_seqlens: (batch_size), dtype torch.int32.
- num_q_tokens_per_head_k:
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
- num_heads_k: The number of k heads.
- num_heads_q:
The number of q heads.
This argument is optional when sparse attention is not enabled
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
- topk: If not None, sparse attention will be enabled,
and only tokens in the `indices` array
passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
is_fp8_kvcache, topk)
def flash_mla_with_kvcache(
......@@ -69,45 +88,95 @@ def flash_mla_with_kvcache(
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q.
descale_k: (batch_size), torch.float32. Descaling factors for K.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
- cache_seqlens: (batch_size), torch.int32.
- head_dim_v: Head dimension of v.
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
returned by get_mla_metadata.
- num_splits:
(batch_size + 1), torch.int32, returned by get_mla_metadata.
- softmax_scale: float.
The scale of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
- descale_q: (batch_size),
torch.float32. Descaling factors for Q, used for fp8 quantization.
- descale_k: (batch_size),
torch.float32. Descaling factors for K, used for fp8 quantization.
- is_fp8_kvcache: bool.
Whether the k_cache and v_cache are in fp8 format.
For the format of FP8 KV cache, please refer to README.md
- indices: (batch_size, seq_len_q, topk), torch.int32.
If not None, sparse attention will be enabled,
and only tokens in the `indices` array will be attended to.
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
For details about how to set up `indices`, please refer to README.md.
Returns:
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
# Note(hc): need revisit when we support DCP with decode query_len > 1.
return out.squeeze(1), softmax_lse.squeeze(-1)
if indices is not None:
# NOTE (zyongye): sparse attention is also causal
# since it only attend to the tokens before
# but here `causal` should not be specified
assert not causal, \
"causal must be `false` if sparse attention is enabled."
assert (descale_q is None) == (
descale_k is None
), "descale_q and descale_k should be both None or both not None"
if (descale_q is not None) and (descale_k is not None):
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
indices)
return out, softmax_lse
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
sm_scale, d_v)
return results
#
......
......@@ -50,6 +50,7 @@ class PagedAttention:
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
......
......@@ -144,6 +144,7 @@ def get_attn_backend(
block_size: int,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
......@@ -158,6 +159,7 @@ def get_attn_backend(
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
)
......@@ -170,6 +172,7 @@ def _cached_get_attn_backend(
use_v1: bool = False,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
......@@ -203,7 +206,7 @@ def _cached_get_attn_backend(
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
use_mla, has_sink)
use_mla, has_sink, use_sparse)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")
......
......@@ -22,7 +22,8 @@ else:
logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2",
"fp8_inc"]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
......@@ -52,7 +53,11 @@ class CacheConfig:
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
bfloat16 instead, this is an invalid option for models that do not default
to fp8.
"""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
......@@ -171,11 +176,12 @@ class CacheConfig:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in get_args(CacheDType):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor.")
if self.cache_dtype.startswith("fp8"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
......
......@@ -374,6 +374,7 @@ class CompilationConfig:
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.sparse_attn_indexer",
]
def compute_hash(self) -> str:
......
......@@ -1082,14 +1082,14 @@ class ModelConfig:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'deepseek_mtp',
'kimi_k2', 'longcat_flash'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
# underlying architecture
return self.hf_text_config.model.model_type in \
('deepseek_v2', 'deepseek_v3') \
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
and self.hf_text_config.kv_lora_rank is not None
return False
......
......@@ -145,7 +145,7 @@ class SpeculativeConfig:
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
......
......@@ -5,6 +5,7 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
......@@ -375,3 +376,20 @@ class PolyNorm(CustomOp):
x: torch.Tensor,
) -> torch.Tensor:
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
class LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
self.eps).type_as(x)
......@@ -24,6 +24,9 @@ class MLAModules:
q_a_layernorm: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
indexer: Optional[torch.nn.Module]
is_sparse: bool
topk_indices_buffer: Optional[torch.Tensor]
@CustomOp.register("multi_head_latent_attention")
......@@ -76,6 +79,13 @@ class MultiHeadLatentAttention(CustomOp):
self.kv_b_proj = mla_modules.kv_b_proj
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
if self.indexer is not None:
assert hasattr(self.indexer, "topk_tokens")
self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
......@@ -92,6 +102,7 @@ class MultiHeadLatentAttention(CustomOp):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
......@@ -100,6 +111,7 @@ class MultiHeadLatentAttention(CustomOp):
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
indexer=self.indexer,
)
self.prefix = prefix
......@@ -145,6 +157,10 @@ class MultiHeadLatentAttention(CustomOp):
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions,
self.rotary_emb)
attn_out = self.mla_attn(
q,
kv_c_normed,
......
......@@ -346,8 +346,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
use_mla=model_config.use_mla).page_size_bytes
dtype=kv_cache_dtype).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
......@@ -401,6 +400,31 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
"exactly equal.", mamba_padding_pct)
class DeepseekV3ForCausalLM(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
"""
hf_config = vllm_config.model_config.hf_config
is_v32 = hasattr(hf_config, "index_topk")
if is_v32:
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
# "auto")
cache_config = vllm_config.cache_config
if cache_config.cache_dtype == "auto" or \
cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info(
"Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
......@@ -417,4 +441,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
}
......@@ -53,8 +53,20 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
def forward(
self,
......
......@@ -33,15 +33,21 @@ from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
......@@ -49,6 +55,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -56,13 +64,26 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
class DeepseekV2MLP(nn.Module):
......@@ -276,6 +297,7 @@ class DeepseekV2Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
num_heads: int,
......@@ -289,6 +311,7 @@ class DeepseekV2Attention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
topk_indices_buffer: Optional[torch.Tensor] = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -306,6 +329,8 @@ class DeepseekV2Attention(nn.Module):
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
assert topk_indices_buffer is None, "topk_indices_buffer is not \
supported for DeepseekV2Attention"
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
......@@ -418,6 +443,391 @@ class DeepseekV2Attention(nn.Module):
return output
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str,
cache_config: CacheConfig):
super().__init__()
self.kv_cache = [torch.tensor([])]
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config
self.dtype = dtype
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def get_kv_cache_spec(self) -> KVCacheSpec:
return MLAAttentionSpec( # Only has one vector instead of K + V
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
)
def forward(self):
...
def get_attn_backend(self) -> AttentionBackend:
return DeepseekV32IndexerBackend
@torch.inference_mode()
def cp_gather_indexer_k_quant_cache(
kv_cache, # [num_blocks, block_size, head_dim + 1]
dst_value, # [cu_seq_lens[-1], head_dim]
dst_scale, # [cu_seq_lens[-1], 4]
block_table, # [batch_size, num_blocks]
cu_seq_lens, # [batch_size + 1, ]
batch_size,
):
num_blocks, block_size, _ = kv_cache.shape
head_dim = dst_value.shape[-1]
kv_cache = kv_cache.view(num_blocks, -1)
expected_value = []
expected_scale = []
for b in range(batch_size):
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
if s == 0:
continue
tot = cdiv(s, block_size)
blocks = block_table[b, :tot]
value = []
scale = []
full_block = torch.arange(tot - 1,
device=kv_cache.device,
dtype=torch.int32)
non_remaining_value = kv_cache[blocks[full_block], :block_size *
head_dim].view(-1, head_dim)
non_remaining_scale = kv_cache[blocks[full_block],
block_size * head_dim:].view(-1, 4)
remaining = s - (tot - 1) * block_size
value = torch.cat([
non_remaining_value,
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
],
dim=0)
scale = torch.cat([
non_remaining_scale,
kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
remaining * 4].view(-1, 4)
],
dim=0)
expected_value.append(value)
expected_scale.append(scale)
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
gather_value = gather_value.view(torch.float8_e4m3fn)
gather_scale = gather_scale.view(torch.float32)
dst_value.copy_(gather_value)
dst_scale.copy_(gather_scale)
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
k,
weights,
quant_block_size,
scale_fmt,
topk_tokens,
head_dim,
max_model_len,
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[:hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
num_prefills = attn_metadata.num_prefills
k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim],
device=k.device,
dtype=torch.float8_e4m3fn)
k_scale = torch.empty([prefill_metadata.total_seq_lens, 1],
device=k.device,
dtype=torch.float32)
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
prefill_metadata.block_table,
prefill_metadata.cu_seq_lens,
num_prefills,
)
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke
num_tokens = attn_metadata.num_actual_tokens
logits = fp8_mqa_logits(
q_fp8[num_decode_tokens:num_tokens],
(k_fp8, k_scale),
weights[num_decode_tokens:num_tokens],
cu_seqlen_ks,
cu_seqlen_ke,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
topk_indices -= cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(topk_indices,
False,
dtype=torch.bool,
device=topk_indices.device)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
decode_lens = decode_metadata.decode_lens
if decode_metadata.requires_padding:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:])
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
positions = torch.arange(max_model_len,
device=current_device).unsqueeze(0).expand(
batch_size * next_n, -1)
row_indices = torch.arange(padded_num_tokens,
device=current_device) // next_n
next_n_offset = torch.arange(
padded_num_tokens,
device=padded_q_fp8_decode_tokens.device) % next_n
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
topk_indices = logits.topk(topk_tokens,
dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices = unpack_seq_triton(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens)
topk_indices_buffer[:num_decode_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
return topk_indices_buffer
def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
device=k.device,
dtype=torch.uint8)
_k_fp8 = _flattened_kv[..., :head_dim].view(
torch.float8_e4m3fn).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer
direct_register_custom_op(
op_name="sparse_attn_indexer",
op_func=sparse_attn_indexer,
mutates_args=["topk_indices_buffer"],
fake_impl=sparse_attn_indexer_fake,
dispatch_key=current_platform.dispatch_key,
)
class Indexer(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
q_lora_rank: int,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
topk_indices_buffer: Optional[torch.Tensor],
prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.config = config
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64
self.head_dim = config.index_head_dim # 128
self.rope_dim = config.qk_rope_head_dim # 64
self.q_lora_rank = q_lora_rank # 1536
# no tensor parallel, just replicated
self.wq_b = ReplicatedLinear(self.q_lora_rank,
self.head_dim * self.n_head,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b")
self.wk = ReplicatedLinear(hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk")
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear(hidden_size,
self.n_head,
quant_config=None,
prefix=f"{prefix}.weights_proj")
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = "ue8m0"
self.quant_block_size = 128 # TODO: get from config
self.topk_indices_buffer = topk_indices_buffer
# NOTE: (zyongye) we use fp8 naive cache,
# where we store value in fp8 and scale in fp32
# per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache(
head_dim=self.head_dim +
self.head_dim // self.quant_block_size * 4,
dtype=torch.uint8,
prefix=f"{prefix}.k_cache",
cache_config=cache_config)
self.max_model_len = vllm_config.model_config.max_model_len
self.prefix = prefix
from vllm.v1.attention.backends.mla.indexer import (
get_max_prefill_buffer_size)
self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
rotary_emb) -> torch.Tensor:
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
k, _ = self.wk(hidden_states)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
q = torch.cat([q_pe, q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt
is not None)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = weights.unsqueeze(
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1)
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
class DeepseekV2MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
......@@ -429,6 +839,7 @@ class DeepseekV2MLAAttention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
num_heads: int,
......@@ -443,6 +854,7 @@ class DeepseekV2MLAAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
topk_indices_buffer: Optional[torch.Tensor] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -523,6 +935,15 @@ class DeepseekV2MLAAttention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
self.indexer = Indexer(vllm_config, config, hidden_size,
q_lora_rank, quant_config, cache_config,
topk_indices_buffer, f"{prefix}.indexer")
else:
self.indexer = None
mla_modules = MLAModules(
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
......@@ -536,7 +957,11 @@ class DeepseekV2MLAAttention(nn.Module):
if self.q_lora_rank is not None else None,
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else None,
indexer=self.indexer,
is_sparse=self.is_v32,
topk_indices_buffer=topk_indices_buffer,
)
self.mla_attn = MultiHeadLatentAttention(
self.hidden_size,
self.num_local_heads,
......@@ -562,7 +987,10 @@ class DeepseekV2MLAAttention(nn.Module):
class DeepseekV2DecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
def __init__(self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -585,6 +1013,7 @@ class DeepseekV2DecoderLayer(nn.Module):
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
vllm_config=vllm_config,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
......@@ -600,6 +1029,7 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
topk_indices_buffer=topk_indices_buffer,
)
if (config.n_routed_experts is not None
......@@ -683,6 +1113,16 @@ class DeepseekV2Model(nn.Module):
self.config = config
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
......@@ -695,7 +1135,8 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
......
......@@ -308,6 +308,7 @@ class FlashDecoderLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: FlashConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......@@ -329,6 +330,7 @@ class FlashDecoderLayer(nn.Module):
# Dual attention structure
self.self_attn = nn.ModuleList([
DeepseekV2MLAAttention(
vllm_config=vllm_config,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
......@@ -454,6 +456,7 @@ class FlashModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: FlashDecoderLayer(
vllm_config,
config,
cache_config=cache_config,
quant_config=quant_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment