llm_engine.py 34.6 KB
Newer Older
Fang li's avatar
Fang li committed
1
import copy
2
from collections import defaultdict
3
import os
Antoni Baum's avatar
Antoni Baum committed
4
import time
5
6
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
                    Union)
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
Antoni Baum's avatar
Antoni Baum committed
10
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.engine.arg_utils import EngineArgs
12
from vllm.engine.metrics import record_metrics
13
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
17
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
18
                           SequenceGroupOutput, SequenceOutput, SequenceStatus)
19
20
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
                                               get_tokenizer)
21
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
22
23
24
25
26
27

if ray:
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup
28
29
30

logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
_LOGGING_INTERVAL_SEC = 5

33

34
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
35
    """An LLM engine that receives requests and generates texts.
36

Woosuk Kwon's avatar
Woosuk Kwon committed
37
    This is the main class for the vLLM engine. It receives requests
38
39
40
41
42
43
44
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
45
    `AsyncLLMEngine` class wraps this class for online serving.
46

Zhuohan Li's avatar
Zhuohan Li committed
47
48
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
49
50
51
52
53
54
55

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
56
57
        placement_group: Ray placement group for distributed execution.
            Required for distributed execution.
58
59
        log_stats: Whether to log statistics.
    """
60
61
62
63
64
65
66

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
67
        placement_group: Optional["PlacementGroup"],
68
        log_stats: bool,
69
70
    ) -> None:
        logger.info(
Zhuohan Li's avatar
Zhuohan Li committed
71
            "Initializing an LLM engine with config: "
72
            f"model={model_config.model!r}, "
73
            f"tokenizer={model_config.tokenizer!r}, "
74
            f"tokenizer_mode={model_config.tokenizer_mode}, "
Jasmond L's avatar
Jasmond L committed
75
            f"revision={model_config.revision}, "
76
            f"tokenizer_revision={model_config.tokenizer_revision}, "
77
            f"trust_remote_code={model_config.trust_remote_code}, "
78
            f"dtype={model_config.dtype}, "
79
            f"max_seq_len={model_config.max_model_len}, "
80
            f"download_dir={model_config.download_dir!r}, "
81
            f"load_format={model_config.load_format}, "
82
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
83
            f"quantization={model_config.quantization}, "
84
            f"enforce_eager={model_config.enforce_eager}, "
85
            f"seed={model_config.seed})")
86
87
88
89
90
91
92
93
94
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.log_stats = log_stats
        self._verify_args()

95
        self.tokenizer = get_tokenizer(
96
97
            model_config.tokenizer,
            tokenizer_mode=model_config.tokenizer_mode,
Jasmond L's avatar
Jasmond L committed
98
            trust_remote_code=model_config.trust_remote_code,
99
            tokenizer_revision=model_config.tokenizer_revision,
Jasmond L's avatar
Jasmond L committed
100
            revision=model_config.revision)
101
102
103
        self.seq_counter = Counter()

        # Create the parallel GPU workers.
104
        if self.parallel_config.worker_use_ray:
105
106
107
108
            # Disable Ray usage stats collection.
            ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
            if ray_usage != "1":
                os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
109
110
            self._init_workers_ray(placement_group)
        else:
111
            self._init_workers()
112

113
114
115
116
        # Profile the memory usage and initialize the cache.
        self._init_cache()

        # Create the scheduler.
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
120
121
122
123
124
        self.scheduler = Scheduler(scheduler_config, cache_config)

        # Logging.
        self.last_logging_time = 0.0
        # List of (timestamp, num_tokens)
        self.num_prompt_tokens: List[Tuple[float, int]] = []
        # List of (timestamp, num_tokens)
        self.num_generation_tokens: List[Tuple[float, int]] = []
125

126
    def _init_workers(self):
127
128
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
129
        from vllm.worker.worker import Worker
130
131
132
133
134

        assert self.parallel_config.world_size == 1, (
            "Ray is required if parallel_config.world_size > 1.")

        self.workers: List[Worker] = []
135
136
        distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
        self.driver_worker = Worker(
137
138
139
            self.model_config,
            self.parallel_config,
            self.scheduler_config,
140
141
142
143
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            is_driver_worker=True,
144
        )
145
146
        self._run_workers("init_model")
        self._run_workers("load_model")
147

