encoder_only_attention.py 3.13 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy

import torch

from vllm import envs
9
10
11
12
13
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionMetadata,
    AttentionType,
)
14
15
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
16
from vllm.config import CacheConfig
17
from vllm.config.vllm import VllmConfig
18
19
20
21
from vllm.v1.attention.backends.utils import (
    CommonAttentionMetadata,
    subclass_attention_backend,
)
22
from vllm.v1.kv_cache_interface import KVCacheSpec
23
24
25
26


@functools.lru_cache
def create_encoder_only_attention_backend(
27
28
    underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
29
30
31
32
    prefix = "EncoderOnlyAttention_"
    underlying_builder = underlying_attn_backend.get_builder_cls()

    class EncoderOnlyAttentionBuilder(underlying_builder):  # type: ignore
33
34
35
36
37
38
        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
39
40
            new_common_attn_metadata = copy(common_attn_metadata)
            new_common_attn_metadata.causal = False
41
42
43
            return super().build(
                common_prefix_len, new_common_attn_metadata, fast_build
            )
44
45
46
47

    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
48
49
        builder_cls=EncoderOnlyAttentionBuilder,
    )
50
51
52
53
54
55
56
57
58

    return attn_backend


class EncoderOnlyAttention(Attention):
    """
    Encoder attention is a special case that doesn't need a KV Cache.
    """

59
60
61
62
63
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
64
65
        cache_config: CacheConfig | None = None,
        attn_type: str | None = None,
66
67
        **kwargs,
    ):
68
69
70
71
72
73
74
75
76
77
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        if envs.VLLM_USE_V1:
78
79
80
            underlying_attn_backend = get_attn_backend(
                head_size, dtype, kv_cache_dtype, block_size
            )
81
82

            attn_backend = create_encoder_only_attention_backend(
83
84
                underlying_attn_backend
            )
85
86
87
88
89
        else:
            # in v0 encoder only attention is handled inside the backends
            attn_backend = None

        if attn_type is not None:
90
            assert attn_type == AttentionType.ENCODER_ONLY, (
91
                "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
92
93
94
95
96
97
98
99
100
101
102
            )

        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            attn_type=AttentionType.ENCODER_ONLY,
            **kwargs,
        )
103
104
105
106

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Does not need KV cache
        return None