llm_engine.py 28.7 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from typing import Iterable, List, Optional, Type, Union
3

4
from transformers import GenerationConfig, PreTrainedTokenizer
5

6
import vllm
7
8
9
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
                         LoRAConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig, SpeculativeConfig,
10
                         VisionLanguageConfig)
Antoni Baum's avatar
Antoni Baum committed
11
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.engine.arg_utils import EngineArgs
13
from vllm.engine.metrics import StatLogger, Stats
14
15
16
17
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
18
from vllm.executor.executor_base import ExecutorBase
19
from vllm.executor.ray_utils import initialize_ray_cluster
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.logger import init_logger
21
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
24
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
25
                           SequenceGroup, SequenceStage)
26
from vllm.transformers_utils.detokenizer import Detokenizer
27
28
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
yhu422's avatar
yhu422 committed
29
30
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
31
from vllm.utils import Counter
32
33

logger = init_logger(__name__)
34
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
35

36

37
38
39
40
41
42
43
44
45
46
47
def _load_generation_config_dict(model_config: ModelConfig):
    try:
        return GenerationConfig.from_pretrained(
            model_config.model,
            revision=model_config.revision,
        ).to_diff_dict()
    except OSError:
        # Not found.
        return {}


48
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
49
    """An LLM engine that receives requests and generates texts.
50

Woosuk Kwon's avatar
Woosuk Kwon committed
51
    This is the main class for the vLLM engine. It receives requests
52
53
54
55
56
57
58
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
59
    `AsyncLLMEngine` class wraps this class for online serving.
60

Zhuohan Li's avatar
Zhuohan Li committed
61
62
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
63
64
65
66
67
68
69

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
70
        device_config: The configuration related to the device.
71
72
73
74
75
        lora_config (Optional): The configuration related to serving multi-LoRA.
        vision_language_config (Optional): The configuration related to vision
            language models.
        speculative_config (Optional): The configuration related to speculative
            decoding.
76
77
        executor_class: The model executor class for managing distributed
            execution.
78
        log_stats: Whether to log statistics.
yhu422's avatar
yhu422 committed
79
        usage_context: Specified entry point, used for usage info collection
80
    """
81
82
83
84
85
86
87

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
88
        device_config: DeviceConfig,
89
        load_config: LoadConfig,
90
        lora_config: Optional[LoRAConfig],
91
92
        vision_language_config: Optional[VisionLanguageConfig],
        speculative_config: Optional[SpeculativeConfig],
93
        decoding_config: Optional[DecodingConfig],
94
        executor_class: Type[ExecutorBase],
95
        log_stats: bool,
yhu422's avatar
yhu422 committed
96
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
97
98
    ) -> None:
        logger.info(
99
            f"Initializing an LLM engine (v{vllm.__version__}) with config: "
100
            f"model={model_config.model!r}, "
101
            f"speculative_config={speculative_config!r}, "
102
            f"tokenizer={model_config.tokenizer!r}, "
103
            f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
104
            f"tokenizer_mode={model_config.tokenizer_mode}, "
Jasmond L's avatar
Jasmond L committed
105
            f"revision={model_config.revision}, "
106
            f"tokenizer_revision={model_config.tokenizer_revision}, "
107
            f"trust_remote_code={model_config.trust_remote_code}, "
108
            f"dtype={model_config.dtype}, "
109
            f"max_seq_len={model_config.max_model_len}, "
110
111
            f"download_dir={load_config.download_dir!r}, "
            f"load_format={load_config.load_format}, "
112
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
113
114
            f"disable_custom_all_reduce="
            f"{parallel_config.disable_custom_all_reduce}, "
115
            f"quantization={model_config.quantization}, "
116
            f"enforce_eager={model_config.enforce_eager}, "
117
            f"kv_cache_dtype={cache_config.cache_dtype}, "
118
            f"quantization_param_path={model_config.quantization_param_path}, "
119
            f"device_config={device_config.device}, "
120
            f"decoding_config={decoding_config!r}, "
121
            f"seed={model_config.seed})")