Antoni Baum's avatar
Antoni Baum committed
148
149
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
150
151
152
153
        if self.parallel_config.tensor_parallel_size == 1:
            num_gpus = self.cache_config.gpu_memory_utilization
        else:
            num_gpus = 1
154

155
156
157
158
159
        self.driver_dummy_worker: RayWorkerVllm = None
        self.workers: List[RayWorkerVllm] = []

        driver_ip = get_ip()
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
160
161
            if not bundle.get("GPU", 0):
                continue
162
163
164
165
166
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
167
168
            worker = ray.remote(
                num_cpus=0,
Woosuk Kwon's avatar
Woosuk Kwon committed
169
                num_gpus=num_gpus,
170
                scheduling_strategy=scheduling_strategy,
Antoni Baum's avatar
Antoni Baum committed
171
                **ray_remote_kwargs,
172
            )(RayWorkerVllm).remote(self.model_config.trust_remote_code)
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

            worker_ip = ray.get(worker.get_node_ip.remote())
            if worker_ip == driver_ip and self.driver_dummy_worker is None:
                # If the worker is on the same node as the driver, we use it
                # as the resource holder for the driver process.
                self.driver_dummy_worker = worker
            else:
                self.workers.append(worker)

        if self.driver_dummy_worker is None:
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

        driver_node_id, driver_gpu_ids = ray.get(
            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
        worker_node_and_gpu_ids = ray.get(
            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])

        node_workers = defaultdict(list)
        node_gpus = defaultdict(list)

        node_workers[driver_node_id].append(0)
        node_gpus[driver_node_id].extend(driver_gpu_ids)
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
                                               start=1):
            node_workers[node_id].append(i)
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

        # Set CUDA_VISIBLE_DEVICES for the driver.
        set_cuda_visible_devices(node_gpus[driver_node_id])
        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
            worker.set_cuda_visible_devices.remote(node_gpus[node_id])

        distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"

        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker
215
216

        # Initialize torch distributed process group for the workers.
Fang li's avatar
Fang li committed
217
218
219
        model_config = copy.deepcopy(self.model_config)
        parallel_config = copy.deepcopy(self.parallel_config)
        scheduler_config = copy.deepcopy(self.scheduler_config)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

        for rank, (worker, (node_id,
                            _)) in enumerate(zip(self.workers,
                                                 worker_node_and_gpu_ids),
                                             start=1):
            local_rank = node_workers[node_id].index(rank)
            worker.init_worker.remote(
                lambda rank=rank, local_rank=local_rank: Worker(
                    model_config,
                    parallel_config,
                    scheduler_config,
                    local_rank,
                    rank,
                    distributed_init_method,
                ))

        driver_rank = 0
        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
        self.driver_worker = Worker(
            model_config,
            parallel_config,
            scheduler_config,
            driver_local_rank,
            driver_rank,
            distributed_init_method,
            is_driver_worker=True,
246
        )
247
248

        self._run_workers("init_model")
249
250
251
252
253
        self._run_workers(
            "load_model",
            max_concurrent_workers=self.parallel_config.
            max_parallel_loading_workers,
        )
254

255
256
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
257
        self.cache_config.verify_with_parallel_config(self.parallel_config)
258
259

    def _init_cache(self) -> None:
260
        """Profiles the memory usage and initializes the KV cache."""
261
262
263
264
265
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers(
            "profile_num_available_blocks",
            block_size=self.cache_config.block_size,
            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
266
            cpu_swap_space=self.cache_config.swap_space_bytes,
267
268
269
270
271
272
273
274
        )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
        # FIXME(woosuk): Change to debug log.
275
276
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")
277

278
        if num_gpu_blocks <= 0:
279
280
281
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `gpu_memory_utilization` when "
                             "initializing the engine.")
282
283
284
285
286
287
288
289
        max_seq_len = self.cache_config.block_size * num_gpu_blocks
        if self.model_config.max_model_len > max_seq_len:
            raise ValueError(
                f"The model's max seq len ({self.model_config.max_model_len}) "
                "is larger than the maximum number of tokens that can be "
                f"stored in KV cache ({max_seq_len}). Try increasing "
                "`gpu_memory_utilization` or decreasing `max_model_len` when "
                "initializing the engine.")
290

291
292
293
294
295
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        # Initialize the cache.
        self._run_workers("init_cache_engine", cache_config=self.cache_config)
