llm_engine.py 39.1 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from typing import Iterable, List, Optional, Tuple, Type, Union
3

4
5
from transformers import PreTrainedTokenizer

6
import vllm
7
8
9
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
                         LoRAConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig, SpeculativeConfig,
10
                         VisionLanguageConfig)
Antoni Baum's avatar
Antoni Baum committed
11
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.engine.arg_utils import EngineArgs
13
from vllm.engine.metrics import StatLogger, Stats
14
from vllm.engine.ray_utils import initialize_ray_cluster
15
from vllm.executor.executor_base import ExecutorBase
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
20
21
22
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
                           SequenceGroup, SequenceGroupOutput, SequenceOutput,
                           SequenceStatus)
23
from vllm.transformers_utils.detokenizer import Detokenizer
24
25
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
yhu422's avatar
yhu422 committed
26
27
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
28
from vllm.utils import Counter
29
30

logger = init_logger(__name__)
31
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
32

33

34
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
35
    """An LLM engine that receives requests and generates texts.
36

Woosuk Kwon's avatar
Woosuk Kwon committed
37
    This is the main class for the vLLM engine. It receives requests
38
39
40
41
42
43
44
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
45
    `AsyncLLMEngine` class wraps this class for online serving.
46

Zhuohan Li's avatar
Zhuohan Li committed
47
48
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
49
50
51
52
53
54
55

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
56
        device_config: The configuration related to the device.
57
58
59
60
61
        lora_config (Optional): The configuration related to serving multi-LoRA.
        vision_language_config (Optional): The configuration related to vision
            language models.
        speculative_config (Optional): The configuration related to speculative
            decoding.
62
63
        executor_class: The model executor class for managing distributed
            execution.
64
        log_stats: Whether to log statistics.
yhu422's avatar
yhu422 committed
65
        usage_context: Specified entry point, used for usage info collection
66
    """
67
68
69
70
71
72
73

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
74
        device_config: DeviceConfig,
75
        load_config: LoadConfig,
76
        lora_config: Optional[LoRAConfig],
77
78
        vision_language_config: Optional[VisionLanguageConfig],
        speculative_config: Optional[SpeculativeConfig],
79
        decoding_config: Optional[DecodingConfig],
80
        executor_class: Type[ExecutorBase],
81
        log_stats: bool,
yhu422's avatar
yhu422 committed
82
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
83
84
    ) -> None:
        logger.info(
85
            f"Initializing an LLM engine (v{vllm.__version__}) with config: "
86
            f"model={model_config.model!r}, "
87
            f"speculative_config={speculative_config!r}, "
88
            f"tokenizer={model_config.tokenizer!r}, "
89
            f"tokenizer_mode={model_config.tokenizer_mode}, "
Jasmond L's avatar
Jasmond L committed
90
            f"revision={model_config.revision}, "
91
            f"tokenizer_revision={model_config.tokenizer_revision}, "
92
            f"trust_remote_code={model_config.trust_remote_code}, "
93
            f"dtype={model_config.dtype}, "
94
            f"max_seq_len={model_config.max_model_len}, "
95
96
            f"download_dir={load_config.download_dir!r}, "
            f"load_format={load_config.load_format}, "
97
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
98
99
            f"disable_custom_all_reduce="
            f"{parallel_config.disable_custom_all_reduce}, "
100
            f"quantization={model_config.quantization}, "
101
            f"enforce_eager={model_config.enforce_eager}, "
102
            f"kv_cache_dtype={cache_config.cache_dtype}, "
103
            f"quantization_param_path={model_config.quantization_param_path}, "
104
            f"device_config={device_config.device}, "
105
            f"decoding_config={decoding_config!r}, "
106
            f"seed={model_config.seed})")
107
108
109
110
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
111
        self.lora_config = lora_config
112
        self.vision_language_config = vision_language_config
113
114
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
115
        self.device_config = device_config
116
        self.speculative_config = speculative_config
117
        self.load_config = load_config
118
        self.decoding_config = decoding_config or DecodingConfig()
119
120
        self.log_stats = log_stats

121
        self._init_tokenizer()
122
        self.detokenizer = Detokenizer(self.tokenizer)
123
124
        self.seq_counter = Counter()

