sequence.py 7.36 KB
Newer Older
1
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
2
import enum
Zhuohan Li's avatar
Zhuohan Li committed
3
from typing import Dict, List, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

from cacheflow.block import LogicalTokenBlock
6
from cacheflow.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9


class SequenceStatus(enum.Enum):
10
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
11
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
12
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
32

33
34
35
36
37
38
39
40
class SequenceData:

    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids
        self.output_token_ids: List[int] = []
41
42
        self.cumulative_logprob = 0.0

43
    def append_token_id(self, token_id: int, logprob: float) -> None:
44
45
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
46
47
48
49

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

50
51
52
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

53
54
55
56
57
58
59
60
61
62
63
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
64
65
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
66
67


Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
71
72
class Sequence:

    def __init__(
        self,
        seq_id: int,
73
        prompt: str,
74
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
78
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
        self.block_size = block_size

81
82
        self.data = SequenceData(prompt_token_ids)
        self.output_logprobs: List[Dict[int, float]] = []
83
        self.output_tokens: List[str] = []
84
        self.output_text = ""
85

Woosuk Kwon's avatar
Woosuk Kwon committed
86
        self.logical_token_blocks: List[LogicalTokenBlock] = []
87
        # Initialize the logical token blocks with the prompt token ids.
88
        self._append_tokens_to_blocks(prompt_token_ids)
89
        self.status = SequenceStatus.WAITING
Woosuk Kwon's avatar
Woosuk Kwon committed
90

91
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
96
97
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

98
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
        while token_ids:
            if not self.logical_token_blocks:
101
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
105
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
108
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
109
            last_block.append_tokens(token_ids[:num_empty_slots])
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
            token_ids = token_ids[num_empty_slots:]

112
113
114
115
116
    def append_token_id(
        self,
        token_id: int,
        logprobs: Dict[int, float],
    ) -> None:
117
        assert token_id in logprobs
118
        self._append_tokens_to_blocks([token_id])
119
        self.output_logprobs.append(logprobs)
120
        self.data.append_token_id(token_id, logprobs[token_id])
121

Woosuk Kwon's avatar
Woosuk Kwon committed
122
    def get_len(self) -> int:
123
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
124

125
126
127
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
128
    def get_token_ids(self) -> List[int]:
129
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
130

131
    def get_last_token_id(self) -> int:
132
        return self.data.get_last_token_id()
133

134
135
136
137
138
139
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

140
    def fork(self, child_seq: 'Sequence') -> None:
141
142
        child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
        child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
143
        child_seq.data = copy.deepcopy(self.data)
144
        return None
145

Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
149
150
    def __repr__(self) -> str:
        return (f'Sequence(seq_id={self.seq_id}, '
                f'status={self.status.name}, '
                f'num_blocks={len(self.logical_token_blocks)})')

Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
154
155

class SequenceGroup:

    def __init__(
        self,
156
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        seqs: List[Sequence],
158
        sampling_params: SamplingParams,
159
        arrival_time: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
160
    ) -> None:
161
        self.request_id = request_id
Woosuk Kwon's avatar
Woosuk Kwon committed
162
        self.seqs = seqs
163
        self.sampling_params = sampling_params
164
        self.arrival_time = arrival_time
Woosuk Kwon's avatar
Woosuk Kwon committed
165

166
167
168
169
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        if status is None:
171
            return self.seqs
Woosuk Kwon's avatar
Woosuk Kwon committed
172
        else:
173
174
175
176
            return [seq for seq in self.seqs if seq.status == status]

    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        return len(self.get_seqs(status))
177
178
179
180
181
182

    def find(self, seq_id: int) -> Sequence:
        for seq in self.seqs:
            if seq.seq_id == seq_id:
                return seq
        raise ValueError(f'Sequence {seq_id} not found.')
Woosuk Kwon's avatar
Woosuk Kwon committed
183

Woosuk Kwon's avatar
Woosuk Kwon committed
184
    def is_finished(self) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
185
        return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
186

Woosuk Kwon's avatar
Woosuk Kwon committed
187
    def __repr__(self) -> str:
188
189
190
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
                f"num_seqs={len(self.seqs)})")
191
192


193
class SequenceGroupMetadata:
194
195
196

    def __init__(
        self,
197
        request_id: str,
198
        is_prompt: bool,
199
        seq_data: Dict[int, SequenceData],      # Seq id -> sequence data.
200
        sampling_params: SamplingParams,
201
        block_tables: Dict[int, List[int]],     # Seq id -> list of physical block numbers.
202
    ) -> None:
203
        self.request_id = request_id
204
        self.is_prompt = is_prompt
205
        self.seq_data = seq_data
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        self.sampling_params = sampling_params
        self.block_tables = block_tables


class SequenceOutputs:

    def __init__(
        self,
        seq_id: int,
        parent_seq_id: int,
        output_token: int,
        logprobs: Dict[int, float],         # Token id -> logP(x_i+1 | x_0, ..., x_i).
    ) -> None:
        self.seq_id = seq_id
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
        return (f'SequenceOutputs(seq_id={self.seq_id}, '
                f'parent_seq_id={self.parent_seq_id}, '
                f'output_token={self.output_token}), '
                f'logprobs={self.logprobs}')
Zhuohan Li's avatar
Zhuohan Li committed
229

230
231
232
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, SequenceOutputs):
            return NotImplemented
Zhuohan Li's avatar
Zhuohan Li committed
233
234
235
236
        return (self.seq_id == other.seq_id and
                self.parent_seq_id == other.parent_seq_id and
                self.output_token == other.output_token and
                self.logprobs == other.logprobs)