296
297
298
        # Warm up the model. This includes capturing the model into CUDA graph
        # if enforce_eager is False.
        self._run_workers("warm_up_model")
299

300
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
301
302
303
304
305
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
306
        # Initialize the cluster.
307
        placement_group = initialize_cluster(parallel_config)
Zhuohan Li's avatar
Zhuohan Li committed
308
        # Create the LLM engine.
309
        engine = cls(*engine_configs,
310
                     placement_group,
Zhuohan Li's avatar
Zhuohan Li committed
311
312
                     log_stats=not engine_args.disable_log_stats)
        return engine
313

314
315
316
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
317
        prompt: Optional[str],
318
319
320
321
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
322
        """Add a request to the engine's request pool.
323
324

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
325
        scheduler as `engine.step()` is called. The exact scheduling policy is
326
327
328
329
330
331
332
333
334
335
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
336
                the current monotonic time.
337
        """
338
        if arrival_time is None:
339
            arrival_time = time.monotonic()
340
        if prompt_token_ids is None:
Woosuk Kwon's avatar
Woosuk Kwon committed
341
            assert prompt is not None
342
343
344
345
            prompt_token_ids = self.tokenizer.encode(prompt)

        # Create the sequences.
        block_size = self.cache_config.block_size
346
347
        seq_id = next(self.seq_counter)
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
348
349

        # Create the sequence group.
350
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
351
352
353
354
355
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
356
357
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
358
359

        Args:
Antoni Baum's avatar
Antoni Baum committed
360
            request_id: The ID(s) of the request to abort.