125
126
127
128
129
130
131
132
133
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            vision_language_config=vision_language_config,
            speculative_config=speculative_config,
134
            load_config=load_config,
135
        )
136

137
138
        self._initialize_kv_caches()

yhu422's avatar
yhu422 committed
139
140
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
141
142
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
                    cache_config.cache_dtype,

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

174
175
176
177
        # Ping the tokenizer to ensure liveness if it runs in a
        # different process.
        self.tokenizer.ping()

178
        # Create the scheduler.
179
180
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
181
        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
182

183
184
185
        # Metric Logging.
        if self.log_stats:
            self.stat_logger = StatLogger(
186
187
                local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                labels=dict(model_name=model_config.model))
188
            self.stat_logger.info("cache_config", self.cache_config)
189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
            logger.info(f"Overriding {num_gpu_blocks=} with "
                        f"{num_gpu_blocks_override=}")
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

210
    @classmethod
yhu422's avatar
yhu422 committed
211
212
213
214
215
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "LLMEngine":
216
217
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
218
        engine_config = engine_args.create_engine_config()
219
220

        # Initialize the cluster and specify the executor class.
221
        if engine_config.device_config.device_type == "neuron":
222
223
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
224
        elif engine_config.device_config.device_type == "cpu":
225
226
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
227
228
        elif engine_config.parallel_config.worker_use_ray:
            initialize_ray_cluster(engine_config.parallel_config)
229
230
231
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
        else:
232
            assert engine_config.parallel_config.world_size == 1, (
233
234
235
236
237
                "Ray is required if parallel_config.world_size > 1.")
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor

        # Create the LLM engine.
yhu422's avatar
yhu422 committed
238
        engine = cls(
239
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
240
241
242
243
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
        )
244
        return engine
245

246
247
248
249
250
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

251
    def get_tokenizer(self) -> "PreTrainedTokenizer":
252
        return self.tokenizer.get_lora_tokenizer(None)
253
254
255

    def get_tokenizer_for_seq(self,
                              sequence: Sequence) -> "PreTrainedTokenizer":
256
257
258
259
        return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

    def _init_tokenizer(self, **tokenizer_init_kwargs):
        init_kwargs = dict(
260
            tokenizer_id=self.model_config.tokenizer,
261
262
263
264
265
266
267
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
268
269
        self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
            self.parallel_config.tokenizer_pool_config, **init_kwargs)
270

271
272
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
273
        self.cache_config.verify_with_parallel_config(self.parallel_config)
274
275
276
277
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
278

279
280
281
282
283
284
285
286
287
288
289
290
291
292
    def encode_request(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = self.tokenizer.encode(request_id=request_id,
                                                     prompt=prompt,
                                                     lora_request=lora_request)
        return prompt_token_ids

293
294
295
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
296
        prompt: Optional[str],
297
298
299
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
300
        lora_request: Optional[LoRARequest] = None,
301
        multi_modal_data: Optional[MultiModalData] = None,
302
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
303
        """Add a request to the engine's request pool.
304
305

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
306
        scheduler as `engine.step()` is called. The exact scheduling policy is
307
308
309
310
311
312
313
314
315
316
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
317
                the current monotonic time.
318
            multi_modal_data: Multi modal data per request.
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
343
        """
344
345
346
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
347
348
349
350
351
352
353
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")
354
        if arrival_time is None:
355
            arrival_time = time.time()
356
357
358
359
360
        prompt_token_ids = self.encode_request(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)
361
362
363

        # Create the sequences.
        block_size = self.cache_config.block_size
364
        seq_id = next(self.seq_counter)
365
366
        eos_token_id = self.tokenizer.get_lora_tokenizer(
            lora_request).eos_token_id
367
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
368
                       eos_token_id, lora_request)
369

370
371
372
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
373
374
375
        # inject the eos token id into the sampling_params to support min_tokens
        # processing
        sampling_params.eos_token_id = seq.eos_token_id
376

377
        # Create the sequence group.
378
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
379
                                  arrival_time, lora_request, multi_modal_data)
380
381
382
383

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
384
385
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
386
387

        Args:
Antoni Baum's avatar
Antoni Baum committed
388
            request_id: The ID(s) of the request to abort.
