sequence.py 5.93 KB
Newer Older
1
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
2
import enum
3
from typing import Dict, List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

from cacheflow.block import LogicalTokenBlock
6
from cacheflow.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9


class SequenceStatus(enum.Enum):
10
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
11
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14
15
    SWAPPED = enum.auto()
    FINISHED = enum.auto()


16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class SequenceData:

    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids

        self.output_token_ids: List[int] = []
        self.cumulative_logprobs = 0.0

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt={self.prompt}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"output_token_ids={self.output_token_ids})")


Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
48
49
class Sequence:

    def __init__(
        self,
        seq_id: int,
50
        prompt: str,
51
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
55
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
        self.block_size = block_size

58
59
60
        self.data = SequenceData(prompt_token_ids)
        self.output_logprobs: List[Dict[int, float]] = []

Woosuk Kwon's avatar
Woosuk Kwon committed
61
        self.logical_token_blocks: List[LogicalTokenBlock] = []
62
        # Initialize the logical token blocks with the prompt token ids.
63
        self._append_tokens_to_blocks(prompt_token_ids)
64
        self.status = SequenceStatus.WAITING
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
71
72
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

73
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
        while token_ids:
            if not self.logical_token_blocks:
76
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
80
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
84
            last_block.append_tokens(token_ids[:num_empty_slots])
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
            token_ids = token_ids[num_empty_slots:]

87
    def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
88
        assert token_id in logprobs
89
        self._append_tokens_to_blocks([token_id])
90
        self.output_logprobs.append(logprobs)
91
92
        self.data.output_token_ids.append(token_id)
        self.data.cumulative_logprobs += logprobs[token_id]
93

Woosuk Kwon's avatar
Woosuk Kwon committed
94
    def get_len(self) -> int:
95
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
96

Woosuk Kwon's avatar
Woosuk Kwon committed
97
    def get_token_ids(self) -> List[int]:
98
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
99

100
    def get_last_token_id(self) -> int:
101
        return self.data.get_last_token_id()
102
103
104
105

    def fork(self, child_seq: 'Sequence') -> 'Sequence':
        child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
        child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
106
        child_seq.data = copy.deepcopy(self.data)
107

Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
    def __repr__(self) -> str:
        return (f'Sequence(seq_id={self.seq_id}, '
                f'status={self.status.name}, '
                f'num_blocks={len(self.logical_token_blocks)})')

Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
117
118
119

class SequenceGroup:

    def __init__(
        self,
        group_id: int,
        seqs: List[Sequence],
120
        arrival_time: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
    ) -> None:
        self.group_id = group_id
        self.seqs = seqs
124
        self.arrival_time = arrival_time
Woosuk Kwon's avatar
Woosuk Kwon committed
125

126
127
128
129
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        if status is None:
131
            return self.seqs
Woosuk Kwon's avatar
Woosuk Kwon committed
132
        else:
133
134
135
136
            return [seq for seq in self.seqs if seq.status == status]

    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        return len(self.get_seqs(status))
137
138
139
140
141
142

    def find(self, seq_id: int) -> Sequence:
        for seq in self.seqs:
            if seq.seq_id == seq_id:
                return seq
        raise ValueError(f'Sequence {seq_id} not found.')
Woosuk Kwon's avatar
Woosuk Kwon committed
143

Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
    def is_finished(self) -> bool:
        return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)

Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
    def __repr__(self) -> str:
        return (f'SequenceGroup(group_id={self.group_id}, '
                f'num_seqs={len(self.seqs)})')
150
151


152
class SequenceGroupMetadata:
153
154
155
156
157

    def __init__(
        self,
        group_id: int,
        is_prompt: bool,
158
        seq_data: Dict[int, SequenceData],      # Seq id -> sequence data.
159
        sampling_params: SamplingParams,
160
        block_tables: Dict[int, List[int]],     # Seq id -> list of physical block numbers.
161
162
163
    ) -> None:
        self.group_id = group_id
        self.is_prompt = is_prompt
164
        self.seq_data = seq_data
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        self.sampling_params = sampling_params
        self.block_tables = block_tables


class SequenceOutputs:

    def __init__(
        self,
        seq_id: int,
        parent_seq_id: int,
        output_token: int,
        logprobs: Dict[int, float],         # Token id -> logP(x_i+1 | x_0, ..., x_i).
    ) -> None:
        self.seq_id = seq_id
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
        return (f'SequenceOutputs(seq_id={self.seq_id}, '
                f'parent_seq_id={self.parent_seq_id}, '
                f'output_token={self.output_token}), '
                f'logprobs={self.logprobs}')
Zhuohan Li's avatar
Zhuohan Li committed
188
189
190
191
192
193

    def __eq__(self, other: 'SequenceOutputs') -> bool:
        return (self.seq_id == other.seq_id and
                self.parent_seq_id == other.parent_seq_id and
                self.output_token == other.output_token and
                self.logprobs == other.logprobs)