361
        """
362
363
        self.scheduler.abort_seq_group(request_id)

364
365
366
367
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

368
    def get_num_unfinished_requests(self) -> int:
369
        """Gets the number of unfinished requests."""
370
371
        return self.scheduler.get_num_unfinished_seq_groups()

372
    def has_unfinished_requests(self) -> bool:
373
        """Returns True if there are unfinished requests."""
374
375
        return self.scheduler.has_unfinished_seqs()

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    def _check_beam_search_early_stopping(
        self,
        early_stopping: Union[bool, str],
        sampling_params: SamplingParams,
        best_running_seq: Sequence,
        current_worst_seq: Sequence,
    ) -> bool:
        assert sampling_params.use_beam_search
        length_penalty = sampling_params.length_penalty
        if early_stopping is True:
            return True

        current_worst_score = (current_worst_seq.get_beam_search_score(
            length_penalty=length_penalty,
            eos_token_id=self.tokenizer.eos_token_id))
        if early_stopping is False:
            highest_attainable_score = (best_running_seq.get_beam_search_score(
                length_penalty=length_penalty,
                eos_token_id=self.tokenizer.eos_token_id))
        else:
            assert early_stopping == "never"
            if length_penalty > 0.0:
                # If length_penalty > 0.0, beam search will prefer longer
                # sequences. The highest attainable score calculation is
                # based on the longest possible sequence length in this case.
                max_possible_length = max(
                    best_running_seq.get_prompt_len() +
                    sampling_params.max_tokens,
                    self.scheduler_config.max_model_len)
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
                        eos_token_id=self.tokenizer.eos_token_id,
                        seq_len=max_possible_length))
            else:
                # Otherwise, beam search will prefer shorter sequences. The
                # highest attainable score calculation is based on the current
                # sequence length.
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
                        eos_token_id=self.tokenizer.eos_token_id))
        return current_worst_score >= highest_attainable_score

420
    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
Zhuohan Li's avatar
Zhuohan Li committed
421
                                        outputs: SequenceGroupOutput) -> None:
422
423
424
425
426
427
428
        # Process prompt logprobs
        prompt_logprobs = outputs.prompt_logprobs
        if prompt_logprobs is not None:
            seq_group.prompt_logprobs = prompt_logprobs

        # Process samples
        samples = outputs.samples
429
430
431
432
433
434
435
436
437
438
439
440
441
        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        existing_finished_seqs = seq_group.get_finished_seqs()
        parent_child_dict = {
            parent_seq.seq_id: []
            for parent_seq in parent_seqs
        }
        for sample in samples:
            parent_child_dict[sample.parent_seq_id].append(sample)
        # List of (child, parent)
        child_seqs: List[Tuple[Sequence, Sequence]] = []

        # Process the child samples for each parent sequence
        for parent in parent_seqs:
Zhuohan Li's avatar
Zhuohan Li committed
442
            child_samples: List[SequenceOutput] = parent_child_dict[
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
                parent.seq_id]
            if len(child_samples) == 0:
                # This parent sequence has no children samples. Remove
                # the parent sequence from the sequence group since it will
                # not be used in the future iterations.
                parent.status = SequenceStatus.FINISHED_ABORTED
                seq_group.remove(parent.seq_id)
                self.scheduler.free_seq(parent)
                continue
            # Fork the parent sequence if there are multiple child samples.
            for child_sample in child_samples[:-1]:
                new_child_seq_id = next(self.seq_counter)
                child = parent.fork(new_child_seq_id)
                child.append_token_id(child_sample.output_token,
                                      child_sample.logprobs)
                child_seqs.append((child, parent))
            # Continue the parent sequence for the last child sample.
            # We reuse the parent sequence here to reduce redundant memory
            # copies, especially when using non-beam search sampling methods.
            last_child_sample = child_samples[-1]
            parent.append_token_id(last_child_sample.output_token,
                                   last_child_sample.logprobs)
            child_seqs.append((parent, parent))

        for seq, _ in child_seqs:
468
            self._decode_sequence(seq, seq_group.sampling_params)
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
            self._check_stop(seq, seq_group.sampling_params)

        # Non-beam search case
        if not seq_group.sampling_params.use_beam_search:
            # For newly created child sequences, add them to the sequence group
            # and fork them in block manager if they are not finished.
            for seq, parent in child_seqs:
                if seq is not parent:
                    seq_group.add(seq)
                    if not seq.is_finished():
                        self.scheduler.fork_seq(parent, seq)

            # Free the finished and selected parent sequences' memory in block
            # manager. Keep them in the sequence group as candidate output.
            # NOTE: we need to fork the new sequences before freeing the
            # old sequences.
            for seq, parent in child_seqs:
                if seq is parent and seq.is_finished():
                    self.scheduler.free_seq(seq)
            return

        # Beam search case
        # Select the child sequences to keep in the sequence group.
        selected_child_seqs = []
        unselected_child_seqs = []
        beam_width = seq_group.sampling_params.best_of
        length_penalty = seq_group.sampling_params.length_penalty

        # Select the newly finished sequences with the highest scores
        # to replace existing finished sequences.
        # Tuple of (seq, parent, is_new)
        existing_finished_seqs = [(seq, None, False)
                                  for seq in existing_finished_seqs]
        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
                             if seq.is_finished()]
        all_finished_seqs = existing_finished_seqs + new_finished_seqs
        # Sort the finished sequences by their scores.
        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
            length_penalty=length_penalty,
            eos_token_id=self.tokenizer.eos_token_id),
                               reverse=True)
        for seq, parent, is_new in all_finished_seqs[:beam_width]:
            if is_new:
                # A newly generated child sequence finishes and has a high
                # score, so we will add it into the sequence group.
                selected_child_seqs.append((seq, parent))
        for seq, parent, is_new in all_finished_seqs[beam_width:]:
            if is_new:
                # A newly generated child sequence finishes but has a low
                # score, so we will not add it into the sequence group.
                # Additionally, if this sequence is a continuation of a
                # parent sequence, we will need remove the parent sequence
                # from the sequence group.
                unselected_child_seqs.append((seq, parent))
            else:
                # An existing finished sequence has a low score, so we will
                # remove it from the sequence group.
                seq_group.remove(seq.seq_id)

        # select the top beam_width sequences from the running
        # sequences for the next iteration to continue the beam
        # search.
        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                              if not seq.is_finished()]
        # Sort the running sequences by their scores.
        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
            length_penalty=length_penalty,
            eos_token_id=self.tokenizer.eos_token_id),
                                reverse=True)

        # Check if we can stop the beam search.
        if len(running_child_seqs) == 0:
            # No running sequences, stop the beam search.
            stop_beam_search = True
        elif len(all_finished_seqs) < beam_width:
            # Not enough finished sequences, continue the beam search.
            stop_beam_search = False
        else:
            # Check the early stopping criteria
            best_running_seq = running_child_seqs[0][0]
            current_worst_seq = all_finished_seqs[beam_width - 1][0]
            stop_beam_search = self._check_beam_search_early_stopping(
                seq_group.sampling_params.early_stopping,
                seq_group.sampling_params, best_running_seq, current_worst_seq)

        if stop_beam_search:
            # Stop the beam search and remove all the running sequences from
            # the sequence group.
            unselected_child_seqs.extend(running_child_seqs)
        else:
            # Continue the beam search and select the top beam_width sequences
            # to continue the beam search.
            selected_child_seqs.extend(running_child_seqs[:beam_width])
            # The remaining running sequences will not be used in the next
            # iteration. Again, if these sequences are continuations of
            # parent sequences, we will need to remove the parent sequences
            # from the sequence group.
            unselected_child_seqs.extend(running_child_seqs[beam_width:])

        # For newly created child sequences, add them to the sequence group
        # and fork them in block manager if they are not finished.
        for seq, parent in selected_child_seqs:
            if seq is not parent:
                seq_group.add(seq)
                if not seq.is_finished():
                    self.scheduler.fork_seq(parent, seq)

        # Free the finished and selected parent sequences' memory in block
        # manager. Keep them in the sequence group as candidate output.
        for seq, parent in selected_child_seqs:
            if seq is parent and seq.is_finished():
                self.scheduler.free_seq(seq)

        # Remove the unselected parent sequences from the sequence group and
        # free their memory in block manager.
        for seq, parent in unselected_child_seqs:
            if seq is parent:
                # Remove the parent sequence if it is not selected for next
                # iteration
                seq_group.remove(seq.seq_id)
                self.scheduler.free_seq(seq)

    def _process_model_outputs(
            self, output: SamplerOutput,
Antoni Baum's avatar
Antoni Baum committed
593
            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
594
595
        # Update the scheduled sequence groups with the model outputs.
        scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
596
597
        for seq_group, outputs in zip(scheduled_seq_groups, output):
            self._process_sequence_group_outputs(seq_group, outputs)
598
599
600

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
601
602
603

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
604
605
606
607
        for seq_group in scheduled_seq_groups:
            request_output = RequestOutput.from_seq_group(seq_group)
            request_outputs.append(request_output)
        for seq_group in scheduler_outputs.ignored_seq_groups:
608
            request_output = RequestOutput.from_seq_group(seq_group)
609
            request_outputs.append(request_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
610
611
612
613
614

        if self.log_stats:
            # Log the system stats.
            self._log_system_stats(scheduler_outputs.prompt_run,
                                   scheduler_outputs.num_batched_tokens)
615
616
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
617
618
619
620
621
622
623
624
625
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
626
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
627

628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        if not scheduler_outputs.is_empty():
            # Execute the model.
            all_outputs = self._run_workers(
                "execute_model",
                driver_kwargs={
                    "seq_group_metadata_list": seq_group_metadata_list,
                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
                })

            # Only the driver worker returns the sampling results.
            output = all_outputs[0]
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
643

644
        return self._process_model_outputs(output, scheduler_outputs)
Antoni Baum's avatar
Antoni Baum committed
645

646
647
648
    def do_log_stats(self) -> None:
        self._log_system_stats(False, 0)

Woosuk Kwon's avatar
Woosuk Kwon committed
649
650
651
652
653
    def _log_system_stats(
        self,
        prompt_run: bool,
        num_batched_tokens: int,
    ) -> None:
654
        now = time.monotonic()
Woosuk Kwon's avatar
Woosuk Kwon committed
655
656
657
658
659
660
        # Log the number of batched input tokens.
        if prompt_run:
            self.num_prompt_tokens.append((now, num_batched_tokens))
        else:
            self.num_generation_tokens.append((now, num_batched_tokens))

661
662
        should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC
        if not should_log:
Woosuk Kwon's avatar
Woosuk Kwon committed
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
            return

        # Discard the old stats.
        self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
                                  if now - t < _LOGGING_INTERVAL_SEC]
        self.num_generation_tokens = [(t, n)
                                      for t, n in self.num_generation_tokens
                                      if now - t < _LOGGING_INTERVAL_SEC]

        if len(self.num_prompt_tokens) > 1:
            total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
            window = now - self.num_prompt_tokens[0][0]
            avg_prompt_throughput = total_num_tokens / window
        else:
            avg_prompt_throughput = 0.0
        if len(self.num_generation_tokens) > 1:
            total_num_tokens = sum(n
                                   for _, n in self.num_generation_tokens[:-1])
            window = now - self.num_generation_tokens[0][0]
            avg_generation_throughput = total_num_tokens / window
        else:
            avg_generation_throughput = 0.0

        total_num_gpu_blocks = self.cache_config.num_gpu_blocks
        num_free_gpu_blocks = (
            self.scheduler.block_manager.get_num_free_gpu_blocks())
        num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
        gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks

        total_num_cpu_blocks = self.cache_config.num_cpu_blocks
        if total_num_cpu_blocks > 0:
            num_free_cpu_blocks = (
                self.scheduler.block_manager.get_num_free_cpu_blocks())
            num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
            cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
        else:
            cpu_cache_usage = 0.0

701
702
703
704
705
706
707
708
709
710
        record_metrics(
            avg_prompt_throughput=avg_prompt_throughput,
            avg_generation_throughput=avg_generation_throughput,
            scheduler_running=len(self.scheduler.running),
            scheduler_swapped=len(self.scheduler.swapped),
            scheduler_waiting=len(self.scheduler.waiting),
            gpu_cache_usage=gpu_cache_usage,
            cpu_cache_usage=cpu_cache_usage,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
711
712
713
714
715
716
717
718
719
720
721
        logger.info("Avg prompt throughput: "
                    f"{avg_prompt_throughput:.1f} tokens/s, "
                    "Avg generation throughput: "
                    f"{avg_generation_throughput:.1f} tokens/s, "
                    f"Running: {len(self.scheduler.running)} reqs, "
                    f"Swapped: {len(self.scheduler.swapped)} reqs, "
                    f"Pending: {len(self.scheduler.waiting)} reqs, "
                    f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
                    f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
        self.last_logging_time = now

722
    def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
723
        """Decodes the new token for a sequence."""
724
725
726
727
728
729
730
        (new_tokens, new_output_text, prefix_offset,
         read_offset) = detokenize_incrementally(
             self.tokenizer,
             all_input_ids=seq.get_token_ids(),
             prev_tokens=seq.tokens,
             prefix_offset=seq.prefix_offset,
             read_offset=seq.read_offset,
731
732
             skip_special_tokens=prms.skip_special_tokens,
             spaces_between_special_tokens=prms.spaces_between_special_tokens,
733
734
735
736
737
738
739
740
         )
        if seq.tokens is None:
            seq.tokens = new_tokens
        else:
            seq.tokens.extend(new_tokens)
        seq.prefix_offset = prefix_offset
        seq.read_offset = read_offset
        seq.output_text += new_output_text
741
742
743

    def _check_stop(self, seq: Sequence,
                    sampling_params: SamplingParams) -> None:
744
        """Stop the finished sequences."""
745
746
        for stop_str in sampling_params.stop:
            if seq.output_text.endswith(stop_str):
747
748
749
750
                if not sampling_params.include_stop_str_in_output:
                    # Truncate the output text so that the stop string is
                    # not included in the output.
                    seq.output_text = seq.output_text[:-len(stop_str)]
751
752
                seq.status = SequenceStatus.FINISHED_STOPPED
                return
753
754
755
        if seq.get_last_token_id() in sampling_params.stop_token_ids:
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771

        # Check if the sequence has reached max_model_len.
        if seq.get_len() > self.scheduler_config.max_model_len:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has reached max_tokens.
        if seq.get_output_len() == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has generated the EOS token.
        if ((not sampling_params.ignore_eos)
                and seq.get_last_token_id() == self.tokenizer.eos_token_id):
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
772

773
774
775
776
    def _run_workers(
        self,
        method: str,
        *args,
777
778
        driver_args: Optional[List[Any]] = None,
        driver_kwargs: Optional[Dict[str, Any]] = None,
779
780
781
782
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers."""
783

784
        if max_concurrent_workers:
785
786
787
788
789
790
791
792
793
794
795
796
797
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

        # Start the ray workers first.
        ray_worker_outputs = [
            worker.execute_method.remote(method, *args, **kwargs)
            for worker in self.workers
        ]

        if driver_args is None:
            driver_args = args
        if driver_kwargs is None:
            driver_kwargs = kwargs
798

799
800
801
        # Start the driver worker after all the ray workers.
        driver_worker_output = getattr(self.driver_worker,
                                       method)(*driver_args, **driver_kwargs)
802

803
804
805
        # Get the results of the ray workers.
        if self.workers:
            ray_worker_outputs = ray.get(ray_worker_outputs)
806

807
        return [driver_worker_output] + ray_worker_outputs