389
390
391
392
393
394
395
396
397
398
399

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
400
        """
401
402
        self.scheduler.abort_seq_group(request_id)

403
404
405
406
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

407
    def get_num_unfinished_requests(self) -> int:
408
        """Gets the number of unfinished requests."""
409
410
        return self.scheduler.get_num_unfinished_seq_groups()

411
    def has_unfinished_requests(self) -> bool:
412
        """Returns True if there are unfinished requests."""
413
414
        return self.scheduler.has_unfinished_seqs()

415
416
417
418
419
420
421
422
423
424
425
426
    def _check_beam_search_early_stopping(
        self,
        early_stopping: Union[bool, str],
        sampling_params: SamplingParams,
        best_running_seq: Sequence,
        current_worst_seq: Sequence,
    ) -> bool:
        assert sampling_params.use_beam_search
        length_penalty = sampling_params.length_penalty
        if early_stopping is True:
            return True

427
        current_worst_score = current_worst_seq.get_beam_search_score(
428
            length_penalty=length_penalty,
429
            eos_token_id=current_worst_seq.eos_token_id)
430
        if early_stopping is False:
431
            highest_attainable_score = best_running_seq.get_beam_search_score(
432
                length_penalty=length_penalty,
433
                eos_token_id=best_running_seq.eos_token_id)
434
435
436
437
438
439
440
441
442
443
444
445
446
        else:
            assert early_stopping == "never"
            if length_penalty > 0.0:
                # If length_penalty > 0.0, beam search will prefer longer
                # sequences. The highest attainable score calculation is
                # based on the longest possible sequence length in this case.
                max_possible_length = max(
                    best_running_seq.get_prompt_len() +
                    sampling_params.max_tokens,
                    self.scheduler_config.max_model_len)
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
447
                        eos_token_id=best_running_seq.eos_token_id,
448
449
450
451
452
453
454
455
                        seq_len=max_possible_length))
            else:
                # Otherwise, beam search will prefer shorter sequences. The
                # highest attainable score calculation is based on the current
                # sequence length.
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
456
                        eos_token_id=best_running_seq.eos_token_id))
457
458
        return current_worst_score >= highest_attainable_score

459
    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
Zhuohan Li's avatar
Zhuohan Li committed
460
                                        outputs: SequenceGroupOutput) -> None:
461

462
463
        # Process prompt logprobs
        prompt_logprobs = outputs.prompt_logprobs
464
        if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
465
466
            self.detokenizer.decode_prompt_logprobs_inplace(
                seq_group, prompt_logprobs)
467
468
469
470
            seq_group.prompt_logprobs = prompt_logprobs

        # Process samples
        samples = outputs.samples
471
472
473
474
475
476
477
478
479
480
481
482
483
        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        existing_finished_seqs = seq_group.get_finished_seqs()
        parent_child_dict = {
            parent_seq.seq_id: []
            for parent_seq in parent_seqs
        }
        for sample in samples:
            parent_child_dict[sample.parent_seq_id].append(sample)
        # List of (child, parent)
        child_seqs: List[Tuple[Sequence, Sequence]] = []

        # Process the child samples for each parent sequence
        for parent in parent_seqs:
Zhuohan Li's avatar
Zhuohan Li committed
484
            child_samples: List[SequenceOutput] = parent_child_dict[
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
                parent.seq_id]
            if len(child_samples) == 0:
                # This parent sequence has no children samples. Remove
                # the parent sequence from the sequence group since it will
                # not be used in the future iterations.
                parent.status = SequenceStatus.FINISHED_ABORTED
                seq_group.remove(parent.seq_id)
                self.scheduler.free_seq(parent)
                continue
            # Fork the parent sequence if there are multiple child samples.
            for child_sample in child_samples[:-1]:
                new_child_seq_id = next(self.seq_counter)
                child = parent.fork(new_child_seq_id)
                child.append_token_id(child_sample.output_token,
                                      child_sample.logprobs)
                child_seqs.append((child, parent))
            # Continue the parent sequence for the last child sample.
            # We reuse the parent sequence here to reduce redundant memory
            # copies, especially when using non-beam search sampling methods.
            last_child_sample = child_samples[-1]
            parent.append_token_id(last_child_sample.output_token,
                                   last_child_sample.logprobs)
            child_seqs.append((parent, parent))

        for seq, _ in child_seqs:
510
            if seq_group.sampling_params.detokenize:
511
                new_char_count = self.detokenizer.decode_sequence_inplace(
512
                    seq, seq_group.sampling_params)
513
514
515
            else:
                new_char_count = 0
            self._check_stop(seq, new_char_count, seq_group.sampling_params)
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552

        # Non-beam search case
        if not seq_group.sampling_params.use_beam_search:
            # For newly created child sequences, add them to the sequence group
            # and fork them in block manager if they are not finished.
            for seq, parent in child_seqs:
                if seq is not parent:
                    seq_group.add(seq)
                    if not seq.is_finished():
                        self.scheduler.fork_seq(parent, seq)

            # Free the finished and selected parent sequences' memory in block
            # manager. Keep them in the sequence group as candidate output.
            # NOTE: we need to fork the new sequences before freeing the
            # old sequences.
            for seq, parent in child_seqs:
                if seq is parent and seq.is_finished():
                    self.scheduler.free_seq(seq)
            return

        # Beam search case
        # Select the child sequences to keep in the sequence group.
        selected_child_seqs = []
        unselected_child_seqs = []
        beam_width = seq_group.sampling_params.best_of
        length_penalty = seq_group.sampling_params.length_penalty

        # Select the newly finished sequences with the highest scores
        # to replace existing finished sequences.
        # Tuple of (seq, parent, is_new)
        existing_finished_seqs = [(seq, None, False)
                                  for seq in existing_finished_seqs]
        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
                             if seq.is_finished()]
        all_finished_seqs = existing_finished_seqs + new_finished_seqs
        # Sort the finished sequences by their scores.
        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
553
            length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
                               reverse=True)
        for seq, parent, is_new in all_finished_seqs[:beam_width]:
            if is_new:
                # A newly generated child sequence finishes and has a high
                # score, so we will add it into the sequence group.
                selected_child_seqs.append((seq, parent))
        for seq, parent, is_new in all_finished_seqs[beam_width:]:
            if is_new:
                # A newly generated child sequence finishes but has a low
                # score, so we will not add it into the sequence group.
                # Additionally, if this sequence is a continuation of a
                # parent sequence, we will need remove the parent sequence
                # from the sequence group.
                unselected_child_seqs.append((seq, parent))
            else:
                # An existing finished sequence has a low score, so we will
                # remove it from the sequence group.
                seq_group.remove(seq.seq_id)

        # select the top beam_width sequences from the running
        # sequences for the next iteration to continue the beam
        # search.
        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                              if not seq.is_finished()]
        # Sort the running sequences by their scores.
        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
580
            length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
                                reverse=True)

        # Check if we can stop the beam search.
        if len(running_child_seqs) == 0:
            # No running sequences, stop the beam search.
            stop_beam_search = True
        elif len(all_finished_seqs) < beam_width:
            # Not enough finished sequences, continue the beam search.
            stop_beam_search = False
        else:
            # Check the early stopping criteria
            best_running_seq = running_child_seqs[0][0]
            current_worst_seq = all_finished_seqs[beam_width - 1][0]
            stop_beam_search = self._check_beam_search_early_stopping(
                seq_group.sampling_params.early_stopping,
                seq_group.sampling_params, best_running_seq, current_worst_seq)

        if stop_beam_search:
            # Stop the beam search and remove all the running sequences from
            # the sequence group.
            unselected_child_seqs.extend(running_child_seqs)
        else:
            # Continue the beam search and select the top beam_width sequences
            # to continue the beam search.
            selected_child_seqs.extend(running_child_seqs[:beam_width])
            # The remaining running sequences will not be used in the next
            # iteration. Again, if these sequences are continuations of
            # parent sequences, we will need to remove the parent sequences
            # from the sequence group.
            unselected_child_seqs.extend(running_child_seqs[beam_width:])

        # For newly created child sequences, add them to the sequence group
        # and fork them in block manager if they are not finished.
        for seq, parent in selected_child_seqs:
            if seq is not parent:
                seq_group.add(seq)
                if not seq.is_finished():
                    self.scheduler.fork_seq(parent, seq)

        # Free the finished and selected parent sequences' memory in block
        # manager. Keep them in the sequence group as candidate output.
        for seq, parent in selected_child_seqs:
            if seq is parent and seq.is_finished():
                self.scheduler.free_seq(seq)

        # Remove the unselected parent sequences from the sequence group and
        # free their memory in block manager.
        for seq, parent in unselected_child_seqs:
            if seq is parent:
                # Remove the parent sequence if it is not selected for next
                # iteration
                seq_group.remove(seq.seq_id)
                self.scheduler.free_seq(seq)

    def _process_model_outputs(
            self, output: SamplerOutput,
Antoni Baum's avatar
Antoni Baum committed
637
            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
638
        now = time.time()
639
640
        # Update the scheduled sequence groups with the model outputs.
        scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
641
642
        for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
            seq_group = scheduled_seq_group.seq_group
643
644
            seq_group.update_num_computed_tokens(
                scheduled_seq_group.token_chunk_size)
645
646
647
648
            # If uncomputed tokens > 0, it means prefill is chunked.
            # We don't need to process outputs in that case.
            if seq_group.get_num_uncomputed_tokens() == 0:
                self._process_sequence_group_outputs(seq_group, outputs)
649
650
651

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
652
653
654

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
655
656
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
657
            seq_group.maybe_set_first_token_time(now)
658
659
660
            request_output = RequestOutput.from_seq_group(seq_group)
            request_outputs.append(request_output)
        for seq_group in scheduler_outputs.ignored_seq_groups:
661
            request_output = RequestOutput.from_seq_group(seq_group)
662
            request_outputs.append(request_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
663

664
        # Log stats.
Woosuk Kwon's avatar
Woosuk Kwon committed
665
        if self.log_stats:
666
            self.stat_logger.log(self._get_stats(scheduler_outputs))
667
668
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
669
670
671
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

687
            - Step 2: Calls the distributed executor to execute the model.
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
            >>>         engine.add_request(str(req_id), prompt, sampling_params)
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
719
        """
