scheduler.py 21.6 KB
Newer Older
1
import enum
2
3
import os
import pickle
4
import time
5
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7

from cacheflow.master.block_manager import BlockSpaceManager
8
from cacheflow.master.policy import PolicyFactory
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from cacheflow.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
12
13
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
from cacheflow.sequence import SequenceStatus


17
18
19
20
21
22
23
24
25
26
27
28
29
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
32
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
33
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
        controllers: List,
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
38
        max_num_batched_tokens: int,
39
40
41
        max_num_sequences: int,
        collect_stats: bool,
        do_memory_analysis: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
42
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45
46
        self.controllers = controllers
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
47
        self.max_num_batched_tokens = max_num_batched_tokens
48
49
50
        self.max_num_sequences = max_num_sequences
        self.collect_stats = collect_stats
        self.do_memory_analysis = do_memory_analysis
Woosuk Kwon's avatar
Woosuk Kwon committed
51

52
53
        # Instantiate the scheduling policy.
        self.policy = PolicyFactory.get_policy(policy_name='fcfs')
Woosuk Kwon's avatar
Woosuk Kwon committed
54
        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
59
60
        self.block_manager = BlockSpaceManager(
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
        )

61
62
63
        # Sequence groups in the WAITING state.
        self.waiting: List[SequenceGroup] = []
        # Sequence groups in the RUNNING state.
64
        self.running: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
        # Mapping: group_id -> num_steps.
        self.num_steps: Dict[int, int] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
        # Mapping: group_id -> sampling params.
        self.sampling_params: Dict[int, SamplingParams] = {}
69
        # Sequence groups in the SWAPPED state.
Woosuk Kwon's avatar
Woosuk Kwon committed
70
        self.swapped: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
71

72
73
74
        # Performance-related statistics.
        self.stats = Stats(num_gpu_blocks, num_cpu_blocks)

75
76
    def add_sequence_groups(
        self,
77
        seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
78
    ) -> None:
79
80
81
        # Add sequence groups to the waiting queue.
        for seq_group, sampling_params in seq_groups:
            self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
            self.sampling_params[seq_group.group_id] = sampling_params

84
    def _schedule(
85
        self,
86
    ) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
