llm_engine.py 40.4 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
3
4
5
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union
6

7
from transformers import GenerationConfig, PreTrainedTokenizer
8

9
import vllm
10
11
12
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
                         LoRAConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig, SpeculativeConfig,
13
                         VisionLanguageConfig)
14
15
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.engine.metrics import StatLogger, Stats
18
19
20
21
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
22
from vllm.executor.executor_base import ExecutorBase
23
from vllm.executor.ray_utils import initialize_ray_cluster
24
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
28
29
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                          RequestOutputFactory)
from vllm.pooling_params import PoolingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.sampling_params import SamplingParams
31
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
32
33
                           PoolerOutput, SamplerOutput, Sequence,
                           SequenceGroup, SequenceGroupMetadata,
34
                           SequenceStatus)
35
from vllm.transformers_utils.detokenizer import Detokenizer
36
37
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
yhu422's avatar
yhu422 committed
38
39
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
40
from vllm.utils import Counter
41
42

logger = init_logger(__name__)
43
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
44

45

46
47
48
49
50
51
52
53
54
55
56
def _load_generation_config_dict(model_config: ModelConfig):
    try:
        return GenerationConfig.from_pretrained(
            model_config.model,
            revision=model_config.revision,
        ).to_diff_dict()
    except OSError:
        # Not found.
        return {}


57
58
59
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)


60
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
61
    """An LLM engine that receives requests and generates texts.
62

Woosuk Kwon's avatar
Woosuk Kwon committed
63
    This is the main class for the vLLM engine. It receives requests
64
65
66
67
68
69
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

70
71
    The :class:`~vllm.LLM` class wraps this class for offline batched inference
    and the :class:`AsyncLLMEngine` class wraps this class for online serving.
72

73
74
    The config arguments are derived from :class:`~vllm.EngineArgs`. (See
    :ref:`engine_args`)
75
76
77
78
79
80
81

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
82
        device_config: The configuration related to the device.
83
84
85
86
87
        lora_config (Optional): The configuration related to serving multi-LoRA.
        vision_language_config (Optional): The configuration related to vision
            language models.
        speculative_config (Optional): The configuration related to speculative
            decoding.
88
89
        executor_class: The model executor class for managing distributed
            execution.
90
        log_stats: Whether to log statistics.
91
        usage_context: Specified entry point, used for usage info collection.
92
    """
93

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
    """A flag to toggle whether to validate the type of request output."""

    @classmethod
    @contextmanager
    def enable_output_validation(cls):
        cls.DO_VALIDATE_OUTPUT = True

        yield

        cls.DO_VALIDATE_OUTPUT = False

    @classmethod
    def validate_output(
        cls,
        output: object,
        output_type: Type[_O],
    ) -> _O:
        do_validate = cls.DO_VALIDATE_OUTPUT

        if ((TYPE_CHECKING or do_validate)
                and not isinstance(output, output_type)):
            raise TypeError(f"Expected output of type {output_type}, "
                            f"but found type {type(output)}")

        return output

    @classmethod
    def validate_outputs(
        cls,
        outputs: GenericSequence[object],
        output_type: Type[_O],
    ) -> List[_O]:
        do_validate = cls.DO_VALIDATE_OUTPUT

        outputs_: List[_O]
        if TYPE_CHECKING or do_validate:
            outputs_ = []
            for output in outputs:
                if not isinstance(output, output_type):
                    raise TypeError(f"Expected output of type {output_type}, "
                                    f"but found type {type(output)}")

                outputs_.append(output)
        else:
            outputs_ = outputs

        return outputs_

    tokenizer: Optional[BaseTokenizerGroup]

145
146
147
148
149
150
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
151
        device_config: DeviceConfig,
152
        load_config: LoadConfig,
153
        lora_config: Optional[LoRAConfig],
154
155
        vision_language_config: Optional[VisionLanguageConfig],
        speculative_config: Optional[SpeculativeConfig],
156
        decoding_config: Optional[DecodingConfig],
157
        executor_class: Type[ExecutorBase],
158
        log_stats: bool,
