"vscode:/vscode.git/clone" did not exist on "8e340b4fa4efc3428b2f2fa0deb9a3140dbe255e"
llm_engine.py 32.6 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from typing import Iterable, List, Optional, Type, Union
3

4
from transformers import GenerationConfig, PreTrainedTokenizer
5

6
import vllm
7
8
9
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
                         LoRAConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig, SpeculativeConfig,
10
                         VisionLanguageConfig)
11
12
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                 SchedulerOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import EngineArgs
14
from vllm.engine.metrics import StatLogger, Stats
15
16
17
18
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
19
from vllm.executor.executor_base import ExecutorBase
20
from vllm.executor.ray_utils import initialize_ray_cluster
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
25
26
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
                           Sequence, SequenceGroup, SequenceGroupMetadata,
27
                           SequenceStatus)
28
from vllm.transformers_utils.detokenizer import Detokenizer
29
30
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
yhu422's avatar
yhu422 committed
31
32
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
33
from vllm.utils import Counter
34
35

logger = init_logger(__name__)
36
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
37

38

39
40
41
42
43
44
45
46
47
48
49
def _load_generation_config_dict(model_config: ModelConfig):
    try:
        return GenerationConfig.from_pretrained(
            model_config.model,
            revision=model_config.revision,
        ).to_diff_dict()
    except OSError:
        # Not found.
        return {}


50
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
51
    """An LLM engine that receives requests and generates texts.
52

Woosuk Kwon's avatar
Woosuk Kwon committed
53
    This is the main class for the vLLM engine. It receives requests
54
55
56
57
58
59
60
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
61
    `AsyncLLMEngine` class wraps this class for online serving.
62

Zhuohan Li's avatar
Zhuohan Li committed
63
64
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
65
66
67
68
69
70
71

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
72
        device_config: The configuration related to the device.
73
74
75
76
77
        lora_config (Optional): The configuration related to serving multi-LoRA.
        vision_language_config (Optional): The configuration related to vision
            language models.
        speculative_config (Optional): The configuration related to speculative
            decoding.
78
79
        executor_class: The model executor class for managing distributed
            execution.
80
        log_stats: Whether to log statistics.
yhu422's avatar
yhu422 committed
81
        usage_context: Specified entry point, used for usage info collection
82
    """
83
84
85
86
87
88
89

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
90
        device_config: DeviceConfig,
91
        load_config: LoadConfig,
92
        lora_config: Optional[LoRAConfig],
93
94
        vision_language_config: Optional[VisionLanguageConfig],
        speculative_config: Optional[SpeculativeConfig],
95
        decoding_config: Optional[DecodingConfig],
96
        executor_class: Type[ExecutorBase],
97
        log_stats: bool,
yhu422's avatar
yhu422 committed
98
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
99
100
    ) -> None:
        logger.info(
101
102
103
104
105
            "Initializing an LLM engine (v%s) with config: "
            "model=%r, speculative_config=%r, tokenizer=%r, "
            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
            "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
            "max_seq_len=%d, download_dir=%r, load_format=%s, "
106
            "tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
            "quantization_param_path=%s, device_config=%s, "
            "decoding_config=%r, seed=%d)",
            vllm.__version__,
            model_config.model,
            speculative_config,
            model_config.tokenizer,
            model_config.skip_tokenizer_init,
            model_config.tokenizer_mode,
            model_config.revision,
            model_config.tokenizer_revision,
            model_config.trust_remote_code,
            model_config.dtype,
            model_config.max_model_len,
            load_config.download_dir,
            load_config.load_format,
            parallel_config.tensor_parallel_size,
            parallel_config.disable_custom_all_reduce,
            model_config.quantization,
            model_config.enforce_eager,
            cache_config.cache_dtype,
            model_config.quantization_param_path,
            device_config.device,
            decoding_config,
            model_config.seed,
        )
133
134
135
136
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
137
        self.lora_config = lora_config
138
        self.vision_language_config = vision_language_config
139
140
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
141
        self.device_config = device_config
142
        self.speculative_config = speculative_config
143
        self.load_config = load_config
144
        self.decoding_config = decoding_config or DecodingConfig()
145
146
        self.log_stats = log_stats

147
148
149
150
151
152
153
154
        if not self.model_config.skip_tokenizer_init:
            self.tokenizer: BaseTokenizerGroup
            self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)
        else:
            self.detokenizer = None
            self.tokenizer = None

