common.py 2.07 KB
Newer Older
1
from dataclasses import dataclass
2
from text_generation_server.models.globals import ATTENTION
3
4
5
6
import torch
from typing import Optional


7
if ATTENTION in {"flashinfer", "flashdecoding"}:
8
9
10
11

    @dataclass
    class Seqlen:
        input_lengths: torch.Tensor
12
        prefix_lengths: torch.Tensor
13
14
        cu_seqlen_q: Optional[torch.Tensor]
        cu_seqlen_k: Optional[torch.Tensor]
15
16
        max_q: int
        max_k: int
17

18
19
20
21
22
23
24
25
        def __init__(
            self,
            input_lengths,
            prefix_lengths,
            cu_seqlen_q=None,
            max_q=None,
            max_k=None,
        ):
26
            self.input_lengths = input_lengths
27
            self.prefix_lengths = prefix_lengths
28
29
            device = self.input_lengths.device
            shape = self.input_lengths.shape
30
31
32
33
34
35
36
37
38
39
            if cu_seqlen_q is None:
                cu_seqlen_q = torch.arange(
                    shape[0] + 1,
                    device=device,
                    dtype=torch.int32,
                )
                max_q = 1
            else:
                assert max_q is not None
            assert max_k is not None
40
            cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
41

42
43
44
            # cuda graphs don't like this and this is necessary to clamp within mistral
            # Although FA2 might not want the clamping
            # cu_seqlen_k[0] = 0
45
46
            total = self.input_lengths + self.prefix_lengths
            torch.cumsum(total, -1, out=cu_seqlen_k[1:])
47
48
49

            self.cu_seqlen_q = cu_seqlen_q
            self.cu_seqlen_k = cu_seqlen_k
50
51
            self.max_q = max_q
            self.max_k = max_k
52
53
54
55
56
57
58
59
60
61

        def clamp(self, max):
            # Flash decoding doesn't need to clamp
            return self

else:

    @dataclass
    class Seqlen:
        input_lengths: torch.Tensor
62
63
64
65
        prefix_lengths: torch.Tensor
        cu_seqlen_q: torch.Tensor
        max_q: int
        max_k: int
66
67

        def clamp(self, max):
68
            raise NotImplementedError("Not implemented seqlen for paged")
69
            return Seqlen(torch.clamp(self.input_lengths, max=max))