yhu422's avatar
yhu422 committed
159
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
160
161
    ) -> None:
        logger.info(
162
163
164
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
165
            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
166
167
168
169
            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
            "disable_custom_all_reduce=%s, quantization=%s, "
            "enforce_eager=%s, kv_cache_dtype=%s, "
170
            "quantization_param_path=%s, device_config=%s, "
171
            "decoding_config=%r, seed=%d, served_model_name=%s)",
172
173
174
175
176
177
178
            vllm.__version__,
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
179
            model_config.rope_scaling,
180
            model_config.rope_theta,
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
            model_config.seed,
196
            model_config.served_model_name,
197
        )
198
199
200
201
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
202
        self.lora_config = lora_config
203
        self.vision_language_config = vision_language_config
204
205
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
206
        self.device_config = device_config
207
        self.speculative_config = speculative_config
208
        self.load_config = load_config
209
        self.decoding_config = decoding_config or DecodingConfig()
210
211
        self.log_stats = log_stats

212
        if not self.model_config.skip_tokenizer_init:
213
            self.tokenizer = self._init_tokenizer()
214
215
216
            self.detokenizer = Detokenizer(self.tokenizer)
        else:
            self.tokenizer = None
217
            self.detokenizer = None
218

219
        self.seq_counter = Counter()
220
221
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
222

223
224
225
226
227
228
229
230
231
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            vision_language_config=vision_language_config,
            speculative_config=speculative_config,
232
            load_config=load_config,
233
        )
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        init_success = False
        try:
            if not self.model_config.embedding_mode:
                self._initialize_kv_caches()

            # If usage stat is enabled, collect relevant info.
            if is_usage_stats_enabled():
                from vllm.model_executor.model_loader import (
                    get_architecture_class_name)
                usage_message.report_usage(
                    get_architecture_class_name(model_config),
                    usage_context,
                    extra_kvs={
                        # Common configuration
                        "dtype":
                        str(model_config.dtype),
                        "tensor_parallel_size":
                        parallel_config.tensor_parallel_size,
                        "block_size":
                        cache_config.block_size,
                        "gpu_memory_utilization":
                        cache_config.gpu_memory_utilization,

                        # Quantization
                        "quantization":
                        model_config.quantization,
                        "kv_cache_dtype":
                        cache_config.cache_dtype,

                        # Feature flags
                        "enable_lora":
                        bool(lora_config),
                        "enable_prefix_caching":
                        cache_config.enable_prefix_caching,
                        "enforce_eager":
                        model_config.enforce_eager,
                        "disable_custom_all_reduce":
                        parallel_config.disable_custom_all_reduce,
                    })

            if self.tokenizer:
                # Ping the tokenizer to ensure liveness if it runs in a
                # different process.
                self.tokenizer.ping()

            # Create the scheduler.
            # NOTE: the cache_config here have been updated with the numbers of
            # GPU and CPU blocks, which are profiled in the distributed executor.
            self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)

            # Metric Logging.
            if self.log_stats:
                self.stat_logger = StatLogger(
                    local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                    labels=dict(model_name=model_config.served_model_name),
                    max_model_len=self.model_config.max_model_len)
                self.stat_logger.info("cache_config", self.cache_config)

            tokenizer_group = self.get_tokenizer_group()

            def get_tokenizer_for_seq(self,
                              sequence: Sequence) -> "PreTrainedTokenizer":
                return tokenizer_group.get_lora_tokenizer(
                    sequence.lora_request)
                
            # Create sequence output processor, e.g. for beam search or
            # speculative decoding.
            self.output_processor = (
                SequenceGroupOutputProcessor.create_output_processor(
                    self.scheduler_config,
                    self.detokenizer,
                    self.scheduler,
                    self.seq_counter,
                    get_tokenizer_for_seq,
                    stop_checker=StopChecker(
                        self.scheduler_config.max_model_len,
                        get_tokenizer_for_seq,
                    ),
                ))
            init_success = True
        finally:
            if not init_success:
                # Ensure that model_executor is shut down if LLMEngine init
                # failed
                self.model_executor.shutdown()
320

321
322
323
324
325
326
327
328
329
330
331
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
332
333
334
335
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
336
337
338
339
340
341
342
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

343
    @classmethod
yhu422's avatar
yhu422 committed
344
345
346
347
348
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "LLMEngine":
349
350
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
351
        engine_config = engine_args.create_engine_config()