122
123
124
125
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
126
        self.lora_config = lora_config
127
        self.vision_language_config = vision_language_config
128
129
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
130
        self.device_config = device_config
131
        self.speculative_config = speculative_config
132
        self.load_config = load_config
133
        self.decoding_config = decoding_config or DecodingConfig()
134
135
        self.log_stats = log_stats

136
137
138
139
140
141
142
143
        if not self.model_config.skip_tokenizer_init:
            self.tokenizer: BaseTokenizerGroup
            self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)
        else:
            self.detokenizer = None
            self.tokenizer = None

144
        self.seq_counter = Counter()
145
146
        self.generation_config_fields = _load_generation_config_dict(
            model_config)
147

148
149
150
151
152
153
154
155
156
        self.model_executor = executor_class(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            vision_language_config=vision_language_config,
            speculative_config=speculative_config,
157
            load_config=load_config,
158
        )
159

160
161
        self._initialize_kv_caches()

yhu422's avatar
yhu422 committed
162
163
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
164
165
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
yhu422's avatar
yhu422 committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
                    cache_config.cache_dtype,

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

197
198
199
200
        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()
201

202
        # Create the scheduler.
203
204
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
205
        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
206

207
208
209
        # Metric Logging.
        if self.log_stats:
            self.stat_logger = StatLogger(
210
211
                local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                labels=dict(model_name=model_config.model))
212
            self.stat_logger.info("cache_config", self.cache_config)
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
                self.get_tokenizer_for_seq,
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
                    self.get_tokenizer_for_seq,
                ),
            ))

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
            logger.info(f"Overriding {num_gpu_blocks=} with "
                        f"{num_gpu_blocks_override=}")
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

249
    @classmethod
yhu422's avatar
yhu422 committed
250
251
252
253
254
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "LLMEngine":
255
256
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
257
        engine_config = engine_args.create_engine_config()
258
259

        # Initialize the cluster and specify the executor class.
260
        if engine_config.device_config.device_type == "neuron":
261
262
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
263
        elif engine_config.device_config.device_type == "cpu":
264
265
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
266
267
        elif engine_config.parallel_config.worker_use_ray:
            initialize_ray_cluster(engine_config.parallel_config)
268
269
270
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
        else:
271
            assert engine_config.parallel_config.world_size == 1, (
272
273
274
275
276
                "Ray is required if parallel_config.world_size > 1.")
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor

        # Create the LLM engine.
yhu422's avatar
yhu422 committed
277
        engine = cls(
278
            **engine_config.to_dict(),
yhu422's avatar
yhu422 committed
279
280
281
282
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
        )
283
        return engine
284

285
286
287
288
289
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

290
    def get_tokenizer(self) -> "PreTrainedTokenizer":
291
        return self.tokenizer.get_lora_tokenizer(None)
292
293
294

    def get_tokenizer_for_seq(self,
                              sequence: Sequence) -> "PreTrainedTokenizer":
