layer.py 4.59 KB
Newer Older
1
"""Attention layer."""
2
from typing import Any, Dict, List, Optional
3
4
5
6

import torch
import torch.nn as nn

7
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
8
from vllm.attention.selector import get_attn_backend
9
from vllm.config import CacheConfig
10
11
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
12
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
34
        cache_config: Optional[CacheConfig] = None,
35
        quant_config: Optional[QuantizationConfig] = None,
36
        blocksparse_params: Optional[Dict[str, Any]] = None,
37
        prefix: str = "",
38
39
    ) -> None:
        super().__init__()
40
41
42
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
43
            sliding_window = cache_config.sliding_window
44
45
46
        else:
            kv_cache_dtype = "auto"
            block_size = 16
47
            sliding_window = None
48
49
        if num_kv_heads is None:
            num_kv_heads = num_heads
50

51
        # The default k/v_scale is set to 1.0. This is ignored
52
53
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
54
        # expect the pre-quantized k/v_scale to be loaded along
55
56
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
57
58
        self._k_scale = 1.0
        self._v_scale = 1.0
59
        quant_method = quant_config.get_quant_method(
60
            self, prefix=prefix) if quant_config else None
61
        if quant_method is not None:
62
63
64
65
66
67
68
69
70
            assert isinstance(quant_method, Fp8KVCacheMethod)
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
            if "fp8" in self.kv_cache_dtype:
                if self.kv_cache_dtype == "fp8_e5m2":
                    raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                     "fp8 checkpoints.")
                # When FP8 quantization is enabled, we make a parameter
                # "kv_scale" so that it can be loaded from FP8 checkpoint.
71
72
                # The k/v_scale will then be converted back to
                # self._kv_scale in a native float32 value after weight loading
73
74
                self.quant_method = quant_method
                self.quant_method.create_weights(self)
75

76
77
78
79
80
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
                                        sliding_window, dtype, kv_cache_dtype,
81
82
                                        block_size, blocksparse_params
                                        is not None)
83
        impl_cls = attn_backend.get_impl_cls()
84
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
85
86
                             alibi_slopes, sliding_window, kv_cache_dtype,
                             blocksparse_params)
87
88
89
90
91
92
93

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: Optional[torch.Tensor],
94
        attn_metadata: AttentionMetadata,
95
        attn_type: AttentionType = AttentionType.DECODER,
96
    ) -> torch.Tensor:
97
98
99
100
101
102

        return self.impl.forward(query,
                                 key,
                                 value,
                                 kv_cache,
                                 attn_metadata,
103
104
                                 self._k_scale,
                                 self._v_scale,
105
                                 attn_type=attn_type)
106
107
108
109
110
111

    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
112
        s += f", backend={self.impl.__class__.__name__}"
113
        return s