llm_engine.py 30.2 KB
Newer Older
Fang li's avatar
Fang li committed
1
import copy
Antoni Baum's avatar
Antoni Baum committed
2
import time
3
from functools import partial
Antoni Baum's avatar
Antoni Baum committed
4
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
Antoni Baum's avatar
Antoni Baum committed
8
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.engine.arg_utils import EngineArgs
Antoni Baum's avatar
Antoni Baum committed
10
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
14
15
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
                           SequenceGroupMetadata, SequenceOutputs,
Antoni Baum's avatar
Antoni Baum committed
16
                           SequenceStatus)
17
18
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
                                               get_tokenizer)
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.utils import Counter
20
21
22
23
24
25
26

if ray:
    from ray.air.util.torch_dist import init_torch_dist_process_group
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup
27
28
29

logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
_LOGGING_INTERVAL_SEC = 5

32

33
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
34
    """An LLM engine that receives requests and generates texts.
35

Woosuk Kwon's avatar
Woosuk Kwon committed
36
    This is the main class for the vLLM engine. It receives requests
37
38
39
40
41
42
43
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
44
    `AsyncLLMEngine` class wraps this class for online serving.
45

Zhuohan Li's avatar
Zhuohan Li committed
46
47
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
48
49
50
51
52
53
54
55
56

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
        distributed_init_method: The initialization method for distributed
            execution. See `torch.distributed.init_process_group` for details.
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
57
58
        placement_group: Ray placement group for distributed execution.
            Required for distributed execution.
59
60
        log_stats: Whether to log statistics.
    """
61
62
63
64
65
66
67
68

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        distributed_init_method: str,
69
        placement_group: Optional["PlacementGroup"],
70
        log_stats: bool,
71
72
    ) -> None:
        logger.info(
Zhuohan Li's avatar
Zhuohan Li committed
73
            "Initializing an LLM engine with config: "
74
            f"model={model_config.model!r}, "
75
            f"tokenizer={model_config.tokenizer!r}, "
76
            f"tokenizer_mode={model_config.tokenizer_mode}, "
Jasmond L's avatar
Jasmond L committed
77
            f"revision={model_config.revision}, "
78
            f"trust_remote_code={model_config.trust_remote_code}, "
79
            f"dtype={model_config.dtype}, "
80
            f"max_seq_len={model_config.max_model_len}, "
81
            f"download_dir={model_config.download_dir!r}, "
82
            f"load_format={model_config.load_format}, "
83
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
84
            f"quantization={model_config.quantization}, "
85
            f"seed={model_config.seed})")
86
87
88
89
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
90
91
        assert self.cache_config.sliding_window == getattr(
            self.model_config.hf_config, "sliding_window", None)
92
93
94
95
96
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.log_stats = log_stats
        self._verify_args()

97
        self.tokenizer = get_tokenizer(
98
99
            model_config.tokenizer,
            tokenizer_mode=model_config.tokenizer_mode,
Jasmond L's avatar
Jasmond L committed
100
101
            trust_remote_code=model_config.trust_remote_code,
            revision=model_config.revision)
102
103
104
        self.seq_counter = Counter()

        # Create the parallel GPU workers.
105
106
107
108
109
        if self.parallel_config.worker_use_ray:
            self._init_workers_ray(placement_group)
        else:
            self._init_workers(distributed_init_method)

110
111
112
113
        # Profile the memory usage and initialize the cache.
        self._init_cache()

        # Create the scheduler.
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
117
118
119
120
121
        self.scheduler = Scheduler(scheduler_config, cache_config)

        # Logging.
        self.last_logging_time = 0.0
        # List of (timestamp, num_tokens)
        self.num_prompt_tokens: List[Tuple[float, int]] = []
        # List of (timestamp, num_tokens)
        self.num_generation_tokens: List[Tuple[float, int]] = []
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def _init_workers(self, distributed_init_method: str):
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel

        assert self.parallel_config.world_size == 1, (
            "Ray is required if parallel_config.world_size > 1.")

        self.workers: List[Worker] = []
        worker = Worker(
            self.model_config,
            self.parallel_config,
            self.scheduler_config,
            0,
            distributed_init_method,
        )
        self.workers.append(worker)
        self._run_workers(
            "init_model",
            get_all_outputs=True,
        )