720
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
721

722
        if not scheduler_outputs.is_empty():
723
724
725
726
            output = self.model_executor.execute_model(
                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
                scheduler_outputs.blocks_to_swap_out,
                scheduler_outputs.blocks_to_copy)
727
728
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
729

730
        return self._process_model_outputs(output, scheduler_outputs)
Antoni Baum's avatar
Antoni Baum committed
731

732
    def do_log_stats(self) -> None:
733
734
735
        """Forced log when no requests active."""
        if self.log_stats:
            self.stat_logger.log(self._get_stats(scheduler_outputs=None))
736

737
738
739
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
        """Get Stats to be Logged to Prometheus."""
740
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
741

742
743
744
745
        # KV Cache Usage in %.
        num_total_gpu = self.cache_config.num_gpu_blocks
        num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
        gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
746

747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
        num_total_cpu = self.cache_config.num_cpu_blocks
        cpu_cache_usage = 0.
        if num_total_cpu > 0:
            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
            )
            cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)

        # Scheduler State
        num_running = len(self.scheduler.running)
        num_swapped = len(self.scheduler.swapped)
        num_waiting = len(self.scheduler.waiting)

        # Iteration stats if we have scheduler output.
        num_prompt_tokens = 0
        num_generation_tokens = 0
        time_to_first_tokens = []
        time_per_output_tokens = []
        time_e2e_requests = []
        if scheduler_outputs is not None:
