sequence.py 6.67 KB
Newer Older
1
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
2
import enum
3
from typing import Dict, List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

from cacheflow.block import LogicalTokenBlock
6
from cacheflow.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9


class SequenceStatus(enum.Enum):
10
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
11
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14
15
    SWAPPED = enum.auto()
    FINISHED = enum.auto()


16
17
18
19
20
21
22
23
24
class SequenceData:

    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids

        self.output_token_ids: List[int] = []
25
26
27
28
29
        self.cumulative_logprob = 0.0

    def append_token(self, token_id: int, logprob: float) -> None:
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
30
31
32
33

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

34
35
36
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

37
38
39
40
41
42
43
44
45
46
47
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
48
49
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
50
51


Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
55
56
class Sequence:

    def __init__(
        self,
        seq_id: int,
57
        prompt: str,
58
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
62
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
        self.block_size = block_size

65
66
        self.data = SequenceData(prompt_token_ids)
        self.output_logprobs: List[Dict[int, float]] = []
67
        self.output_text = ""
68

Woosuk Kwon's avatar
Woosuk Kwon committed
69
        self.logical_token_blocks: List[LogicalTokenBlock] = []
70
        # Initialize the logical token blocks with the prompt token ids.
71
        self._append_tokens_to_blocks(prompt_token_ids)
72
        self.status = SequenceStatus.WAITING
Woosuk Kwon's avatar
Woosuk Kwon committed
73

74
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
78
79
80
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

81
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
        while token_ids:
            if not self.logical_token_blocks:
84
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
88
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
92
            last_block.append_tokens(token_ids[:num_empty_slots])
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
            token_ids = token_ids[num_empty_slots:]

95
    def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
96
        assert token_id in logprobs
97
        self._append_tokens_to_blocks([token_id])
98
        self.output_logprobs.append(logprobs)
99
        self.data.append_token(token_id, logprobs[token_id])
100

Woosuk Kwon's avatar
Woosuk Kwon committed
101
    def get_len(self) -> int:
102
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
103

104
105
106
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
107
    def get_token_ids(self) -> List[int]:
108
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
109

110
    def get_last_token_id(self) -> int:
111
        return self.data.get_last_token_id()
112

113
114
115
116
117
118
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

119
    def fork(self, child_seq: 'Sequence') -> None:
120
121
        child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
        child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
122
        child_seq.data = copy.deepcopy(self.data)
123
        return None
124

Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
129
    def __repr__(self) -> str:
        return (f'Sequence(seq_id={self.seq_id}, '
                f'status={self.status.name}, '
                f'num_blocks={len(self.logical_token_blocks)})')

Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134

class SequenceGroup:

    def __init__(
        self,
135
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
136
        seqs: List[Sequence],
137
        sampling_params: SamplingParams,
138
        arrival_time: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
139
    ) -> None:
140
        self.request_id = request_id
Woosuk Kwon's avatar
Woosuk Kwon committed
141
        self.seqs = seqs
142
        self.sampling_params = sampling_params
143
        self.arrival_time = arrival_time
Woosuk Kwon's avatar
Woosuk Kwon committed
144

145
146
147
148
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
Woosuk Kwon's avatar
Woosuk Kwon committed
149
        if status is None:
150
            return self.seqs
Woosuk Kwon's avatar
Woosuk Kwon committed
151
        else:
152
153
154
155
            return [seq for seq in self.seqs if seq.status == status]

    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        return len(self.get_seqs(status))
156
157
158
159
160
161

    def find(self, seq_id: int) -> Sequence:
        for seq in self.seqs:
            if seq.seq_id == seq_id:
                return seq
        raise ValueError(f'Sequence {seq_id} not found.')
Woosuk Kwon's avatar
Woosuk Kwon committed
162

Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
    def is_finished(self) -> bool:
        return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)

Woosuk Kwon's avatar
Woosuk Kwon committed
166
    def __repr__(self) -> str:
167
168
169
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
                f"num_seqs={len(self.seqs)})")
170
171


172
class SequenceGroupMetadata:
173
174
175

    def __init__(
        self,
176
        request_id: str,
177
        is_prompt: bool,
178
        seq_data: Dict[int, SequenceData],      # Seq id -> sequence data.
179
        sampling_params: SamplingParams,
180
        block_tables: Dict[int, List[int]],     # Seq id -> list of physical block numbers.
181
    ) -> None:
182
        self.request_id = request_id
183
        self.is_prompt = is_prompt
184
        self.seq_data = seq_data
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        self.sampling_params = sampling_params
        self.block_tables = block_tables


class SequenceOutputs:

    def __init__(
        self,
        seq_id: int,
        parent_seq_id: int,
        output_token: int,
        logprobs: Dict[int, float],         # Token id -> logP(x_i+1 | x_0, ..., x_i).
    ) -> None:
        self.seq_id = seq_id
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
        return (f'SequenceOutputs(seq_id={self.seq_id}, '
                f'parent_seq_id={self.parent_seq_id}, '
                f'output_token={self.output_token}), '
                f'logprobs={self.logprobs}')
Zhuohan Li's avatar
Zhuohan Li committed
208

209
210
211
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, SequenceOutputs):
            return NotImplemented
Zhuohan Li's avatar
Zhuohan Li committed
212
213
214
215
        return (self.seq_id == other.seq_id and
                self.parent_seq_id == other.parent_seq_id and
                self.output_token == other.output_token and
                self.logprobs == other.logprobs)