flashmla.py 8.26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

from contextlib import contextmanager
from dataclasses import dataclass
6
from typing import List, Optional, Tuple, Type
7
8
9

import torch

10
11
from vllm.attention.backends.abstract import (AttentionType,
                                              is_quantized_kv_cache)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from vllm.attention.backends.mla.common import (MLACommonBackend,
                                                MLACommonImpl,
                                                MLACommonMetadata,
                                                MLACommonMetadataBuilder,
                                                MLACommonState)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
                                         get_mla_metadata,
                                         is_flashmla_supported)


class FlashMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "FLASHMLA"

    @staticmethod
    def get_impl_cls() -> Type["FlashMLAImpl"]:
        return FlashMLAImpl

    @staticmethod
    def get_metadata_cls() -> Type["FlashMLAMetadata"]:
        return FlashMLAMetadata

    @staticmethod
    def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
        return FlashMLAMetadataBuilder

    @staticmethod
    def get_state_cls() -> Type["FlashMLAState"]:
        return FlashMLAState


@dataclass
class FlashMLAMetadata(MLACommonMetadata):
    decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
                                                   torch.Tensor]] = None
    decode_num_splits: Optional[torch.Tensor] = None

    @property
    def decode_metadata(self):
        decode_metadata = super().decode_metadata
        # TODO: cache assignment?
        if decode_metadata is not None:
            decode_metadata.decode_tile_scheduler_metadata=\
                self.decode_tile_scheduler_metadata
            decode_metadata.decode_num_splits=\
                self.decode_num_splits
        return decode_metadata


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.num_q_heads = self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config)

    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int):
        m = super().build(seq_lens, query_lens, cuda_graph_pad_size,
                          batch_size)

        if m.num_decode_tokens > 0:
            m.decode_tile_scheduler_metadata, m.decode_num_splits = \
                get_mla_metadata(
                m.seq_lens_tensor[m.num_prefills:],
                self.num_q_heads,
                1, # MQA for the decode path
            )

        return m


class FlashMLAState(MLACommonState[FlashMLAMetadata]):

    def __init__(self, *args, **kwds):
        super().__init__(*args, **kwds)

        self.num_q_heads = self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config)

    @contextmanager
    def graph_capture(self, max_batch_size: int):
        # Run a dummy `get_mla_metadata` so we can get the right shapes
        self._graph_decoder_tile_scheduler_metadata, \
            self._graph_decode_num_splits = get_mla_metadata(
            torch.ones(
                max_batch_size, dtype=torch.int32, device=self.runner.device),
            self.num_q_heads,
            1, # MQA for the decode path
        )

        with super().graph_capture(max_batch_size):
            yield

        del self._graph_decoder_tile_scheduler_metadata
        del self._graph_decode_num_splits

    def graph_capture_get_metadata_for_batch(
            self, batch_size: int, is_encoder_decoder_model: bool = False):
        metadata = super().graph_capture_get_metadata_for_batch(
            batch_size, is_encoder_decoder_model)
        assert metadata.num_decode_tokens > 0

        decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
            self._graph_seq_lens[:batch_size],
            self.num_q_heads,
            1,  # MQA for the decode path
        )

        self._graph_decoder_tile_scheduler_metadata.copy_(
            decoder_tile_scheduler_metadata)
        self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits)

        metadata.decode_tile_scheduler_metadata=\
            self._graph_decoder_tile_scheduler_metadata
        metadata.decode_num_splits=\
            self._graph_decode_num_splits[:batch_size + 1]

        return metadata

    def get_graph_input_buffers(self,
                                attn_metadata,
                                is_encoder_decoder_model: bool = False):
        input_buffers = super().get_graph_input_buffers(
            attn_metadata, is_encoder_decoder_model)
        input_buffers["decode_tile_scheduler_metadata"] = \
                attn_metadata.decode_metadata.decode_tile_scheduler_metadata
        input_buffers["decode_num_splits"] = \
                attn_metadata.decode_metadata.decode_num_splits

        return input_buffers

    def prepare_graph_input_buffers(self,
                                    input_buffers,
                                    attn_metadata,
                                    is_encoder_decoder_model: bool = False):
        super().prepare_graph_input_buffers(input_buffers, attn_metadata,
                                            is_encoder_decoder_model)

        input_buffers["decode_tile_scheduler_metadata"].copy_(
            attn_metadata.decode_metadata.decode_tile_scheduler_metadata)
        input_buffers["decode_num_splits"].copy_(
            attn_metadata.decode_metadata.decode_num_splits)


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[List[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            logits_soft_cap: Optional[float],
            attn_type: str,
173
            kv_sharing_target_layer_name: Optional[str] = None,
174
175
176
177
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
178
                         logits_soft_cap, attn_type,
179
                         kv_sharing_target_layer_name, **mla_args)
180
181
182
183

        assert is_flashmla_supported(), \
            "FlashMLA is not supported on this device"

184
        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
185
186
187
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
188
                "alibi_slopes, sliding_window, logits_soft_cap")
189
190
191
192
193
194
195

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashMLAImpl")

196
197
198
199
        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "FlashMLA with FP8 KV cache not yet supported")

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    def _forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0

        decode_meta = attn_metadata.decode_metadata
        assert decode_meta is not None

        q = torch.cat([q_nope, q_pe], dim=-1)\
            .unsqueeze(1) # Add seqlen dim of 1 (decode)

        o, _ = flash_mla_with_kvcache(
            q=q,
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
            block_table=decode_meta.block_tables,
            cache_seqlens=decode_meta.seq_lens_tensor,
            head_dim_v=self.kv_lora_rank,
            tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
            num_splits=decode_meta.decode_num_splits,
            softmax_scale=self.scale,
            causal=True,
        )

227
        return self._v_up_proj(o)