sequence.py 2.16 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import enum
from typing import List, Optional

from cacheflow.block import LogicalTokenBlock
from cacheflow.decoding import DecodingParams


class SequenceStatus(enum.Enum):
    PENDING = enum.auto()
    SERVING = enum.auto()
    SWAPPED = enum.auto()
    FINISHED = enum.auto()


class Sequence:

    def __init__(
        self,
        seq_id: int,
        token_ids: List[int],
        max_context_len: int,
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
        self.max_context_len = max_context_len
        self.block_size = block_size

        self.logical_token_blocks: List[LogicalTokenBlock] = []
        # Initialize the logical token blocks with the given token ids.
        self.append(token_ids)

        self.status = SequenceStatus.PENDING

    def add_block(self) -> None:
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

    def append(self, token_ids: List[int]) -> None:
        while token_ids:
            if not self.logical_token_blocks:
                self.add_block()

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
                self.add_block()
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
            last_block.append(token_ids[:num_empty_slots])
            token_ids = token_ids[num_empty_slots:]

    def get_token_ids(self) -> List[int]:
        token_ids: List[int] = []
        for block in self.logical_token_blocks:
            token_ids.extend(block.get_token_ids())
        return token_ids


class SequenceGroup:

    def __init__(
        self,
        group_id: int,
        seqs: List[Sequence],
        decoding_params: DecodingParams,
    ) -> None:
        self.group_id = group_id
        self.seqs = seqs
        self.decoding_params = decoding_params

    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        if status is None:
            return len(self.seqs)
        else:
            return len([seq for seq in self.seqs if seq.status == status])