155
        self.seq_counter = Counter()
156
157
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
158

159
160
161
162
163
164
165
166
167
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            vision_language_config=vision_language_config,
            speculative_config=speculative_config,
168
            load_config=load_config,
169
        )
170

171
172
        self._initialize_kv_caches()

yhu422's avatar
yhu422 committed
173
174
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
175
176
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
                    cache_config.cache_dtype,

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

208
209
210
211
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
212

213
        # Create the scheduler.
214
215
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
216
        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
217

218
219
220
        # Metric Logging.
        if self.log_stats:
            self.stat_logger = StatLogger(
221
                local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
222
223
                labels=dict(model_name=model_config.model),
                max_model_len=self.model_config.max_model_len)
224
            self.stat_logger.info("cache_config", self.cache_config)
225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
                self.get_tokenizer_for_seq,
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
                    self.get_tokenizer_for_seq,
                ),
            ))

241
242
243
244
245
246
247
248
249
250
251
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
252
253
254
255
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
256
257
258
259
260
261
262
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

263
    @classmethod
yhu422's avatar
yhu422 committed
264
265
266
267
268
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "LLMEngine":
269
270
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
271
        engine_config = engine_args.create_engine_config()
272
273

        # Initialize the cluster and specify the executor class.
274
        if engine_config.device_config.device_type == "neuron":
275
276
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
277
        elif engine_config.device_config.device_type == "cpu":
278
279
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
280
281
        elif engine_config.parallel_config.worker_use_ray:
            initialize_ray_cluster(engine_config.parallel_config)
282
283
284
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
        else:
285
            assert engine_config.parallel_config.world_size == 1, (
286
287
288
289
290
                "Ray is required if parallel_config.world_size > 1.")
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor

        # Create the LLM engine.
yhu422's avatar
yhu422 committed
291
        engine = cls(
292
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
293
294
295
296
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
        )
297
        return engine
298

299
300
301
302
303
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

304
305
306
307
308
309
    def __del__(self):
        # Shutdown model executor when engine is garbage collected
        # Use getattr since __init__ can fail before the field is set
        if model_executor := getattr(self, "model_executor", None):
            model_executor.shutdown()

310
    def get_tokenizer(self) -> "PreTrainedTokenizer":
311
        return self.tokenizer.get_lora_tokenizer(None)
312
313
314

    def get_tokenizer_for_seq(self,
                              sequence: Sequence) -> "PreTrainedTokenizer":
315
316
317
318
        return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

    def _init_tokenizer(self, **tokenizer_init_kwargs):
        init_kwargs = dict(
319
            tokenizer_id=self.model_config.tokenizer,
320
321
322
323
324
325
326
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
327
        self.tokenizer = get_tokenizer_group(
328
            self.parallel_config.tokenizer_pool_config, **init_kwargs)
329

330
331
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
332
        self.cache_config.verify_with_parallel_config(self.parallel_config)
333
334
335
336
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
337

338
339
340
341
342
343
344
345
346
347
348
349
350
351
    def encode_request(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = self.tokenizer.encode(request_id=request_id,
                                                     prompt=prompt,
                                                     lora_request=lora_request)
        return prompt_token_ids

352
353
354
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
355
        prompt: Optional[str],
356
357
358
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
359
        lora_request: Optional[LoRARequest] = None,
360
        multi_modal_data: Optional[MultiModalData] = None,
361
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
362
        """Add a request to the engine's request pool.
363
364

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
365
        scheduler as `engine.step()` is called. The exact scheduling policy is
366
367
368
369
370
371
372
373
374
375
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
376
                the current monotonic time.
377
            multi_modal_data: Multi modal data per request.
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
402
        """
403
404
405
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
406
407
408
409
410
411
412
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")
413
        if arrival_time is None:
414
            arrival_time = time.time()
415
416
417
418
419
        prompt_token_ids = self.encode_request(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)
420
421
422

        # Create the sequences.
        block_size = self.cache_config.block_size
423
        seq_id = next(self.seq_counter)
424
425
426
427
428
429
430
        eos_token_id = None
        if self.tokenizer:
            eos_token_id = self.tokenizer.get_lora_tokenizer(
                lora_request).eos_token_id
        else:
            logger.warning("Use None for EOS token id because tokenizer is "
                           "not initialized")
431
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
432
                       eos_token_id, lora_request)
433

434
435
436
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
437
        # Add the eos token id into the sampling_params to support min_tokens
438
        # processing
439
440
        if seq.eos_token_id is not None:
            sampling_params.all_stop_token_ids.add(seq.eos_token_id)
441
442
        sampling_params.update_from_generation_config(
            self.generation_config_fields)
443

444
        # Create the sequence group.
445
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
446
                                  arrival_time, lora_request, multi_modal_data)
