attention.py 2.79 KB
Newer Older
1
"""Attention layer."""
2
from functools import lru_cache
3
4
5
6
7
from typing import List, Optional

import torch
import torch.nn as nn

8
from vllm.logger import init_logger
9
10
11
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip

12
13
logger = init_logger(__name__)

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
        super().__init__()
37
        if _use_flash_attn():
38
            from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend  # noqa: E501
39
40
41
42
            self.backend = FlashAttentionBackend(num_heads, head_size, scale,
                                                 num_kv_heads, alibi_slopes,
                                                 sliding_window)
        else:
43
            from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend  # noqa: E501
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
            self.backend = XFormersBackend(num_heads, head_size, scale,
                                           num_kv_heads, alibi_slopes,
                                           sliding_window)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        return self.backend.forward(query, key, value, key_cache, value_cache,
                                    input_metadata)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


@lru_cache(maxsize=1)
def _use_flash_attn() -> bool:
    try:
        import flash_attn  # noqa: F401
    except ImportError:
        logger.info("flash_attn is not found. Using xformers backend.")
        return False

    if is_hip():
        # AMD GPUs.
        return False
    if torch.cuda.get_device_capability()[0] < 8:
        # Volta and Turing NVIDIA GPUs.
        logger.info("flash_attn is not supported on Turing or older GPUs. "
                    "Using xformers backend.")
        return False
    if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
        logger.info(
            "flash_attn only supports torch.float16 or torch.bfloat16. "
            "Using xformers backend.")
        return False

    logger.info("Using flash_attn backend.")
    return True