352
353
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
354
355

        # Initialize the cluster and specify the executor class.
356
        if engine_config.device_config.device_type == "neuron":
357
358
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
359
        elif engine_config.device_config.device_type == "cpu":
360
361
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
362
        elif distributed_executor_backend == "ray":
363
            initialize_ray_cluster(engine_config.parallel_config)
364
365
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
366
367
368
369
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutor)
            executor_class = MultiprocessingGPUExecutor
370
371
372
373
374
        else:
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor

        # Create the LLM engine.
yhu422's avatar
yhu422 committed
375
        engine = cls(
376
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
377
378
379
380
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
        )
381
        return engine
382

383
384
385
386
387
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

388
389
390
391
392
393
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

394
395
396
397
398
399
400
401
402
403
404
    MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
                                   "skip_tokenizer_init is True")

    def get_tokenizer_group(
            self,
            fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
        if self.tokenizer is None:
            raise ValueError(fail_msg)

        return self.tokenizer

405
    def get_tokenizer(self) -> "PreTrainedTokenizer":
406
        return self.get_tokenizer_group().get_lora_tokenizer(None)
407

408
409
410
411
    # def get_tokenizer_for_seq(self,
    #                           sequence: Sequence) -> "PreTrainedTokenizer":
    #     return self.get_tokenizer_group().get_lora_tokenizer(
    #         sequence.lora_request)
412

413
    def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
414
        init_kwargs = dict(
415
            tokenizer_id=self.model_config.tokenizer,
416
417
418
419
420
421
422
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
423
424
425

        return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
                                   **init_kwargs)
426

427
428
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
429
        self.cache_config.verify_with_parallel_config(self.parallel_config)
430
431
432
433
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
434

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    def _get_eos_token_id(
            self, lora_request: Optional[LoRARequest]) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def _add_processed_request(
        self,
        request_id: str,
        processed_inputs: LLMInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: float,
        lora_request: Optional[LoRARequest],
    ) -> None:
        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
        eos_token_id = self._get_eos_token_id(lora_request)

        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
                       lora_request)

        # Create a SequenceGroup based on SamplingParams or PoolingParams
        if isinstance(params, SamplingParams):
            seq_group = self._create_sequence_group_with_sampling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
            )
        elif isinstance(params, PoolingParams):
            seq_group = self._create_sequence_group_with_pooling(
                request_id,
                seq,
                params,
                arrival_time=arrival_time,
                lora_request=lora_request,
            )
        else:
            raise ValueError(
                "Either SamplingParams or PoolingParams must be provided.")

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

    def process_model_inputs(
485
        self,
486
487
        request_id: str,
        inputs: PromptInputs,
488
        lora_request: Optional[LoRARequest] = None,
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = tokenizer.encode(request_id=request_id,
                                                prompt=inputs["prompt"],
                                                lora_request=lora_request)
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

        return LLMInputs(prompt_token_ids=prompt_token_ids,
                         prompt=inputs.get("prompt"),
                         multi_modal_data=inputs.get("multi_modal_data"))
506

507
508
509
    def add_request(
        self,
        request_id: str,
510
        inputs: PromptInputs,
511
        params: Union[SamplingParams, PoolingParams],
512
        arrival_time: Optional[float] = None,
513
        lora_request: Optional[LoRARequest] = None,
514
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
515
        """Add a request to the engine's request pool.
516
517

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
518
        scheduler as `engine.step()` is called. The exact scheduling policy is
519
520
521
522
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
523
524
525
526
527
528
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
            params: Parameters for sampling or pooling.
                :class:`~vllm.SamplingParams` for text generation.
                :class:`~vllm.PoolingParams` for pooling.
529
            arrival_time: The arrival time of the request. If None, we use
530
                the current monotonic time.
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
555
        """
556
557
558
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
559
        if arrival_time is None:
560
            arrival_time = time.time()
561

562
563
564
        processed_inputs = self.process_model_inputs(request_id=request_id,
                                                     inputs=inputs,
                                                     lora_request=lora_request)
565

566
567
568
569
570
571
572
        self._add_processed_request(
            request_id=request_id,
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
        )
573
574
575
576
577
578

    def _create_sequence_group_with_sampling(
        self,
        request_id: str,
        seq: Sequence,
        sampling_params: SamplingParams,
579
580
        arrival_time: float,
        lora_request: Optional[LoRARequest],
581
582
583
584
585
586
587
588
589
590
    ) -> SequenceGroup:
        """Creates a SequenceGroup with SamplingParams."""
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")

591
592
593
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
594
        # Add the eos token id into the sampling_params to support min_tokens
595
        # processing
596
597
        if seq.eos_token_id is not None:
            sampling_params.all_stop_token_ids.add(seq.eos_token_id)
598
599
        sampling_params.update_from_generation_config(
            self.generation_config_fields)
600

601
        # Create the sequence group.
602
603
604
605
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  sampling_params=sampling_params,
606
                                  lora_request=lora_request)
607

608
609
610
611
612
613
614
        return seq_group

    def _create_sequence_group_with_pooling(
        self,
        request_id: str,
        seq: Sequence,
        pooling_params: PoolingParams,
615
616
        arrival_time: float,
        lora_request: Optional[LoRARequest],
617
618
619
620
621
622
623
624
625
626
627
    ) -> SequenceGroup:
        """Creates a SequenceGroup with PoolingParams."""
        # Defensive copy of PoolingParams, which are used by the pooler
        pooling_params = pooling_params.clone()
        # Create the sequence group.
        seq_group = SequenceGroup(request_id=request_id,
                                  seqs=[seq],
                                  arrival_time=arrival_time,
                                  lora_request=lora_request,
                                  pooling_params=pooling_params)
        return seq_group
628

Antoni Baum's avatar
Antoni Baum committed
629
630
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
631
632

        Args:
Antoni Baum's avatar
Antoni Baum committed
633
            request_id: The ID(s) of the request to abort.
634
635
636
637
638
639
640
641
642
643
644

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
645
        """
646
647
        self.scheduler.abort_seq_group(request_id)

648
649
650
651
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

652
653
654
655
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

656
    def get_num_unfinished_requests(self) -> int:
657
        """Gets the number of unfinished requests."""
658
659
        return self.scheduler.get_num_unfinished_seq_groups()

660
    def has_unfinished_requests(self) -> bool:
661
        """Returns True if there are unfinished requests."""
662
663
        return self.scheduler.has_unfinished_seqs()

664
665
666
667
668
669
670
671
672
673
674
675
    def _process_sequence_group_outputs(
        self,
        seq_group: SequenceGroup,
        outputs: List[EmbeddingSequenceGroupOutput],
    ) -> None:
        seq_group.embeddings = outputs[0].embeddings

        for seq in seq_group.get_seqs():
            seq.status = SequenceStatus.FINISHED_STOPPED

        return

676
    def _process_model_outputs(
677
        self,
678
        output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
679
        scheduled_seq_groups: List[ScheduledSequenceGroup],
680
681
        ignored_seq_groups: List[SequenceGroup],
        seq_group_metadata_list: List[SequenceGroupMetadata],
682
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
683
        """Apply the model output to the sequences in the scheduled seq groups.
684

685
686
687
        Returns RequestOutputs that can be returned to the client.
        """

688
        now = time.time()
689
690
691
692

        # Organize outputs by [sequence group][step] instead of
        # [step][sequence group].
        output_by_sequence_group = create_output_by_sequence_group(
693
            output, num_seq_groups=len(scheduled_seq_groups))
694

695
        # Update the scheduled sequence groups with the model outputs.
696
697
698
        for scheduled_seq_group, outputs, seq_group_meta in zip(
                scheduled_seq_groups, output_by_sequence_group,
                seq_group_metadata_list):
699
            seq_group = scheduled_seq_group.seq_group
700
701
            seq_group.update_num_computed_tokens(
                scheduled_seq_group.token_chunk_size)
702
703
704
            if self.model_config.embedding_mode:
                self._process_sequence_group_outputs(seq_group, outputs)
                continue
705

706
707
            self.output_processor.process_prompt_logprob(seq_group, outputs)
            if seq_group_meta.do_sample:
708
                self.output_processor.process_outputs(seq_group, outputs)
709
710
711

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
712
713

        # Create the outputs.
714
715
        request_outputs: List[Union[RequestOutput,
                                    EmbeddingRequestOutput]] = []
716
717
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
718
            seq_group.maybe_set_first_token_time(now)
719
            request_output = RequestOutputFactory.create(seq_group)
720
            request_outputs.append(request_output)
721
        for seq_group in ignored_seq_groups:
722
            request_output = RequestOutputFactory.create(seq_group)
723
724
725
            request_outputs.append(request_output)
        return request_outputs

726
    def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
727
728
        """Performs one decoding iteration and returns newly generated results.

729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

744
            - Step 2: Calls the distributed executor to execute the model.
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
766
            >>>         engine.add_request(str(req_id),prompt,sampling_params)
767
768
769
770
771
772
773
774
775
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
776
        """
777
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
778

779
        if not scheduler_outputs.is_empty():
780
            execute_model_req = ExecuteModelRequest(
781
782
783
784
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
785
786
787
788
789
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
            output = self.model_executor.execute_model(
                execute_model_req=execute_model_req)
790
791
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
792

793
794
        request_outputs = self._process_model_outputs(
            output, scheduler_outputs.scheduled_seq_groups,
795
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
796
797

        # Log stats.
798
        self.do_log_stats(scheduler_outputs, output)
799

800
801
        # if not request_outputs:
        if not self.has_unfinished_requests():
802
803
804
805
806
807
808
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            self.model_executor.stop_remote_worker_execution_loop()

809
        return request_outputs
Antoni Baum's avatar
Antoni Baum committed
810

811
812
813
814
    def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
815
816
        """Forced log when no requests active."""
        if self.log_stats:
817
818
            self.stat_logger.log(
                self._get_stats(scheduler_outputs, model_output))
819

820
821
822
823
824
825
826
827
828
829
830
831
    def _get_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs],
            model_output: Optional[List[SamplerOutput]] = None) -> Stats:
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
        """
832
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
833

834
835
836
837
838
839
840
        # System State
        #   Scheduler State
        num_running_sys = len(self.scheduler.running)
        num_swapped_sys = len(self.scheduler.swapped)
        num_waiting_sys = len(self.scheduler.waiting)

        # KV Cache Usage in %
841
        num_total_gpu = self.cache_config.num_gpu_blocks
842
843
844
845
846
        gpu_cache_usage_sys = 0.
        if num_total_gpu is not None:
            num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
            )
            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