447
448
449
450

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
451
452
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
453
454

        Args:
Antoni Baum's avatar
Antoni Baum committed
455
            request_id: The ID(s) of the request to abort.
456
457
458
459
460
461
462
463
464
465
466

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
467
        """
468
469
        self.scheduler.abort_seq_group(request_id)

470
471
472
473
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

474
475
476
477
    def get_decoding_config(self) -> DecodingConfig:
        """Gets the decoding configuration."""
        return self.decoding_config

478
    def get_num_unfinished_requests(self) -> int:
479
        """Gets the number of unfinished requests."""
480
481
        return self.scheduler.get_num_unfinished_seq_groups()

482
    def has_unfinished_requests(self) -> bool:
483
        """Returns True if there are unfinished requests."""
484
485
        return self.scheduler.has_unfinished_seqs()

486
    def _process_model_outputs(
487
488
        self,
        output: List[SamplerOutput],
489
        scheduled_seq_groups: List[ScheduledSequenceGroup],
490
491
492
        ignored_seq_groups: List[SequenceGroup],
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> List[RequestOutput]:
493
        """Apply the model output to the sequences in the scheduled seq groups.
494

495
496
497
        Returns RequestOutputs that can be returned to the client.
        """

498
        now = time.time()
499
500
501
502
503
504

        # Organize outputs by [sequence group][step] instead of
        # [step][sequence group].
        output_by_sequence_group = create_output_by_sequence_group(
            sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))

505
        # Update the scheduled sequence groups with the model outputs.
506
507
508
        for scheduled_seq_group, outputs, seq_group_meta in zip(
                scheduled_seq_groups, output_by_sequence_group,
                seq_group_metadata_list):
509
            seq_group = scheduled_seq_group.seq_group
510
511
            seq_group.update_num_computed_tokens(
                scheduled_seq_group.token_chunk_size)
512

513
514
            self.output_processor.process_prompt_logprob(seq_group, outputs)
            if seq_group_meta.do_sample:
515
                self.output_processor.process_outputs(seq_group, outputs)
516
517
518

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
519
520
521

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
522
523
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
524
            seq_group.maybe_set_first_token_time(now)
525
526
            request_output = RequestOutput.from_seq_group(seq_group)
            request_outputs.append(request_output)
527
        for seq_group in ignored_seq_groups:
528
            request_output = RequestOutput.from_seq_group(seq_group)
529
530
531
            request_outputs.append(request_output)
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
532
533
534
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

550
            - Step 2: Calls the distributed executor to execute the model.
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
            >>>         engine.add_request(str(req_id), prompt, sampling_params)
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
582
        """
583
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
584

585
        if not scheduler_outputs.is_empty():