87
88
89
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
90
        blocks_to_copy: Dict[int, List[int]] = {}
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        # Fix the current time.
        now = time.time()

        # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
        # in order to minimize the preemption overheads.
        # Preemption happens only when there is no available slot to keep all
        # the sequence groups in the RUNNING state.
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
        running: List[SequenceGroup] = []
        preempted: List[SequenceGroup] = []
        while self.running:
            seq_group = self.running.pop(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
            while not self.block_manager.can_append(seq_group):
109
110
111
112
113
114
115
116
117
118
                if self.running:
                    # Preempt the lowest-priority sequence groups.
                    victim_seq_group = self.running.pop(-1)
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
                    break
            else:
121
                # Append new slots to the sequence group.
122
                self._append(seq_group, blocks_to_copy)
123
124
125
126
127
128
129
130
131
132
133
134
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
        while self.swapped:
            seq_group = self.swapped[0]
            # If the sequence group has been preempted in this step, stop.
            if seq_group in preempted:
                break
            # If the sequence group cannot be swapped in, stop.
            if not self.block_manager.can_swap_in(seq_group):
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
                break

137
138
139
140
141
142
            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
            if len(self.running) + num_seqs > self.max_num_sequences:
                break

143
144
145
146
            seq_group = self.swapped.pop(0)
            self._swap_in(seq_group, blocks_to_swap_in)
            self._append(seq_group, blocks_to_copy)
            self.running.append(seq_group)
147

148
149
150
151
152
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

153
154
155
156
157
158
        # Join waiting sequences if possible.
        prompt_group_ids: List[int] = []
        # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
        # prioritized over the sequence groups in the WAITING state.
        # This is because we want to bound the amount of CPU memory taken by
        # the swapped sequence groups.
Woosuk Kwon's avatar
Woosuk Kwon committed
159
        if not self.swapped:
160
161
162
163
164
165
166
167
168
169
170
            self.waiting = self.policy.sort_by_priority(now, self.waiting)
            while self.waiting:
                seq_group = self.waiting[0]
                # If the sequence group has been preempted in this step, stop.
                if seq_group in preempted:
                    break
                # If the sequence group cannot be allocated, stop.
                if not self.block_manager.can_allocate(seq_group):
                    break

                # If the number of batched tokens exceeds the limit, stop.
171
                num_prompt_tokens = seq_group.seqs[0].get_len()
172
173
174
175
                if (num_batched_tokens + num_prompt_tokens
                    > self.max_num_batched_tokens):
                    break

176
177
178
179
180
181
                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
                num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
                if len(self.running) + num_seqs > self.max_num_sequences:
                    break

182
183
184
185
186
                seq_group = self.waiting.pop(0)
                self._allocate(seq_group)
                self.running.append(seq_group)
                num_batched_tokens += num_prompt_tokens
                prompt_group_ids.append(seq_group.group_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        if self.collect_stats:
            if self.running or blocks_to_swap_in or blocks_to_swap_out:
                self.stats.timestamps.append(now - self.stats.start_time)
                self.stats.input_lens.append(num_batched_tokens)
                self.stats.swap_out_lens.append(len(blocks_to_swap_out) * self.block_size)
                self.stats.swap_in_lens.append(len(blocks_to_swap_in) * self.block_size)
                self.stats.num_preemption.append(len(preempted))
                self.stats.num_swapped.append(len(self.swapped))
                self.stats.num_running.append(len(self.running))
                self.stats.num_waiting.append(len(self.waiting))

                num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
                num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks
                self.stats.gpu_cache_usage.append(num_used_gpu_blocks / self.num_gpu_blocks)
                num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
                num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks
                self.stats.cpu_cache_usage.append(num_used_cpu_blocks / self.num_cpu_blocks)

                if self.do_memory_analysis:
                    block_tables = self.block_manager.block_tables
                    num_logical_blocks = 0
                    num_logical_tokens = 0
                    num_physical_blocks = 0
                    num_physical_tokens = 0
                    physical_block_numbers = set()
                    num_reserved_tokens = 0
                    for seq_group in self.running:
                        group_id = seq_group.group_id
                        sampling_params = self.sampling_params[group_id]
                        max_num_steps = sampling_params.max_num_steps
                        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                            num_logical_blocks += len(seq.logical_token_blocks)
                            num_logical_tokens += seq.get_len()

                            seq_id = seq.seq_id
                            block_table = block_tables[seq_id]
                            for i, block in enumerate(block_table):
                                if block.block_number in physical_block_numbers:
                                    continue
                                physical_block_numbers.add(block.block_number)
                                num_physical_blocks += 1
                                num_physical_tokens += seq.logical_token_blocks[i].num_tokens
                    
                    assert num_physical_blocks == num_used_gpu_blocks
                    self.stats.num_logical_blocks.append(num_logical_blocks)
                    self.stats.num_logical_tokens.append(num_logical_tokens)
                    self.stats.num_physical_blocks.append(num_physical_blocks)
                    self.stats.num_physical_tokens.append(num_physical_tokens)
                    self.stats.num_reserved_tokens.append(num_reserved_tokens)

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        return (blocks_to_swap_in,
                blocks_to_swap_out,
                blocks_to_copy,
                prompt_group_ids)

    def step(self) -> List[SequenceGroup]:
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
        scheduler_output = self._schedule()
        blocks_to_swap_in = scheduler_output[0]
        blocks_to_swap_out = scheduler_output[1]
        blocks_to_copy = scheduler_output[2]
        prompt_group_ids = scheduler_output[3]

        # Create input data structures.
254
        input_seq_groups: List[SequenceGroupInputs] = []
255
256
        updated_seq_groups: List[SequenceGroup] = self.running.copy()

257
258
        for seq_group in self.running:
            group_id = seq_group.group_id
259
            is_prompt = group_id in prompt_group_ids
260

261
262
263
264
            input_tokens: Dict[int, List[int]] = {}
            seq_logprobs: Dict[int, float] = {}
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
265
266
267
                seq_id = seq.seq_id
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
                if is_prompt:
268
                    input_tokens[seq_id] = seq.get_token_ids()
269
                else:
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                    input_tokens[seq_id] = [seq.get_last_token_id()]
                seq_logprobs[seq_id] = seq.cumulative_logprobs
                # NOTE(woosuk): Sequences in the same group have the same
                # sequence length
                seq_len = seq.get_len()

            input_seq_group = SequenceGroupInputs(
                group_id=group_id,
                is_prompt=is_prompt,
                input_tokens=input_tokens,
                context_len=seq_len,
                seq_logprobs=seq_logprobs,
                sampling_params=self.sampling_params[group_id],
                block_tables=block_tables,
            )
            input_seq_groups.append(input_seq_group)
286

287
288
289
290
        # Execute the first stage of the pipeline.
        if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
            # Swap in and swap out should never happen at the same time.
            assert not (blocks_to_swap_in and blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
            self.controllers[0].execute_stage(
                input_seq_groups,
293
294
295
                blocks_to_swap_in=blocks_to_swap_in,
                blocks_to_swap_out=blocks_to_swap_out,
                blocks_to_copy=blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
296
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
297

298
299
        return updated_seq_groups

Woosuk Kwon's avatar
Woosuk Kwon committed
300
301
    def post_step(
        self,
302
        seq_outputs: Dict[int, SequenceOutputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
    ) -> None:
        # Update the running sequences and free blocks.
305
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
            group_id = seq_group.group_id
            self.num_steps[group_id] += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
308
            stop_token_ids = self.sampling_params[group_id].stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
309

310
            # Process beam search results before processing the next tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
311
312
313
314
            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

315
316
                output = seq_outputs[seq.seq_id]
                if seq.seq_id != output.parent_seq_id:
Woosuk Kwon's avatar
Woosuk Kwon committed
317
318
319
320
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
321
322
                    parent_seq = seq_group.find(output.parent_seq_id)
                    parent_seq.fork(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
323
324
                    self.block_manager.fork(parent_seq, seq)

325
326
327
328
329
            # Process the next tokens.
            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

Woosuk Kwon's avatar
Woosuk Kwon committed
330
                # Append a new token to the sequence.
331
332
                output = seq_outputs[seq.seq_id]
                seq.append(output.output_token, output.logprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
333
334

                # Check if the sequence has generated a stop token.
335
                if output.output_token in stop_token_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
336
337
338
339
                    self._free_seq(seq)
                    continue

                # Check if the sequence has reached the maximum number of steps.
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341
                max_num_steps = self.sampling_params[group_id].max_num_steps
                if self.num_steps[group_id] == max_num_steps:
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
344
                    self._free_seq(seq)
                    continue

345
346
347
        # Update the running sequences.
        running: List[SequenceGroup] = []
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
348
            if seq_group.is_finished():
349
                self._free_seq_group(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
350
            else:
351
352
                running.append(seq_group)
        self.running = running
Woosuk Kwon's avatar
Woosuk Kwon committed
353

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
        for seq in seq_group.seqs:
            seq.status = SequenceStatus.RUNNING
        # FIXME(woosuk): Support interactive generation.
        if seq_group.group_id not in self.num_steps:
            self.num_steps[seq_group.group_id] = 0

    def _append(
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            ret = self.block_manager.append(seq)
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
        # (e.g., beam search), recomputation is not supported. In such a case,
        # we use swapping instead.
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
            seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
            if len(seqs) == 1:
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
            assert False, 'Invalid preemption mode.'

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
            self.block_manager.free(seq)
        self.waiting.append(seq_group)

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        for seq in seqs:
            seq.status = SequenceStatus.SWAPPED
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)

432
    def _free_seq_group(self, seq_group: SequenceGroup) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
433
434
435
        group_id = seq_group.group_id
        del self.num_steps[group_id]
        del self.sampling_params[group_id]
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        assert self.block_manager.can_swap_out(seq_group)
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528

    def reset_stats(self) -> None:
        self.stats.reset(self.num_gpu_blocks, self.num_cpu_blocks)

    def save_stats(
        self,
        output_dir: str,
    ) -> None:
        assert self.collect_stats, 'Statistics collection is disabled.'
        self.stats.save(output_dir)


class Stats:

    def __init__(
        self,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        self.start_time: float = time.time()
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks

        self.timestamps: List[float] = []
        self.input_lens: List[int] = []
        self.swap_out_lens: List[int] = []
        self.swap_in_lens: List[int] = []
        self.num_preemption: List[int] = []
        self.num_waiting: List[int] = []
        self.num_running: List[int] = []
        self.num_swapped: List[int] = []
        self.gpu_cache_usage: List[float] = []
        self.cpu_cache_usage: List[float] = []

        self.num_logical_blocks: List[int] = []
        self.num_logical_tokens: List[int] = []
        self.num_physical_blocks: List[int] = []
        self.num_physical_tokens: List[int] = []
        self.num_reserved_tokens: List[int] = []

    def reset(
        self,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        self.__init__(num_gpu_blocks, num_cpu_blocks)

    def to_dict(self) -> Dict[str, Any]:
        return {
            'start_time': self.start_time,
            'num_gpu_blocks': self.num_gpu_blocks,
            'num_cpu_blocks': self.num_cpu_blocks,
            'timestamps': self.timestamps,
            'input_lens': self.input_lens,
            'swap_out_lens': self.swap_out_lens,
            'swap_in_lens': self.swap_in_lens,
            'num_preemption': self.num_preemption,
            'num_waiting': self.num_waiting,
            'num_running': self.num_running,
            'num_swapped': self.num_swapped,
            'gpu_cache_usage': self.gpu_cache_usage,
            'cpu_cache_usage': self.cpu_cache_usage,
            'num_logical_blocks': self.num_logical_blocks,
            'num_logical_tokens': self.num_logical_tokens,
            'num_physical_blocks': self.num_physical_blocks,
            'num_physical_tokens': self.num_physical_tokens,
            'num_reserved_tokens': self.num_reserved_tokens,
        }

    def save(self, output_dir: str) -> None:
        with open(os.path.join(output_dir, 'stats.pkl'), 'wb') as f:
            pickle.dump(self.to_dict(), f)