mla.py 6.42 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import torch

7
from vllm.attention.layer import MLAAttention
8
from vllm.config import CacheConfig
9
from vllm.model_executor.custom_op import PluggableLayer
10
from vllm.model_executor.layers.quantization import QuantizationConfig
11
from vllm import envs
12
13
14
15


@dataclass
class MLAModules:
16
17
    """Modules used in MLA."""

18
19
20
21
    kv_a_layernorm: torch.nn.Module
    kv_b_proj: torch.nn.Module
    rotary_emb: torch.nn.Module
    o_proj: torch.nn.Module
22
23
24
25
26
27
    fused_qkv_a_proj: torch.nn.Module | None
    kv_a_proj_with_mqa: torch.nn.Module | None
    q_a_layernorm: torch.nn.Module | None
    q_b_proj: torch.nn.Module | None
    q_proj: torch.nn.Module | None
    indexer: torch.nn.Module | None
28
    is_sparse: bool
29
    topk_indices_buffer: torch.Tensor | None
30
    indexer_rotary_emb: torch.nn.Module | None = None
31
32


33
# --8<-- [start:multi_head_latent_attention]
34
35
36
@PluggableLayer.register("multi_head_latent_attention")
class MultiHeadLatentAttentionWrapper(PluggableLayer):
    """Pluggable MLA layer which allows OOT backends to add
37
    custom implementations of the outer MLA layer (including rope & o_proj).
38
39
40
    Note that currently oot platforms can still use CustomOp.register_oot to
    replace MLA layer entirly, although we use PluggableLayer to register
    this layer now.
41

42
    This class takes positions and hidden_states as input.
43
44
45
46
47
48
49
50
51
    The input tensors can either contain prefill tokens or decode tokens.
    The class does the following:

    1. MLA Preprocess.
    2. Perform multi-head attention to prefill tokens and
       multi-query attention to decode tokens separately.
    3. Return the output tensor.
    """

52
53
    # --8<-- [end:multi_head_latent_attention]

54
55
56
57
58
59
60
61
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
62
        q_lora_rank: int | None,
63
64
        kv_lora_rank: int,
        mla_modules: MLAModules,
65
66
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
        self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
        self.q_a_layernorm = mla_modules.q_a_layernorm
        self.q_b_proj = mla_modules.q_b_proj
        self.q_proj = mla_modules.q_proj
        self.kv_a_layernorm = mla_modules.kv_a_layernorm
        self.kv_b_proj = mla_modules.kv_b_proj
        self.rotary_emb = mla_modules.rotary_emb
        self.o_proj = mla_modules.o_proj
87
        self.indexer = mla_modules.indexer
88
        self.indexer_rope_emb = mla_modules.indexer_rotary_emb
89
90
91
92
93
94
        self.is_sparse = mla_modules.is_sparse

        if self.indexer is not None:
            assert hasattr(self.indexer, "topk_tokens")
            self.topk_tokens = self.indexer.topk_tokens
            self.topk_indices_buffer = mla_modules.topk_indices_buffer
95

96
        self.mla_attn = MLAAttention(
97
98
99
100
101
            num_heads=self.num_heads,
            scale=scale,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
102
103
104
105
106
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
107
            kv_b_proj=self.kv_b_proj,
108
            use_sparse=self.is_sparse,
109
            indexer=self.indexer,
110
111
112
113
        )

        self.prefix = prefix

114
    def forward(
115
116
117
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
118
        llama_4_scaling: torch.Tensor | None = None,
119
        *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
120
121
122
123
124
    ) -> torch.Tensor:
        q_c = None
        kv_lora = None

        if self.q_lora_rank is not None:
125
            assert self.fused_qkv_a_proj is not None, (
126
                "fused_qkv_a_proj is required when q_lora_rank is not None"
127
128
            )
            assert self.q_a_layernorm is not None, (
129
                "q_a_layernorm is required when q_lora_rank is not None"
130
131
            )
            assert self.q_b_proj is not None, (
132
                "q_b_proj is required when q_lora_rank is not None"
133
            )
134
135
136
137
            if envs.USE_FUSED_RMS_QUANT and iqis is not None:
                qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0]
            else:
                qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
138
139
140
141
142
143
144
            q_c, kv_lora = qkv_lora.split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                dim=-1,
            )
            q_c = self.q_a_layernorm(q_c)
            q = self.q_b_proj(q_c)[0]
        else:
145
            assert self.kv_a_proj_with_mqa is not None, (
146
                "kv_a_proj_with_mqa is required when q_lora_rank is None"
147
148
            )
            assert self.q_proj is not None, (
149
                "q_proj is required when q_lora_rank is None"
150
            )
151
152
153
            kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
            q = self.q_proj(hidden_states)[0]

154
        kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
155
156
157
158
159
160
        kv_c_normed = self.kv_a_layernorm(kv_c)

        q = q.view(-1, self.num_heads, self.qk_head_dim)
        # Add head dim of 1 to k_pe
        k_pe = k_pe.unsqueeze(1)

161
162
163
164
        if self.rotary_emb is not None:
            q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
                positions, q[..., self.qk_nope_head_dim :], k_pe
            )
165

166
        if self.indexer and self.is_sparse:
167
168
169
            _topk_indices = self.indexer(
                hidden_states, q_c, positions, self.indexer_rope_emb
            )
170

171
172
173
        if llama_4_scaling is not None:
            q *= llama_4_scaling

174
175
176
177
        attn_out = self.mla_attn(
            q,
            kv_c_normed,
            k_pe,
178
179
            output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
        )
180

181
        return self.o_proj(attn_out)[0]