295
296
297
298
        return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

    def _init_tokenizer(self, **tokenizer_init_kwargs):
        init_kwargs = dict(
299
            tokenizer_id=self.model_config.tokenizer,
300
301
302
303
304
305
306
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
307
        self.tokenizer = get_tokenizer_group(
308
            self.parallel_config.tokenizer_pool_config, **init_kwargs)
309

310
311
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
312
        self.cache_config.verify_with_parallel_config(self.parallel_config)
313
314
315
316
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
317

318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def encode_request(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = self.tokenizer.encode(request_id=request_id,
                                                     prompt=prompt,
                                                     lora_request=lora_request)
        return prompt_token_ids

332
333
334
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
335
        prompt: Optional[str],
336
337
338
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
339
        lora_request: Optional[LoRARequest] = None,
340
        multi_modal_data: Optional[MultiModalData] = None,
341
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
342
        """Add a request to the engine's request pool.
343
344

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
345
        scheduler as `engine.step()` is called. The exact scheduling policy is
346
347
348
349
350
351
352
353
354
355
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
356
                the current monotonic time.
357
            multi_modal_data: Multi modal data per request.
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
382
        """
383
384
385
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
386
387
388
389
390
391
392
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")
393
        if arrival_time is None:
394
            arrival_time = time.time()
395
396
397
398
399
        prompt_token_ids = self.encode_request(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)
400
401
402

        # Create the sequences.
        block_size = self.cache_config.block_size
403
        seq_id = next(self.seq_counter)
404
405
406
407
408
409
410
        eos_token_id = None
        if self.tokenizer:
            eos_token_id = self.tokenizer.get_lora_tokenizer(
                lora_request).eos_token_id
        else:
            logger.warning("Use None for EOS token id because tokenizer is "
                           "not initialized")
411
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
412
                       eos_token_id, lora_request)
413

414
415
416
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
417
418
419
        # inject the eos token id into the sampling_params to support min_tokens
        # processing
        sampling_params.eos_token_id = seq.eos_token_id
420
421
        sampling_params.update_from_generation_config(
            self.generation_config_fields)
422

423
        # Create the sequence group.
424
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
425
                                  arrival_time, lora_request, multi_modal_data)
426
427
428
429

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
430
431
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
432
433

        Args:
Antoni Baum's avatar
Antoni Baum committed
434
            request_id: The ID(s) of the request to abort.
435
436
437
438
439
440
441
442
443
444
445

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
446
        """
447
448
        self.scheduler.abort_seq_group(request_id)

449
450
451
452
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

453
    def get_num_unfinished_requests(self) -> int:
454
        """Gets the number of unfinished requests."""
455
456
        return self.scheduler.get_num_unfinished_seq_groups()

457
    def has_unfinished_requests(self) -> bool:
458
        """Returns True if there are unfinished requests."""
459
460
        return self.scheduler.has_unfinished_seqs()

461
    def _process_model_outputs(
462
463
464
465
            self, output: List[SamplerOutput],
            scheduled_seq_groups: List[SequenceGroup],
            ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
        """Apply the model output to the sequences in the scheduled seq groups.
466

467
468
469
        Returns RequestOutputs that can be returned to the client.
        """

470
        now = time.time()
471
472
473
474
475
476

        # Organize outputs by [sequence group][step] instead of
        # [step][sequence group].
        output_by_sequence_group = create_output_by_sequence_group(
            sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))

477
        # Update the scheduled sequence groups with the model outputs.
478
479
        for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
                                                output_by_sequence_group):
480
            seq_group = scheduled_seq_group.seq_group
481
482
            seq_group.update_num_computed_tokens(
                scheduled_seq_group.token_chunk_size)
483
484
485
486
487
488

            # If all sequences in the sequence group are in DECODE, then we can
            # process the output tokens. Otherwise, they are (chunked) prefill
            # samples and should not be processed.
            stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
            if all(stage == SequenceStage.DECODE for stage in stages):
489
                self.output_processor.process_outputs(seq_group, outputs)
490
491
492

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
493
494
495

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
496
497
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
498
            seq_group.maybe_set_first_token_time(now)
499
500
            request_output = RequestOutput.from_seq_group(seq_group)
            request_outputs.append(request_output)
501
        for seq_group in ignored_seq_groups:
502
            request_output = RequestOutput.from_seq_group(seq_group)
503
504
505
            request_outputs.append(request_output)
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
506
507
508
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

524
            - Step 2: Calls the distributed executor to execute the model.
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
            >>>         engine.add_request(str(req_id), prompt, sampling_params)
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
556
        """
557
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
558

559
        if not scheduler_outputs.is_empty():