Antoni Baum's avatar
Antoni Baum committed
145
146
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel

        self.workers: List[Worker] = []
        for bundle in placement_group.bundle_specs:
            if not bundle.get("GPU", 0):
                continue
            worker = ray.remote(
                num_cpus=0,
                num_gpus=1,
                scheduling_strategy=PlacementGroupSchedulingStrategy(
                    placement_group=placement_group,
                    placement_group_capture_child_tasks=True),
Antoni Baum's avatar
Antoni Baum committed
161
                **ray_remote_kwargs,
162
            )(RayWorker).remote(self.model_config.trust_remote_code)
163
164
165
166
            self.workers.append(worker)

        # Initialize torch distributed process group for the workers.
        init_torch_dist_process_group(self.workers, backend="nccl")
Fang li's avatar
Fang li committed
167
168
169
        model_config = copy.deepcopy(self.model_config)
        parallel_config = copy.deepcopy(self.parallel_config)
        scheduler_config = copy.deepcopy(self.scheduler_config)
170
171
172
        self._run_workers("init_worker",
                          get_all_outputs=True,
                          worker_init_fn=lambda: Worker(
Fang li's avatar
Fang li committed
173
174
175
                              model_config,
                              parallel_config,
                              scheduler_config,
176
177
178
179
180
181
182
183
                              None,
                              None,
                          ))
        self._run_workers(
            "init_model",
            get_all_outputs=True,
        )

184
185
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
186
        self.cache_config.verify_with_parallel_config(self.parallel_config)
187
188

    def _init_cache(self) -> None:
189
        """Profiles the memory usage and initializes the KV cache."""
190
191
192
193
194
195
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers(
            "profile_num_available_blocks",
            get_all_outputs=True,
            block_size=self.cache_config.block_size,
            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
196
            cpu_swap_space=self.cache_config.swap_space_bytes,
197
198
199
200
201
202
203
204
        )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
        # FIXME(woosuk): Change to debug log.
205
206
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")
207

208
        if num_gpu_blocks <= 0:
209
210
211
212
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `gpu_memory_utilization` when "
                             "initializing the engine.")

213
214
215
216
217
218
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        # Initialize the cache.
        self._run_workers("init_cache_engine", cache_config=self.cache_config)

219
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
220
221
222
223
224
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
225
        # Initialize the cluster.
226
227
        distributed_init_method, placement_group = initialize_cluster(
            parallel_config)
Zhuohan Li's avatar
Zhuohan Li committed
228
        # Create the LLM engine.
229
230
        engine = cls(*engine_configs,
                     distributed_init_method,
231
                     placement_group,
Zhuohan Li's avatar
Zhuohan Li committed
232
233
                     log_stats=not engine_args.disable_log_stats)
        return engine
234

235
236
237
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
238
        prompt: Optional[str],
239
240
241
242
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
243
        """Add a request to the engine's request pool.
244
245

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
246
        scheduler as `engine.step()` is called. The exact scheduling policy is
247
248
249
250
251
252
253
254
255
256
257
258
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
                the current time.
        """
259
260
261
        if arrival_time is None:
            arrival_time = time.time()
        if prompt_token_ids is None:
Woosuk Kwon's avatar
Woosuk Kwon committed
262
            assert prompt is not None
263
264
265
266
            prompt_token_ids = self.tokenizer.encode(prompt)

        # Create the sequences.
        block_size = self.cache_config.block_size
267
268
        seq_id = next(self.seq_counter)
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
269
270

        # Create the sequence group.
271
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
272
273
274
275
276
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
277
278
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
279
280

        Args:
Antoni Baum's avatar
Antoni Baum committed
281
            request_id: The ID(s) of the request to abort.
