"devtools/ci/vscode:/vscode.git/clone" did not exist on "0ad623417f701ddffd7d6b92ab94c26d9ba2310c"
mla.py 6.21 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import torch

7
from vllm.attention.layer import MLAAttention
8
9
10
11
12
13
14
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig


@dataclass
class MLAModules:
15
16
    """Modules used in MLA."""

17
18
19
20
    kv_a_layernorm: torch.nn.Module
    kv_b_proj: torch.nn.Module
    rotary_emb: torch.nn.Module
    o_proj: torch.nn.Module
21
22
23
24
25
26
    fused_qkv_a_proj: torch.nn.Module | None
    kv_a_proj_with_mqa: torch.nn.Module | None
    q_a_layernorm: torch.nn.Module | None
    q_b_proj: torch.nn.Module | None
    q_proj: torch.nn.Module | None
    indexer: torch.nn.Module | None
27
    indexer_rotary_emb: torch.nn.Module | None
28
    is_sparse: bool
29
    topk_indices_buffer: torch.Tensor | None
30
31
32


@CustomOp.register("multi_head_latent_attention")
33
34
35
class MultiHeadLatentAttentionWrapper(CustomOp):
    """MLA layer registered as CustomOp to allow OOT backends to add
    custom implementations of the outer MLA layer (including rope & o_proj).
36
37
38
39
    Note that currently MLA ignores the enable/disable mechanism of CustomOp
    because there is only one in-tree implementation in forward_native.
    TODO: implement this with a new PluggableLayer mechanism.

40
    This class takes positions and hidden_states as input.
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    The input tensors can either contain prefill tokens or decode tokens.
    The class does the following:

    1. MLA Preprocess.
    2. Perform multi-head attention to prefill tokens and
       multi-query attention to decode tokens separately.
    3. Return the output tensor.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
58
        q_lora_rank: int | None,
59
60
        kv_lora_rank: int,
        mla_modules: MLAModules,
61
62
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
        self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
        self.q_a_layernorm = mla_modules.q_a_layernorm
        self.q_b_proj = mla_modules.q_b_proj
        self.q_proj = mla_modules.q_proj
        self.kv_a_layernorm = mla_modules.kv_a_layernorm
        self.kv_b_proj = mla_modules.kv_b_proj
        self.rotary_emb = mla_modules.rotary_emb
        self.o_proj = mla_modules.o_proj
83
        self.indexer = mla_modules.indexer
84
        self.indexer_rope_emb = mla_modules.indexer_rotary_emb
85
86
87
88
89
90
        self.is_sparse = mla_modules.is_sparse

        if self.indexer is not None:
            assert hasattr(self.indexer, "topk_tokens")
            self.topk_tokens = self.indexer.topk_tokens
            self.topk_indices_buffer = mla_modules.topk_indices_buffer
91

92
        self.mla_attn = MLAAttention(
93
94
95
96
97
            num_heads=self.num_heads,
            scale=scale,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
98
99
100
101
102
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
103
            kv_b_proj=self.kv_b_proj,
104
            use_sparse=self.is_sparse,
105
            indexer=self.indexer,
106
107
108
109
110
111
112
113
        )

        self.prefix = prefix

    def forward_native(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
114
        llama_4_scaling: torch.Tensor | None = None,
115
116
117
118
119
    ) -> torch.Tensor:
        q_c = None
        kv_lora = None

        if self.q_lora_rank is not None:
120
            assert self.fused_qkv_a_proj is not None, (
121
                "fused_qkv_a_proj is required when q_lora_rank is not None"
122
123
            )
            assert self.q_a_layernorm is not None, (
124
                "q_a_layernorm is required when q_lora_rank is not None"
125
126
            )
            assert self.q_b_proj is not None, (
127
                "q_b_proj is required when q_lora_rank is not None"
128
            )
129
130
131
132
133
134
135
136
            qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
            q_c, kv_lora = qkv_lora.split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                dim=-1,
            )
            q_c = self.q_a_layernorm(q_c)
            q = self.q_b_proj(q_c)[0]
        else:
137
            assert self.kv_a_proj_with_mqa is not None, (
138
                "kv_a_proj_with_mqa is required when q_lora_rank is None"
139
140
            )
            assert self.q_proj is not None, (
141
                "q_proj is required when q_lora_rank is None"
142
            )
143
144
145
            kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
            q = self.q_proj(hidden_states)[0]

146
        kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
147
148
149
150
151
152
        kv_c_normed = self.kv_a_layernorm(kv_c)

        q = q.view(-1, self.num_heads, self.qk_head_dim)
        # Add head dim of 1 to k_pe
        k_pe = k_pe.unsqueeze(1)

153
154
155
156
        if self.rotary_emb is not None:
            q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
                positions, q[..., self.qk_nope_head_dim :], k_pe
            )
157

158
        if self.indexer and self.is_sparse:
159
160
161
            _topk_indices = self.indexer(
                hidden_states, q_c, positions, self.indexer_rope_emb
            )
162

163
164
165
        if llama_4_scaling is not None:
            q *= llama_4_scaling

166
167
168
169
        attn_out = self.mla_attn(
            q,
            kv_c_normed,
            k_pe,
170
171
            output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
        )
172

173
174
175
176
        return self.o_proj(attn_out)[0]

    def forward_cuda(self, *args, **kwargs):
        return self.forward_native(*args, **kwargs)