847

848
        num_total_cpu = self.cache_config.num_cpu_blocks
849
        cpu_cache_usage_sys = 0.
850
        if num_total_cpu is not None and num_total_cpu > 0:
851
852
            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
            )
853
854
855
856
857
858
859
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []
860
861
        num_preemption_iter = (0 if scheduler_outputs is None else
                               scheduler_outputs.preempted)
862
863
864
865
866
867
868
869
870
871
872
873
874

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        best_of_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
875
        if scheduler_outputs is not None:
876
            num_generation_tokens_from_prefill_groups = 0.
877
878
879
880
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
881
882
883
884

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
885
                seq_group = scheduled_seq_group.seq_group
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
914
                if seq_group.is_finished():
915
                    # Latency timings
916
917
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
918

919
920
921
922
923
924
925
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
926
927
928
929
                    if seq_group.sampling_params is not None:
                        best_of_requests.append(
                            seq_group.sampling_params.best_of)
                        n_requests.append(seq_group.sampling_params.n)
930
931
932
933
934
935
936
937
938
939
940
941
942
943
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
                scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
                num_generation_tokens_from_prefill_groups)
944

945
946
947
948
949
950
951
952
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

953
954
        return Stats(
            now=now,
955
956
957
958
959
960
961
962
963
964
965
966
967
968
            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
969
            spec_decode_metrics=spec_decode_metrics,
970
            num_preemption_iter=num_preemption_iter,
971
972
973
974
975
976
977
978
979
980

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            best_of_requests=best_of_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
981
982
        )

983
    def add_lora(self, lora_request: LoRARequest) -> bool:
984
        return self.model_executor.add_lora(lora_request)
985
986

    def remove_lora(self, lora_id: int) -> bool:
987
        return self.model_executor.remove_lora(lora_id)
988
989

    def list_loras(self) -> List[int]:
990
        return self.model_executor.list_loras()
991
992

    def check_health(self) -> None:
993
        self.model_executor.check_health()