766
            prompt_run = scheduler_outputs.num_prefill_groups > 0
767
768
769

            # Number of Tokens.
            if prompt_run:
770
                num_prompt_tokens = sum(
771
772
773
                    len(scheduled_seq_group.seq_group.prompt_token_ids)
                    for scheduled_seq_group in
                    scheduler_outputs.scheduled_seq_groups)
774
                num_generation_tokens = sum(
775
776
777
                    scheduled_seq_group.seq_group.num_seqs()
                    for scheduled_seq_group in
                    scheduler_outputs.scheduled_seq_groups)
778
779
780
781
782
            else:
                num_generation_tokens = scheduler_outputs.num_batched_tokens

            # Latency Timings.
            time_last_iters = []
783
784
            for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
                seq_group = scheduled_seq_group.seq_group
785
786
                # Time since last token.
                # (n.b. updates seq_group.metrics.last_token_time)
787
788
789
                time_last_iters.append(seq_group.get_last_latency(now))
                # Time since arrival for all finished requests.
                if seq_group.is_finished():
790
791
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
792
793
794
795
796
797
798
799
800

            time_to_first_tokens = time_last_iters if prompt_run else []
            time_per_output_tokens = [] if prompt_run else time_last_iters

        return Stats(
            now=now,
            num_running=num_running,
            num_swapped=num_swapped,
            num_waiting=num_waiting,
801
802
            gpu_cache_usage=gpu_cache_usage,
            cpu_cache_usage=cpu_cache_usage,
803
804
805
806
807
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=num_generation_tokens,
            time_to_first_tokens=time_to_first_tokens,
            time_per_output_tokens=time_per_output_tokens,
            time_e2e_requests=time_e2e_requests,
808
809
        )

