flashmla.py 5.18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
5
from typing import Any, Optional
6
7
8

import torch

9
10
from vllm.attention.backends.abstract import (AttentionType,
                                              is_quantized_kv_cache)
11
12
13
14
15
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
                                         get_mla_metadata,
                                         is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
16
                                                   MLACommonDecodeMetadata,
17
18
19
                                                   MLACommonImpl,
                                                   MLACommonMetadata,
                                                   MLACommonMetadataBuilder)
20
21
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
22
23
24
25
26
27
28
29
30
31
32

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "FLASHMLA_VLLM_V1"

    @staticmethod
33
    def get_metadata_cls() -> type["FlashMLAMetadata"]:
34
35
36
        return FlashMLAMetadata

    @staticmethod
37
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
38
39
40
        return FlashMLAMetadataBuilder

    @staticmethod
41
    def get_impl_cls() -> type["FlashMLAImpl"]:
42
43
44
45
        return FlashMLAImpl


@dataclass
46
47
48
49
50
51
52
53
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
    tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor]
    num_splits: torch.Tensor


@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass
54
55
56
57


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):

58
59
60
    def __init__(self, runner, kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
        super().__init__(runner, kv_cache_spec, block_table)
61
62
63
64

        self.num_q_heads = self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config)

65
    def _build_decode(self, block_table_tensor: torch.Tensor,
66
67
68
69
70
71
72
                      seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
        tile_scheduler_metadata, num_splits = \
            get_mla_metadata(
            seq_lens,
            self.num_q_heads,
            1, # MQA for the decode path
        )
73

74
        return FlashMLADecodeMetadata(
75
            block_table=block_table_tensor,
76
77
78
79
            seq_lens=seq_lens,
            tile_scheduler_metadata=tile_scheduler_metadata,
            num_splits=num_splits,
        )
80
81
82
83
84
85
86
87
88
89


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
90
            alibi_slopes: Optional[list[float]],
91
92
            sliding_window: Optional[int],
            kv_cache_dtype: str,
93
            blocksparse_params: Optional[dict[str, Any]],
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            logits_soft_cap: Optional[float],
            attn_type: str,
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         blocksparse_params, logits_soft_cap, attn_type,
                         **mla_args)

        assert is_flashmla_supported(), \
            "FlashMLA is not supported on this device"

        unsupported_features = [
            alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
        ]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, blocksparse_params, "
                "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashMLAImpl")

121
122
123
124
        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "FlashMLA V1 with FP8 KV cache not yet supported")

125
126
127
128
129
130
131
132
    def _forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0
133
134
        assert attn_metadata.decode is not None

135
136
137
138
139
140
        q = torch.cat([q_nope, q_pe], dim=-1)\
            .unsqueeze(1) # Add seqlen dim of 1 (decode)

        o, _ = flash_mla_with_kvcache(
            q=q,
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
141
142
            block_table=attn_metadata.decode.block_table,
            cache_seqlens=attn_metadata.decode.seq_lens,
143
            head_dim_v=self.kv_lora_rank,
144
145
146
            tile_scheduler_metadata=attn_metadata.decode.
            tile_scheduler_metadata,
            num_splits=attn_metadata.decode.num_splits,
147
148
149
150
            softmax_scale=self.scale,
            causal=True,
        )

151
        return self._v_up_proj(o)