common.py 2.19 KB
Newer Older
1
from dataclasses import dataclass
2
from text_generation_server.utils.import_utils import SYSTEM
3
from text_generation_server.models.globals import ATTENTION
4
5
6
7
import torch
from typing import Optional


8
if ATTENTION in {"flashinfer", "flashdecoding"}:
9
10
11
12

    @dataclass
    class Seqlen:
        input_lengths: torch.Tensor
13
        prefix_lengths: torch.Tensor
14
15
        cu_seqlen_q: Optional[torch.Tensor]
        cu_seqlen_k: Optional[torch.Tensor]
16
17
        max_q: int
        max_k: int
18

19
20
21
22
23
24
25
26
        def __init__(
            self,
            input_lengths,
            prefix_lengths,
            cu_seqlen_q=None,
            max_q=None,
            max_k=None,
        ):
27
            self.input_lengths = input_lengths
28
            self.prefix_lengths = prefix_lengths
29
30
            device = self.input_lengths.device
            shape = self.input_lengths.shape
31
32
33
34
35
36
37
38
39
40
            if cu_seqlen_q is None:
                cu_seqlen_q = torch.arange(
                    shape[0] + 1,
                    device=device,
                    dtype=torch.int32,
                )
                max_q = 1
            else:
                assert max_q is not None
            assert max_k is not None
41
            cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
42

43
44
45
            # cuda graphs don't like this and this is necessary to clamp within mistral
            # Although FA2 might not want the clamping
            # cu_seqlen_k[0] = 0
46
47
            total = self.input_lengths + self.prefix_lengths
            torch.cumsum(total, -1, out=cu_seqlen_k[1:])
48
49
50

            self.cu_seqlen_q = cu_seqlen_q
            self.cu_seqlen_k = cu_seqlen_k
51
52
            self.max_q = max_q
            self.max_k = max_k
53
54
55
56
57
58
59
60
61
62

        def clamp(self, max):
            # Flash decoding doesn't need to clamp
            return self

else:

    @dataclass
    class Seqlen:
        input_lengths: torch.Tensor
63
64
65
66
        prefix_lengths: torch.Tensor
        cu_seqlen_q: torch.Tensor
        max_q: int
        max_k: int
67
68

        def clamp(self, max):
69
70
            if SYSTEM == "rocm":
                return self
71
            raise NotImplementedError("Not implemented seqlen for paged")
72
            return Seqlen(torch.clamp(self.input_lengths, max=max))