282
        """
283
284
        self.scheduler.abort_seq_group(request_id)

285
286
287
288
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

289
    def get_num_unfinished_requests(self) -> int:
290
        """Gets the number of unfinished requests."""
291
292
        return self.scheduler.get_num_unfinished_seq_groups()

293
    def has_unfinished_requests(self) -> bool:
294
        """Returns True if there are unfinished requests."""
295
296
        return self.scheduler.has_unfinished_seqs()

Antoni Baum's avatar
Antoni Baum committed
297
298
299
    def _schedule(
        self
    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
300
               List[RequestOutput]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
301
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
302
303
304
305
        return seq_group_metadata_list, scheduler_outputs, [
            RequestOutput.from_seq_group(seq_group)
            for seq_group in scheduler_outputs.ignored_seq_groups
        ]
306

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    def _check_beam_search_early_stopping(
        self,
        early_stopping: Union[bool, str],
        sampling_params: SamplingParams,
        best_running_seq: Sequence,
        current_worst_seq: Sequence,
    ) -> bool:
        assert sampling_params.use_beam_search
        length_penalty = sampling_params.length_penalty
        if early_stopping is True:
            return True

        current_worst_score = (current_worst_seq.get_beam_search_score(
            length_penalty=length_penalty,
            eos_token_id=self.tokenizer.eos_token_id))
        if early_stopping is False:
            highest_attainable_score = (best_running_seq.get_beam_search_score(
                length_penalty=length_penalty,
                eos_token_id=self.tokenizer.eos_token_id))
        else:
            assert early_stopping == "never"
            if length_penalty > 0.0:
                # If length_penalty > 0.0, beam search will prefer longer
                # sequences. The highest attainable score calculation is
                # based on the longest possible sequence length in this case.
                max_possible_length = max(
                    best_running_seq.get_prompt_len() +
                    sampling_params.max_tokens,
                    self.scheduler_config.max_model_len)
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
                        eos_token_id=self.tokenizer.eos_token_id,
                        seq_len=max_possible_length))
            else:
                # Otherwise, beam search will prefer shorter sequences. The
                # highest attainable score calculation is based on the current
                # sequence length.
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
                        eos_token_id=self.tokenizer.eos_token_id))
        return current_worst_score >= highest_attainable_score

    def _process_sequence_group_samples(
            self, seq_group: SequenceGroup,
            samples: List[SequenceOutputs]) -> None:
        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        existing_finished_seqs = seq_group.get_finished_seqs()
        parent_child_dict = {
            parent_seq.seq_id: []
            for parent_seq in parent_seqs
        }
        for sample in samples:
            parent_child_dict[sample.parent_seq_id].append(sample)
        # List of (child, parent)
        child_seqs: List[Tuple[Sequence, Sequence]] = []

        # Process the child samples for each parent sequence
        for parent in parent_seqs:
            child_samples: List[SequenceOutputs] = parent_child_dict[
                parent.seq_id]
            if len(child_samples) == 0:
                # This parent sequence has no children samples. Remove
                # the parent sequence from the sequence group since it will
                # not be used in the future iterations.
                parent.status = SequenceStatus.FINISHED_ABORTED
                seq_group.remove(parent.seq_id)
                self.scheduler.free_seq(parent)
                continue
            # Fork the parent sequence if there are multiple child samples.
            for child_sample in child_samples[:-1]:
                new_child_seq_id = next(self.seq_counter)
                child = parent.fork(new_child_seq_id)
                child.append_token_id(child_sample.output_token,
                                      child_sample.logprobs)
                child_seqs.append((child, parent))
            # Continue the parent sequence for the last child sample.
            # We reuse the parent sequence here to reduce redundant memory
            # copies, especially when using non-beam search sampling methods.
            last_child_sample = child_samples[-1]
            parent.append_token_id(last_child_sample.output_token,
                                   last_child_sample.logprobs)
            child_seqs.append((parent, parent))

        for seq, _ in child_seqs:
393
            self._decode_sequence(seq, seq_group.sampling_params)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
            self._check_stop(seq, seq_group.sampling_params)

        # Non-beam search case
        if not seq_group.sampling_params.use_beam_search:
            # For newly created child sequences, add them to the sequence group
            # and fork them in block manager if they are not finished.
            for seq, parent in child_seqs:
                if seq is not parent:
                    seq_group.add(seq)
                    if not seq.is_finished():
                        self.scheduler.fork_seq(parent, seq)

            # Free the finished and selected parent sequences' memory in block
            # manager. Keep them in the sequence group as candidate output.
            # NOTE: we need to fork the new sequences before freeing the
            # old sequences.
            for seq, parent in child_seqs:
                if seq is parent and seq.is_finished():
                    self.scheduler.free_seq(seq)
            return

        # Beam search case
        # Select the child sequences to keep in the sequence group.
        selected_child_seqs = []
        unselected_child_seqs = []
        beam_width = seq_group.sampling_params.best_of
        length_penalty = seq_group.sampling_params.length_penalty

        # Select the newly finished sequences with the highest scores
        # to replace existing finished sequences.
        # Tuple of (seq, parent, is_new)
        existing_finished_seqs = [(seq, None, False)
                                  for seq in existing_finished_seqs]
        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
                             if seq.is_finished()]
        all_finished_seqs = existing_finished_seqs + new_finished_seqs
        # Sort the finished sequences by their scores.
        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
            length_penalty=length_penalty,
            eos_token_id=self.tokenizer.eos_token_id),
                               reverse=True)
        for seq, parent, is_new in all_finished_seqs[:beam_width]:
            if is_new:
                # A newly generated child sequence finishes and has a high
                # score, so we will add it into the sequence group.
                selected_child_seqs.append((seq, parent))
        for seq, parent, is_new in all_finished_seqs[beam_width:]:
            if is_new:
                # A newly generated child sequence finishes but has a low
                # score, so we will not add it into the sequence group.
                # Additionally, if this sequence is a continuation of a
                # parent sequence, we will need remove the parent sequence
                # from the sequence group.
                unselected_child_seqs.append((seq, parent))
            else:
                # An existing finished sequence has a low score, so we will
                # remove it from the sequence group.
                seq_group.remove(seq.seq_id)

        # select the top beam_width sequences from the running
        # sequences for the next iteration to continue the beam
        # search.
        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                              if not seq.is_finished()]
        # Sort the running sequences by their scores.
        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
            length_penalty=length_penalty,
            eos_token_id=self.tokenizer.eos_token_id),
                                reverse=True)

        # Check if we can stop the beam search.
        if len(running_child_seqs) == 0:
            # No running sequences, stop the beam search.
            stop_beam_search = True
        elif len(all_finished_seqs) < beam_width:
            # Not enough finished sequences, continue the beam search.
            stop_beam_search = False
        else:
            # Check the early stopping criteria
            best_running_seq = running_child_seqs[0][0]
            current_worst_seq = all_finished_seqs[beam_width - 1][0]
            stop_beam_search = self._check_beam_search_early_stopping(
                seq_group.sampling_params.early_stopping,
                seq_group.sampling_params, best_running_seq, current_worst_seq)

        if stop_beam_search:
            # Stop the beam search and remove all the running sequences from
            # the sequence group.
            unselected_child_seqs.extend(running_child_seqs)
        else:
            # Continue the beam search and select the top beam_width sequences
            # to continue the beam search.
            selected_child_seqs.extend(running_child_seqs[:beam_width])
            # The remaining running sequences will not be used in the next
            # iteration. Again, if these sequences are continuations of
            # parent sequences, we will need to remove the parent sequences
            # from the sequence group.
            unselected_child_seqs.extend(running_child_seqs[beam_width:])

        # For newly created child sequences, add them to the sequence group
        # and fork them in block manager if they are not finished.
        for seq, parent in selected_child_seqs:
            if seq is not parent:
                seq_group.add(seq)
                if not seq.is_finished():
                    self.scheduler.fork_seq(parent, seq)

        # Free the finished and selected parent sequences' memory in block
        # manager. Keep them in the sequence group as candidate output.
        for seq, parent in selected_child_seqs:
            if seq is parent and seq.is_finished():
                self.scheduler.free_seq(seq)

        # Remove the unselected parent sequences from the sequence group and
        # free their memory in block manager.
        for seq, parent in unselected_child_seqs:
            if seq is parent:
                # Remove the parent sequence if it is not selected for next
                # iteration
                seq_group.remove(seq.seq_id)
                self.scheduler.free_seq(seq)

    def _process_model_outputs(
            self, output: SamplerOutput,
Antoni Baum's avatar
Antoni Baum committed
518
            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
519
520
521
522
        # Update the scheduled sequence groups with the model outputs.
        scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
        for seq_group, samples in zip(scheduled_seq_groups, output):
            self._process_sequence_group_samples(seq_group, samples)
523
524
525

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
526
527
528

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
529
530
        for seq_group in (scheduled_seq_groups +
                          scheduler_outputs.ignored_seq_groups):
531
            request_output = RequestOutput.from_seq_group(seq_group)
532
            request_outputs.append(request_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
533
534
535
536
537

        if self.log_stats:
            # Log the system stats.
            self._log_system_stats(scheduler_outputs.prompt_run,
                                   scheduler_outputs.num_batched_tokens)
538
539
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
540
541
542
543
544
545
546
547
548
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
549
550
551
        seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
        if scheduler_outputs.is_empty():
            return ignored
Antoni Baum's avatar
Antoni Baum committed
552
553
554
555
556
557
558
559
560
561

        # Execute the model.
        output = self._run_workers(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )

562
        return self._process_model_outputs(output, scheduler_outputs) + ignored
Antoni Baum's avatar
Antoni Baum committed
563

Woosuk Kwon's avatar
Woosuk Kwon committed
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    def _log_system_stats(
        self,
        prompt_run: bool,
        num_batched_tokens: int,
    ) -> None:
        now = time.time()
        # Log the number of batched input tokens.
        if prompt_run:
            self.num_prompt_tokens.append((now, num_batched_tokens))
        else:
            self.num_generation_tokens.append((now, num_batched_tokens))

        elapsed_time = now - self.last_logging_time
        if elapsed_time < _LOGGING_INTERVAL_SEC:
            return

        # Discard the old stats.
        self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
                                  if now - t < _LOGGING_INTERVAL_SEC]
        self.num_generation_tokens = [(t, n)
                                      for t, n in self.num_generation_tokens
                                      if now - t < _LOGGING_INTERVAL_SEC]

        if len(self.num_prompt_tokens) > 1:
            total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
            window = now - self.num_prompt_tokens[0][0]
            avg_prompt_throughput = total_num_tokens / window
        else:
            avg_prompt_throughput = 0.0
        if len(self.num_generation_tokens) > 1:
            total_num_tokens = sum(n
                                   for _, n in self.num_generation_tokens[:-1])
            window = now - self.num_generation_tokens[0][0]
            avg_generation_throughput = total_num_tokens / window
        else:
            avg_generation_throughput = 0.0

        total_num_gpu_blocks = self.cache_config.num_gpu_blocks
        num_free_gpu_blocks = (
            self.scheduler.block_manager.get_num_free_gpu_blocks())
        num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
        gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks

        total_num_cpu_blocks = self.cache_config.num_cpu_blocks
        if total_num_cpu_blocks > 0:
            num_free_cpu_blocks = (
                self.scheduler.block_manager.get_num_free_cpu_blocks())
            num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
            cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
        else:
            cpu_cache_usage = 0.0

        logger.info("Avg prompt throughput: "
                    f"{avg_prompt_throughput:.1f} tokens/s, "
                    "Avg generation throughput: "
                    f"{avg_generation_throughput:.1f} tokens/s, "
                    f"Running: {len(self.scheduler.running)} reqs, "
                    f"Swapped: {len(self.scheduler.swapped)} reqs, "
                    f"Pending: {len(self.scheduler.waiting)} reqs, "
                    f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
                    f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
        self.last_logging_time = now

627
628
    def _decode_sequence(self, seq: Sequence,
                         sampling_params: SamplingParams) -> None:
629
        """Decodes the new token for a sequence."""
630
631
632
633
634
635
636
        (new_tokens, new_output_text, prefix_offset,
         read_offset) = detokenize_incrementally(
             self.tokenizer,
             all_input_ids=seq.get_token_ids(),
             prev_tokens=seq.tokens,
             prefix_offset=seq.prefix_offset,
             read_offset=seq.read_offset,
637
             skip_special_tokens=sampling_params.skip_special_tokens,
638
639
640
641
642
643
644
645
         )
        if seq.tokens is None:
            seq.tokens = new_tokens
        else:
            seq.tokens.extend(new_tokens)
        seq.prefix_offset = prefix_offset
        seq.read_offset = read_offset
        seq.output_text += new_output_text
646
647
648

    def _check_stop(self, seq: Sequence,
                    sampling_params: SamplingParams) -> None:
649
        """Stop the finished sequences."""
650
651
652
653
654
655
656
        for stop_str in sampling_params.stop:
            if seq.output_text.endswith(stop_str):
                # Truncate the output text so that the stop string is
                # not included in the output.
                seq.output_text = seq.output_text[:-len(stop_str)]
                seq.status = SequenceStatus.FINISHED_STOPPED
                return
657
658
659
        if seq.get_last_token_id() in sampling_params.stop_token_ids:
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675

        # Check if the sequence has reached max_model_len.
        if seq.get_len() > self.scheduler_config.max_model_len:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has reached max_tokens.
        if seq.get_output_len() == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has generated the EOS token.
        if ((not sampling_params.ignore_eos)
                and seq.get_last_token_id() == self.tokenizer.eos_token_id):
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
676

677
678
679
680
    def _run_workers(
        self,
        method: str,
        *args,
681
        get_all_outputs: bool = False,
682
683
        **kwargs,
    ) -> Any:
684
        """Runs the given method on all workers."""
685
686
        all_outputs = []
        for worker in self.workers:
687
            if self.parallel_config.worker_use_ray:
688
689
690
                executor = partial(worker.execute_method.remote, method)
            else:
                executor = getattr(worker, method)
Zhuohan Li's avatar
Zhuohan Li committed
691

692
693
            output = executor(*args, **kwargs)
            all_outputs.append(output)
Zhuohan Li's avatar
Zhuohan Li committed
694

695
        if self.parallel_config.worker_use_ray:
696
697
698
699
700
701
702
703
704
705
            all_outputs = ray.get(all_outputs)

        if get_all_outputs:
            return all_outputs

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
        return output