810
    def _check_stop(self, seq: Sequence, new_char_count: int,
811
                    sampling_params: SamplingParams) -> None:
812
        """Stop the finished sequences.
813

814
815
816
       new_char_count is the number of chars added to the
           sequence's output text for the newly generated token
        """
817
818
819
820
821
822

        # Check if the minimum number of tokens has been generated yet;
        # skip the stop string/token checks if not
        if seq.get_output_len() < sampling_params.min_tokens:
            return

823
824
825
826
827
828
829
830
        # Check if the sequence has generated the EOS token.
        if ((not sampling_params.ignore_eos)
                and seq.get_last_token_id() == seq.eos_token_id):
            seq.status = SequenceStatus.FINISHED_STOPPED
            return

        # Check if a stop token was encountered.
        # This assumes a single token produced per step.
831
832
        last_token_id = seq.get_last_token_id()
        if last_token_id in sampling_params.stop_token_ids:
833
834
835
836
            if new_char_count and (
                    not sampling_params.include_stop_str_in_output):
                # Remove last token
                seq.output_text = seq.output_text[:-new_char_count]
837
            seq.status = SequenceStatus.FINISHED_STOPPED
838
            seq.stop_reason = last_token_id
839
            return
840

841
842
843
844
        # Check if any stop strings are matched.
        stop_str = self._check_stop_strings(seq, new_char_count,
                                            sampling_params)
        if stop_str is not None:
845
            seq.status = SequenceStatus.FINISHED_STOPPED
846
            seq.stop_reason = stop_str
847
            return
848

849
850
851
        # Check if the sequence has reached max_model_len.
        if seq.get_len() > self.scheduler_config.max_model_len:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
852
853
            return

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
        # Check if the sequence has reached max_tokens.
        if seq.get_output_len() == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

    @staticmethod
    def _check_stop_strings(seq: Sequence, new_char_count: int,
                            sampling_params: SamplingParams) -> Optional[str]:
        """Check if any stop strings are matched and truncate sequence
        output text accordingly.

        Returns the stop string if matched or else None.
        """
        if not new_char_count:
            return None

        for stop_str in sampling_params.stop:
            stop_string_len = len(stop_str)
            # Avoid searching already-searched text.
            stop_index = seq.output_text.find(
                stop_str, -new_char_count - stop_string_len)
            if stop_index == -1:
                continue

            if sampling_params.include_stop_str_in_output:
                # Truncate to end of stop string.
                stop_index += stop_string_len
                if stop_index >= len(seq.output_text):
                    # No truncation required.
                    return stop_str

            # Truncate the output text to either the beginning
            # or end of the stop string.
            seq.output_text = seq.output_text[:stop_index]
            return stop_str
        return None
890

891
    def add_lora(self, lora_request: LoRARequest) -> bool:
892
        return self.model_executor.add_lora(lora_request)
893
894

    def remove_lora(self, lora_id: int) -> bool:
895
        return self.model_executor.remove_lora(lora_id)
896
897

    def list_loras(self) -> List[int]:
898
        return self.model_executor.list_loras()
899
900

    def check_health(self) -> None:
901
        self.model_executor.check_health()