586
            execute_model_req = ExecuteModelRequest(
587
588
589
590
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
591
592
593
594
595
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
            output = self.model_executor.execute_model(
                execute_model_req=execute_model_req)
596
597
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
598

599
600
        request_outputs = self._process_model_outputs(
            output, scheduler_outputs.scheduled_seq_groups,
601
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
602
603

        # Log stats.
604
        self.do_log_stats(scheduler_outputs, output)
605
606

        return request_outputs
Antoni Baum's avatar
Antoni Baum committed
607

608
609
610
611
    def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
612
613
        """Forced log when no requests active."""
        if self.log_stats:
614
615
            self.stat_logger.log(
                self._get_stats(scheduler_outputs, model_output))
616

617
618
619
620
621
622
623
624
625
626
627
628
    def _get_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs],
            model_output: Optional[List[SamplerOutput]] = None) -> Stats:
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
        """
629
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
630

631
632
633
634
635
636
637
        # System State
        #   Scheduler State
        num_running_sys = len(self.scheduler.running)
        num_swapped_sys = len(self.scheduler.swapped)
        num_waiting_sys = len(self.scheduler.waiting)

        # KV Cache Usage in %
638
639
        num_total_gpu = self.cache_config.num_gpu_blocks
        num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
640
        gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
641

642
        num_total_cpu = self.cache_config.num_cpu_blocks
643
        cpu_cache_usage_sys = 0.
644
645
646
        if num_total_cpu > 0:
            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
            )
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)

        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []

        # Request stats
        #   Latency
        time_e2e_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        best_of_requests: List[int] = []
        n_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # NOTE: This loop assumes prefill seq_groups are before
        # decode seq_groups in scheduled_seq_groups.
667
        if scheduler_outputs is not None:
668
            num_generation_tokens_from_prefill_groups = 0.
669
670
671
672
            # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
            # the len of scheduler_outputs.scheduled_seq_groups is !=
            # scheduler_outputs.num_prefill_groups, this means that
            # chunked prefills have been detected.
673
674
675
676

            for idx, scheduled_seq_group in enumerate(
                    scheduler_outputs.scheduled_seq_groups):
                group_was_prefill = idx < scheduler_outputs.num_prefill_groups
677
                seq_group = scheduled_seq_group.seq_group
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705

                # NOTE: a seq_group that completed all of its prefill tokens
                # in the last iteration will have seq_group.is_prefill() = False
                # with group_was_prefill = True
                if group_was_prefill:
                    # Number of prompt tokens.
                    num_prompt_tokens_iter += (
                        scheduled_seq_group.token_chunk_size)

                    # If the seq_group just finished the prefill state
                    # get TTFT.
                    if not seq_group.is_prefill():
                        latency = seq_group.get_last_latency(now)
                        time_to_first_tokens_iter.append(latency)

                        # One generation token per finished prefill.
                        num_generation_tokens_from_prefill_groups += (
                            seq_group.num_seqs())
                else:
                    # TPOTs.
                    latency = seq_group.get_last_latency(now)
                    time_per_output_tokens_iter.append(latency)

                # Because of chunked prefill, we can have a single sequence
                # group that does multiple prompt_runs. To prevent logging
                # the same metadata more than once per request, we standardize
                # on logging request level information for finished requests,
                # which can only happen once.
706
                if seq_group.is_finished():
707
                    # Latency timings
708
709
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
710

711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
                    # Metadata
                    num_prompt_tokens_requests.append(
                        len(seq_group.prompt_token_ids))
                    num_generation_tokens_requests.extend([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ])
                    best_of_requests.append(seq_group.sampling_params.best_of)
                    n_requests.append(seq_group.sampling_params.n)
                    finished_reason_requests.extend([
                        SequenceStatus.get_finished_reason(seq.status)
                        for seq in seq_group.get_finished_seqs()
                    ])

            # Number of generation tokens.
            #   num_batched_tokens equals the number of prompt_tokens plus the
            #   number of decode_tokens in a single iteration. So,
            #   num_generation_tokens = num_batched_tokens - num_prompt_tokens
            #   + num_generation_tokens_from_prefill_groups (since we generate
            #   one token on prefills on iters where the prefill finishes).
            num_generation_tokens_iter = (
                scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
                num_generation_tokens_from_prefill_groups)
734

735
736
737
738
739
740
741
742
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

743
744
        return Stats(
            now=now,
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759

            # System stats
            #   Scheduler State
            num_running_sys=num_running_sys,
            num_swapped_sys=num_swapped_sys,
            num_waiting_sys=num_waiting_sys,
            #   KV Cache Usage in %
            gpu_cache_usage_sys=gpu_cache_usage_sys,
            cpu_cache_usage_sys=cpu_cache_usage_sys,

            # Iteration stats
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
760
            spec_decode_metrics=spec_decode_metrics,
761
762
763
764
765
766
767
768
769
770

            # Request stats
            #   Latency
            time_e2e_requests=time_e2e_requests,
            #   Metadata
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            best_of_requests=best_of_requests,
            n_requests=n_requests,
            finished_reason_requests=finished_reason_requests,
771
772
        )

773
    def add_lora(self, lora_request: LoRARequest) -> bool:
774
        return self.model_executor.add_lora(lora_request)
775
776

    def remove_lora(self, lora_id: int) -> bool:
777
        return self.model_executor.remove_lora(lora_id)
778
779

    def list_loras(self) -> List[int]:
780
        return self.model_executor.list_loras()
781
782

    def check_health(self) -> None:
783
        self.model_executor.check_health()