560
            output = self.model_executor.execute_model(
561
562
563
564
565
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
566
567
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
568

569
570
571
572
573
574
        request_outputs = self._process_model_outputs(
            output, scheduler_outputs.scheduled_seq_groups,
            scheduler_outputs.ignored_seq_groups)

        # Log stats.
        if self.log_stats:
575
576
            self.stat_logger.log(
                self._get_stats(scheduler_outputs, model_output=output))
577
578

        return request_outputs
Antoni Baum's avatar
Antoni Baum committed
579

580
    def do_log_stats(self) -> None:
581
582
583
        """Forced log when no requests active."""
        if self.log_stats:
            self.stat_logger.log(self._get_stats(scheduler_outputs=None))
584

585
586
587
588
589
590
591
592
593
594
595
596
    def _get_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs],
            model_output: Optional[List[SamplerOutput]] = None) -> Stats:
        """Get Stats to be Logged to Prometheus.

        Args:
            scheduler_outputs: Optional, used to populate metrics related to
                the scheduled batch,
            model_output: Optional, used to emit speculative decoding metrics
                which are created by the workers.
        """
597
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
598

599
600
601
602
        # KV Cache Usage in %.
        num_total_gpu = self.cache_config.num_gpu_blocks
        num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
        gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
603

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        num_total_cpu = self.cache_config.num_cpu_blocks
        cpu_cache_usage = 0.
        if num_total_cpu > 0:
            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
            )
            cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)

        # Scheduler State
        num_running = len(self.scheduler.running)
        num_swapped = len(self.scheduler.swapped)
        num_waiting = len(self.scheduler.waiting)

        # Iteration stats if we have scheduler output.
        num_prompt_tokens = 0
        num_generation_tokens = 0
        time_to_first_tokens = []
        time_per_output_tokens = []
        time_e2e_requests = []
        if scheduler_outputs is not None:
623
            prompt_run = scheduler_outputs.num_prefill_groups > 0
624
625
626

            # Number of Tokens.
            if prompt_run:
627
                num_prompt_tokens = sum(
628
629
630
                    len(scheduled_seq_group.seq_group.prompt_token_ids)
                    for scheduled_seq_group in
                    scheduler_outputs.scheduled_seq_groups)
631
                num_generation_tokens = sum(
632
633
634
                    scheduled_seq_group.seq_group.num_seqs()
                    for scheduled_seq_group in
                    scheduler_outputs.scheduled_seq_groups)
635
636
637
638
639
            else:
                num_generation_tokens = scheduler_outputs.num_batched_tokens

            # Latency Timings.
            time_last_iters = []
640
641
            for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
                seq_group = scheduled_seq_group.seq_group
642
643
                # Time since last token.
                # (n.b. updates seq_group.metrics.last_token_time)
644
645
646
                time_last_iters.append(seq_group.get_last_latency(now))
                # Time since arrival for all finished requests.
                if seq_group.is_finished():
647
648
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
649
650
651
652

            time_to_first_tokens = time_last_iters if prompt_run else []
            time_per_output_tokens = [] if prompt_run else time_last_iters

653
654
655
656
657
658
659
660
        # Spec decode, if enabled, emits specialized metrics from the worker in
        # sampler output.
        if model_output and (model_output[0].spec_decode_worker_metrics
                             is not None):
            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
        else:
            spec_decode_metrics = None

661
662
663
664
665
        return Stats(
            now=now,
            num_running=num_running,
            num_swapped=num_swapped,
            num_waiting=num_waiting,
666
667
            gpu_cache_usage=gpu_cache_usage,
            cpu_cache_usage=cpu_cache_usage,
668
669
670
671
672
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=num_generation_tokens,
            time_to_first_tokens=time_to_first_tokens,
            time_per_output_tokens=time_per_output_tokens,
            time_e2e_requests=time_e2e_requests,
673
            spec_decode_metrics=spec_decode_metrics,
674
675
        )

676
    def add_lora(self, lora_request: LoRARequest) -> bool:
677
        return self.model_executor.add_lora(lora_request)
678
679

    def remove_lora(self, lora_id: int) -> bool:
680
        return self.model_executor.remove_lora(lora_id)
681
682

    def list_loras(self) -> List[int]:
683
        return self.model_executor.list_loras()
684
685

    def check_health(self) -> None:
686
        self.model_executor.check_health()