mla.py 11 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import torch

7
from vllm.attention.layer import MLAAttention
8
from vllm.config import CacheConfig
9
import vllm.envs as envs
10
from vllm.forward_context import get_forward_context
11
from vllm.model_executor.custom_op import PluggableLayer
12
from vllm.model_executor.layers.quantization import QuantizationConfig
13
14
15
from vllm.distributed import (
    tensor_model_parallel_all_gather,
)
16
17
18
19


@dataclass
class MLAModules:
20
21
    """Modules used in MLA."""

22
23
24
25
    kv_a_layernorm: torch.nn.Module
    kv_b_proj: torch.nn.Module
    rotary_emb: torch.nn.Module
    o_proj: torch.nn.Module
26
27
28
29
30
31
    fused_qkv_a_proj: torch.nn.Module | None
    kv_a_proj_with_mqa: torch.nn.Module | None
    q_a_layernorm: torch.nn.Module | None
    q_b_proj: torch.nn.Module | None
    q_proj: torch.nn.Module | None
    indexer: torch.nn.Module | None
32
    is_sparse: bool
33
    topk_indices_buffer: torch.Tensor | None
34
    indexer_rotary_emb: torch.nn.Module | None = None
35
36


37
# --8<-- [start:multi_head_latent_attention]
38
39
40
@PluggableLayer.register("multi_head_latent_attention")
class MultiHeadLatentAttentionWrapper(PluggableLayer):
    """Pluggable MLA layer which allows OOT backends to add
41
    custom implementations of the outer MLA layer (including rope & o_proj).
42
43
44
    Note that currently oot platforms can still use CustomOp.register_oot to
    replace MLA layer entirly, although we use PluggableLayer to register
    this layer now.
45

46
    This class takes positions and hidden_states as input.
47
48
49
50
51
52
53
54
55
    The input tensors can either contain prefill tokens or decode tokens.
    The class does the following:

    1. MLA Preprocess.
    2. Perform multi-head attention to prefill tokens and
       multi-query attention to decode tokens separately.
    3. Return the output tensor.
    """

56
57
    # --8<-- [end:multi_head_latent_attention]

58
59
60
61
62
63
64
65
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
66
        q_lora_rank: int | None,
67
68
        kv_lora_rank: int,
        mla_modules: MLAModules,
69
70
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
        self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
        self.q_a_layernorm = mla_modules.q_a_layernorm
        self.q_b_proj = mla_modules.q_b_proj
        self.q_proj = mla_modules.q_proj
        self.kv_a_layernorm = mla_modules.kv_a_layernorm
        self.kv_b_proj = mla_modules.kv_b_proj
        self.rotary_emb = mla_modules.rotary_emb
        self.o_proj = mla_modules.o_proj
91
        self.indexer = mla_modules.indexer
92
        self.indexer_rope_emb = mla_modules.indexer_rotary_emb
93
94
95
96
97
98
        self.is_sparse = mla_modules.is_sparse

        if self.indexer is not None:
            assert hasattr(self.indexer, "topk_tokens")
            self.topk_tokens = self.indexer.topk_tokens
            self.topk_indices_buffer = mla_modules.topk_indices_buffer
99

100
        self.mla_attn = MLAAttention(
101
102
103
104
105
            num_heads=self.num_heads,
            scale=scale,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
106
107
108
109
110
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
111
            kv_b_proj=self.kv_b_proj,
112
            use_sparse=self.is_sparse,
113
            indexer=self.indexer,
114
115
116
117
        )

        self.prefix = prefix

118
    def forward(
119
120
121
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
122
        llama_4_scaling: torch.Tensor | None = None,
123
        *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
124
125
126
127
128
    ) -> torch.Tensor:
        q_c = None
        kv_lora = None

        if self.q_lora_rank is not None:
129
            assert self.fused_qkv_a_proj is not None, (
130
                "fused_qkv_a_proj is required when q_lora_rank is not None"
131
132
            )
            assert self.q_a_layernorm is not None, (
133
                "q_a_layernorm is required when q_lora_rank is not None"
134
135
            )
            assert self.q_b_proj is not None, (
136
                "q_b_proj is required when q_lora_rank is not None"
137
            )
138
139
140
141
            if envs.USE_FUSED_RMS_QUANT and iqis is not None:
                qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0]
            else:
                qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
142
143
144
145
            q_c, kv_lora = qkv_lora.split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                dim=-1,
            )
wujl5's avatar
wujl5 committed
146
147
148
149
150
151
152
153
154
155
            if envs.USE_FUSED_RMS_QUANT:
                qa_iq, qa_is, _ = self.q_a_layernorm(x=q_c,
                                                     residual=None, 
                                                     quant_dtype=torch.int8,
                                                     update_input=False)
                q = self.q_b_proj(q_c, iqis=(qa_iq, qa_is))[0]
                
            else:
                q_c = self.q_a_layernorm(q_c)
                q = self.q_b_proj(q_c)[0]
156
        else:
157
            assert self.kv_a_proj_with_mqa is not None, (
158
                "kv_a_proj_with_mqa is required when q_lora_rank is None"
159
160
            )
            assert self.q_proj is not None, (
161
                "q_proj is required when q_lora_rank is None"
162
            )
163
164
165
            kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
            q = self.q_proj(hidden_states)[0]

166
        kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
167
168
        kv_cache_dtype = getattr(self.mla_attn, "kv_cache_dtype", "auto")
        calculate_kv_scales = getattr(self.mla_attn, "calculate_kv_scales", False)
xiabo's avatar
xiabo committed
169
170

        if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
171
            kv_c_normed = self.kv_a_layernorm(kv_c)
172
173
174
175
176

        q = q.view(-1, self.num_heads, self.qk_head_dim)
        # Add head dim of 1 to k_pe
        k_pe = k_pe.unsqueeze(1)

xiabo's avatar
xiabo committed
177
178
        # if not use_fused_rms_rope_concat and self.rotary_emb is not None:
        if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT and self.rotary_emb is not None:
179
180
            q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
                positions, q[..., self.qk_nope_head_dim:], k_pe
181
            )
182

183
        if self.indexer and self.is_sparse:
184
185
186
187
            if envs.USE_FUSED_RMS_QUANT and iqis is not None:
                _topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb, iqis=iqis)
            else:
                _topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb)
188

189
190
        if llama_4_scaling is not None:
            q *= llama_4_scaling
191
192
193
194

        enable_lightly_cp = get_forward_context().enable_lightly_cp
        enable_lightly_cplb = get_forward_context().enable_lightly_cplb

xiabo's avatar
xiabo committed
195
196
        # if not use_fused_rms_rope_concat:
        if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
197
198
199
200
201
202
203
204
205
206
207
208
209
210
            if enable_lightly_cp:
                kv_c_normed = tensor_model_parallel_all_gather(
                    kv_c_normed.contiguous(), 0
                )
                k_pe = tensor_model_parallel_all_gather(
                    k_pe.contiguous(), 0
                )

                gather_indexes_tensor = get_forward_context().gather_indexes_tensor
                if enable_lightly_cplb and gather_indexes_tensor is not None:
                    # Reorder kv after pcp allgather.
                    kv_c_normed = torch.index_select(kv_c_normed, 0, gather_indexes_tensor)
                    k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
            attn_out = self.mla_attn(
                q,
                kv_c_normed,
                k_pe,
                output_shape=(hidden_states.shape[0],
                              self.num_heads * self.v_head_dim),
            )
        else:
            # Lightop fused path:
            # - kv_c is passed as "unnormed" and written to kv_cache by the backend.
            # - key_normed is an output buffer filled by the fused op and then
            #   used for the prefill path.
            # Keep kv_c/k_pe as views into the original kv_lora buffer so they
            # share the same row stride. The lightop fused op requires
            # `kv_c.stride(0) == k_pe.stride(0)`, which is not true if we make
            # kv_c individually contiguous.
            key_normed = torch.empty_like(kv_c,
                                          memory_format=torch.contiguous_format)
            weight = getattr(self.kv_a_layernorm, "weight", None)
            epsilon = getattr(self.kv_a_layernorm, "variance_epsilon", 1e-6)
            if weight is None:
                raise RuntimeError(
                    "VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires kv_a_layernorm "
                    "to have a 'weight' parameter."
                )
            # Keep cos_sin_cache on the same device/dtype as q.
            if hasattr(self.rotary_emb, "_match_cos_sin_cache_dtype"):
                # type: ignore[attr-defined]
                self.rotary_emb._match_cos_sin_cache_dtype(q)
            cos_sin_cache = getattr(self.rotary_emb, "cos_sin_cache", None)
            if cos_sin_cache is None:
                raise RuntimeError(
                    "VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
                    "expose 'cos_sin_cache'."
                )
246
247
248
249
250
251
252
253
254
255
256
257
258
259

            if enable_lightly_cp:
                kv_c = tensor_model_parallel_all_gather(
                    kv_c.contiguous(), 0
                )
                k_pe = tensor_model_parallel_all_gather(
                    k_pe.contiguous(), 0
                )
                gather_indexes_tensor = get_forward_context().gather_indexes_tensor
                if enable_lightly_cplb and gather_indexes_tensor is not None:
                    # Reorder kv after pcp allgather.
                    kv_c = torch.index_select(kv_c, 0, gather_indexes_tensor)
                    k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)

260
261
262
263
264
265
266
267
268
269
270
271
272
            attn_out = self.mla_attn(
                q[..., self.qk_nope_head_dim:],
                kv_c,
                k_pe,
                output_shape=(hidden_states.shape[0],
                              self.num_heads * self.v_head_dim),
                q_ori=q,
                key_normed=key_normed,
                positions=positions,
                weight=weight,
                cos_sin_cache=cos_sin_cache,
                epsilon=epsilon,
            )
273

274
        return self.o_proj(attn_out)[0]