llm_engine.py 36.2 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from typing import Iterable, List, Optional, Tuple, Type, Union
3

4
5
from transformers import PreTrainedTokenizer

6
import vllm
7
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
8
                         ParallelConfig, SchedulerConfig, VisionLanguageConfig)
Antoni Baum's avatar
Antoni Baum committed
9
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
10
from vllm.engine.arg_utils import EngineArgs
11
from vllm.engine.metrics import StatLogger, Stats
12
from vllm.engine.ray_utils import initialize_ray_cluster
13
from vllm.executor.executor_base import ExecutorBase
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.logger import init_logger
15
from vllm.lora.request import LoRARequest
yhu422's avatar
yhu422 committed
16
from vllm.model_executor.model_loader import get_architecture_class_name
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
19
20
21
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
                           SequenceGroup, SequenceGroupOutput, SequenceOutput,
                           SequenceStatus)
22
from vllm.transformers_utils.detokenizer import Detokenizer
23
24
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
yhu422's avatar
yhu422 committed
25
26
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
27
from vllm.utils import Counter
28
29

logger = init_logger(__name__)
30
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32

33
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
34
    """An LLM engine that receives requests and generates texts.
35

Woosuk Kwon's avatar
Woosuk Kwon committed
36
    This is the main class for the vLLM engine. It receives requests
37
38
39
40
41
42
43
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
44
    `AsyncLLMEngine` class wraps this class for online serving.
45

Zhuohan Li's avatar
Zhuohan Li committed
46
47
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
48
49
50
51
52
53
54

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
55
        device_config: The configuration related to the device.
56
57
        executor_class: The model executor class for managing distributed
            execution.
58
        log_stats: Whether to log statistics.
yhu422's avatar
yhu422 committed
59
        usage_context: Specified entry point, used for usage info collection
60
    """
61
62
63
64
65
66
67

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
68
        device_config: DeviceConfig,
69
        lora_config: Optional[LoRAConfig],
70
        vision_language_config: Optional["VisionLanguageConfig"],
71
        executor_class: Type[ExecutorBase],
72
        log_stats: bool,
yhu422's avatar
yhu422 committed
73
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
74
75
    ) -> None:
        logger.info(
76
            f"Initializing an LLM engine (v{vllm.__version__}) with config: "
77
            f"model={model_config.model!r}, "
78
            f"tokenizer={model_config.tokenizer!r}, "
79
            f"tokenizer_mode={model_config.tokenizer_mode}, "
Jasmond L's avatar
Jasmond L committed
80
            f"revision={model_config.revision}, "
81
            f"tokenizer_revision={model_config.tokenizer_revision}, "
82
            f"trust_remote_code={model_config.trust_remote_code}, "
83
            f"dtype={model_config.dtype}, "
84
            f"max_seq_len={model_config.max_model_len}, "
85
            f"download_dir={model_config.download_dir!r}, "
86
            f"load_format={model_config.load_format}, "
87
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
88
89
            f"disable_custom_all_reduce="
            f"{parallel_config.disable_custom_all_reduce}, "
90
            f"quantization={model_config.quantization}, "
91
            f"enforce_eager={model_config.enforce_eager}, "
92
            f"kv_cache_dtype={cache_config.cache_dtype}, "
93
            f"device_config={device_config.device}, "
94
            f"seed={model_config.seed})")
95
96
97
98
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
99
        self.lora_config = lora_config
100
        self.vision_language_config = vision_language_config
101
102
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
103
        self.device_config = device_config
104
105
106
        self.log_stats = log_stats
        self._verify_args()

107
        self._init_tokenizer()
108
        self.detokenizer = Detokenizer(self.tokenizer)
109
110
        self.seq_counter = Counter()

111
112
        self.model_executor = executor_class(model_config, cache_config,
                                             parallel_config, scheduler_config,
113
114
                                             device_config, lora_config,
                                             vision_language_config)
115

yhu422's avatar
yhu422 committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
                    cache_config.cache_dtype,

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

149
150
151
152
        # Ping the tokenizer to ensure liveness if it runs in a
        # different process.
        self.tokenizer.ping()

153
        # Create the scheduler.
154
155
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
156
        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
157

158
159
160
        # Metric Logging.
        if self.log_stats:
            self.stat_logger = StatLogger(
161
162
                local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                labels=dict(model_name=model_config.model))
163
            self.stat_logger.info("cache_config", self.cache_config)
164

165
    @classmethod
yhu422's avatar
yhu422 committed
166
167
168
169
170
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "LLMEngine":
171
172
173
174
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
175
        device_config = engine_configs[4]
176
177

        # Initialize the cluster and specify the executor class.
178
179
180
181
        if device_config.device_type == "neuron":
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
        elif parallel_config.worker_use_ray:
182
183
184
185
186
187
188
189
190
191
            initialize_ray_cluster(parallel_config)
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
        else:
            assert parallel_config.world_size == 1, (
                "Ray is required if parallel_config.world_size > 1.")
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor

        # Create the LLM engine.
yhu422's avatar
yhu422 committed
192
193
194
195
196
197
        engine = cls(
            *engine_configs,
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
        )
198
        return engine
199

200
201
202
203
204
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

205
    def get_tokenizer(self) -> "PreTrainedTokenizer":
206
        return self.tokenizer.get_lora_tokenizer(None)
207
208
209

    def get_tokenizer_for_seq(self,
                              sequence: Sequence) -> "PreTrainedTokenizer":
210
211
212
213
        return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

    def _init_tokenizer(self, **tokenizer_init_kwargs):
        init_kwargs = dict(
214
            tokenizer_id=self.model_config.tokenizer,
215
216
217
218
219
220
221
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
222
223
        self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
            self.parallel_config.tokenizer_pool_config, **init_kwargs)
224

225
226
227
228
229
230
231
232
        if len(self.get_tokenizer()) != self.model_config.get_vocab_size():
            logger.warning(
                f"The tokenizer's vocabulary size {len(self.get_tokenizer())}"
                f" does not match the model's vocabulary size "
                f"{self.model_config.get_vocab_size()}. This might "
                f"cause an error in decoding. Please change config.json "
                "to match the tokenizer's vocabulary size.")

233
234
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
235
        self.cache_config.verify_with_parallel_config(self.parallel_config)
236
237
238
239
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
    def encode_request(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = self.tokenizer.encode(request_id=request_id,
                                                     prompt=prompt,
                                                     lora_request=lora_request)
        return prompt_token_ids

255
256
257
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
258
        prompt: Optional[str],
259
260
261
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
262
        lora_request: Optional[LoRARequest] = None,
263
        multi_modal_data: Optional[MultiModalData] = None,
264
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
265
        """Add a request to the engine's request pool.
266
267

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
268
        scheduler as `engine.step()` is called. The exact scheduling policy is
269
270
271
272
273
274
275
276
277
278
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
279
                the current monotonic time.
280
            multi_modal_data: Multi modal data per request.
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
305
        """
306
307
308
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
309
310
311
312
313
314
315
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")
316
        if arrival_time is None:
317
            arrival_time = time.time()
318
319
320
321
322
        prompt_token_ids = self.encode_request(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)
323
324
325

        # Create the sequences.
        block_size = self.cache_config.block_size
326
        seq_id = next(self.seq_counter)
327
328
        eos_token_id = self.tokenizer.get_lora_tokenizer(
            lora_request).eos_token_id
329
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
330
                       eos_token_id, lora_request)
331

332
333
334
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
335
336
337
        # inject the eos token id into the sampling_params to support min_tokens
        # processing
        sampling_params.eos_token_id = seq.eos_token_id
338

339
        # Create the sequence group.
340
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
341
                                  arrival_time, lora_request, multi_modal_data)
342
343
344
345

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
346
347
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
348
349

        Args:
Antoni Baum's avatar
Antoni Baum committed
350
            request_id: The ID(s) of the request to abort.
351
352
353
354
355
356
357
358
359
360
361

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
362
        """
363
364
        self.scheduler.abort_seq_group(request_id)

365
366
367
368
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

369
    def get_num_unfinished_requests(self) -> int:
370
        """Gets the number of unfinished requests."""
371
372
        return self.scheduler.get_num_unfinished_seq_groups()

373
    def has_unfinished_requests(self) -> bool:
374
        """Returns True if there are unfinished requests."""
375
376
        return self.scheduler.has_unfinished_seqs()

377
378
379
380
381
382
383
384
385
386
387
388
    def _check_beam_search_early_stopping(
        self,
        early_stopping: Union[bool, str],
        sampling_params: SamplingParams,
        best_running_seq: Sequence,
        current_worst_seq: Sequence,
    ) -> bool:
        assert sampling_params.use_beam_search
        length_penalty = sampling_params.length_penalty
        if early_stopping is True:
            return True

389
        current_worst_score = current_worst_seq.get_beam_search_score(
390
            length_penalty=length_penalty,
391
            eos_token_id=current_worst_seq.eos_token_id)
392
        if early_stopping is False:
393
            highest_attainable_score = best_running_seq.get_beam_search_score(
394
                length_penalty=length_penalty,
395
                eos_token_id=best_running_seq.eos_token_id)
396
397
398
399
400
401
402
403
404
405
406
407
408
        else:
            assert early_stopping == "never"
            if length_penalty > 0.0:
                # If length_penalty > 0.0, beam search will prefer longer
                # sequences. The highest attainable score calculation is
                # based on the longest possible sequence length in this case.
                max_possible_length = max(
                    best_running_seq.get_prompt_len() +
                    sampling_params.max_tokens,
                    self.scheduler_config.max_model_len)
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
409
                        eos_token_id=best_running_seq.eos_token_id,
410
411
412
413
414
415
416
417
                        seq_len=max_possible_length))
            else:
                # Otherwise, beam search will prefer shorter sequences. The
                # highest attainable score calculation is based on the current
                # sequence length.
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
418
                        eos_token_id=best_running_seq.eos_token_id))
419
420
        return current_worst_score >= highest_attainable_score

421
    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
Zhuohan Li's avatar
Zhuohan Li committed
422
                                        outputs: SequenceGroupOutput) -> None:
423

424
425
426
        # Process prompt logprobs
        prompt_logprobs = outputs.prompt_logprobs
        if prompt_logprobs is not None:
427
428
            self.detokenizer.decode_prompt_logprobs_inplace(
                seq_group, prompt_logprobs)
429
430
431
432
            seq_group.prompt_logprobs = prompt_logprobs

        # Process samples
        samples = outputs.samples
433
434
435
436
437
438
439
440
441
442
443
444
445
        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        existing_finished_seqs = seq_group.get_finished_seqs()
        parent_child_dict = {
            parent_seq.seq_id: []
            for parent_seq in parent_seqs
        }
        for sample in samples:
            parent_child_dict[sample.parent_seq_id].append(sample)
        # List of (child, parent)
        child_seqs: List[Tuple[Sequence, Sequence]] = []

        # Process the child samples for each parent sequence
        for parent in parent_seqs:
Zhuohan Li's avatar
Zhuohan Li committed
446
            child_samples: List[SequenceOutput] = parent_child_dict[
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
                parent.seq_id]
            if len(child_samples) == 0:
                # This parent sequence has no children samples. Remove
                # the parent sequence from the sequence group since it will
                # not be used in the future iterations.
                parent.status = SequenceStatus.FINISHED_ABORTED
                seq_group.remove(parent.seq_id)
                self.scheduler.free_seq(parent)
                continue
            # Fork the parent sequence if there are multiple child samples.
            for child_sample in child_samples[:-1]:
                new_child_seq_id = next(self.seq_counter)
                child = parent.fork(new_child_seq_id)
                child.append_token_id(child_sample.output_token,
                                      child_sample.logprobs)
                child_seqs.append((child, parent))
            # Continue the parent sequence for the last child sample.
            # We reuse the parent sequence here to reduce redundant memory
            # copies, especially when using non-beam search sampling methods.
            last_child_sample = child_samples[-1]
            parent.append_token_id(last_child_sample.output_token,
                                   last_child_sample.logprobs)
            child_seqs.append((parent, parent))

        for seq, _ in child_seqs:
472
473
            self.detokenizer.decode_sequence_inplace(seq,
                                                     seq_group.sampling_params)
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
            self._check_stop(seq, seq_group.sampling_params)

        # Non-beam search case
        if not seq_group.sampling_params.use_beam_search:
            # For newly created child sequences, add them to the sequence group
            # and fork them in block manager if they are not finished.
            for seq, parent in child_seqs:
                if seq is not parent:
                    seq_group.add(seq)
                    if not seq.is_finished():
                        self.scheduler.fork_seq(parent, seq)

            # Free the finished and selected parent sequences' memory in block
            # manager. Keep them in the sequence group as candidate output.
            # NOTE: we need to fork the new sequences before freeing the
            # old sequences.
            for seq, parent in child_seqs:
                if seq is parent and seq.is_finished():
                    self.scheduler.free_seq(seq)
            return

        # Beam search case
        # Select the child sequences to keep in the sequence group.
        selected_child_seqs = []
        unselected_child_seqs = []
        beam_width = seq_group.sampling_params.best_of
        length_penalty = seq_group.sampling_params.length_penalty

        # Select the newly finished sequences with the highest scores
        # to replace existing finished sequences.
        # Tuple of (seq, parent, is_new)
        existing_finished_seqs = [(seq, None, False)
                                  for seq in existing_finished_seqs]
        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
                             if seq.is_finished()]
        all_finished_seqs = existing_finished_seqs + new_finished_seqs
        # Sort the finished sequences by their scores.
        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
512
            length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
                               reverse=True)
        for seq, parent, is_new in all_finished_seqs[:beam_width]:
            if is_new:
                # A newly generated child sequence finishes and has a high
                # score, so we will add it into the sequence group.
                selected_child_seqs.append((seq, parent))
        for seq, parent, is_new in all_finished_seqs[beam_width:]:
            if is_new:
                # A newly generated child sequence finishes but has a low
                # score, so we will not add it into the sequence group.
                # Additionally, if this sequence is a continuation of a
                # parent sequence, we will need remove the parent sequence
                # from the sequence group.
                unselected_child_seqs.append((seq, parent))
            else:
                # An existing finished sequence has a low score, so we will
                # remove it from the sequence group.
                seq_group.remove(seq.seq_id)

        # select the top beam_width sequences from the running
        # sequences for the next iteration to continue the beam
        # search.
        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                              if not seq.is_finished()]
        # Sort the running sequences by their scores.
        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
539
            length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
                                reverse=True)

        # Check if we can stop the beam search.
        if len(running_child_seqs) == 0:
            # No running sequences, stop the beam search.
            stop_beam_search = True
        elif len(all_finished_seqs) < beam_width:
            # Not enough finished sequences, continue the beam search.
            stop_beam_search = False
        else:
            # Check the early stopping criteria
            best_running_seq = running_child_seqs[0][0]
            current_worst_seq = all_finished_seqs[beam_width - 1][0]
            stop_beam_search = self._check_beam_search_early_stopping(
                seq_group.sampling_params.early_stopping,
                seq_group.sampling_params, best_running_seq, current_worst_seq)

        if stop_beam_search:
            # Stop the beam search and remove all the running sequences from
            # the sequence group.
            unselected_child_seqs.extend(running_child_seqs)
        else:
            # Continue the beam search and select the top beam_width sequences
            # to continue the beam search.
            selected_child_seqs.extend(running_child_seqs[:beam_width])
            # The remaining running sequences will not be used in the next
            # iteration. Again, if these sequences are continuations of
            # parent sequences, we will need to remove the parent sequences
            # from the sequence group.
            unselected_child_seqs.extend(running_child_seqs[beam_width:])

        # For newly created child sequences, add them to the sequence group
        # and fork them in block manager if they are not finished.
        for seq, parent in selected_child_seqs:
            if seq is not parent:
                seq_group.add(seq)
                if not seq.is_finished():
                    self.scheduler.fork_seq(parent, seq)

        # Free the finished and selected parent sequences' memory in block
        # manager. Keep them in the sequence group as candidate output.
        for seq, parent in selected_child_seqs:
            if seq is parent and seq.is_finished():
                self.scheduler.free_seq(seq)

        # Remove the unselected parent sequences from the sequence group and
        # free their memory in block manager.
        for seq, parent in unselected_child_seqs:
            if seq is parent:
                # Remove the parent sequence if it is not selected for next
                # iteration
                seq_group.remove(seq.seq_id)
                self.scheduler.free_seq(seq)

    def _process_model_outputs(
            self, output: SamplerOutput,
Antoni Baum's avatar
Antoni Baum committed
596
            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
597
        now = time.time()
598
599
        # Update the scheduled sequence groups with the model outputs.
        scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
600

601
602
603
604
        for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
            seq_group = scheduled_seq_group.seq_group
            token_chunk_size = scheduled_seq_group.token_chunk_size
            seq_group.update_num_computed_tokens(token_chunk_size)
605
            self._process_sequence_group_outputs(seq_group, outputs)
606
607
608

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
609
610
611

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
612
613
        for scheduled_seq_group in scheduled_seq_groups:
            seq_group = scheduled_seq_group.seq_group
614
            seq_group.maybe_set_first_token_time(now)
615
616
617
            request_output = RequestOutput.from_seq_group(seq_group)
            request_outputs.append(request_output)
        for seq_group in scheduler_outputs.ignored_seq_groups:
618
            request_output = RequestOutput.from_seq_group(seq_group)
619
            request_outputs.append(request_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
620

621
        # Log stats.
Woosuk Kwon's avatar
Woosuk Kwon committed
622
        if self.log_stats:
623
            self.stat_logger.log(self._get_stats(scheduler_outputs))
624
625
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
626
627
628
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

644
            - Step 2: Calls the distributed executor to execute the model.
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
            >>>         engine.add_request(str(req_id), prompt, sampling_params)
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
676
        """
677
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
678

679
        if not scheduler_outputs.is_empty():
680
681
682
683
            output = self.model_executor.execute_model(
                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
                scheduler_outputs.blocks_to_swap_out,
                scheduler_outputs.blocks_to_copy)
684
685
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
686

687
        return self._process_model_outputs(output, scheduler_outputs)
Antoni Baum's avatar
Antoni Baum committed
688

689
    def do_log_stats(self) -> None:
690
691
692
        """Forced log when no requests active."""
        if self.log_stats:
            self.stat_logger.log(self._get_stats(scheduler_outputs=None))
693

694
695
696
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
        """Get Stats to be Logged to Prometheus."""
697
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
698

699
700
701
702
        # KV Cache Usage in %.
        num_total_gpu = self.cache_config.num_gpu_blocks
        num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
        gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
703

704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
        num_total_cpu = self.cache_config.num_cpu_blocks
        cpu_cache_usage = 0.
        if num_total_cpu > 0:
            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
            )
            cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)

        # Scheduler State
        num_running = len(self.scheduler.running)
        num_swapped = len(self.scheduler.swapped)
        num_waiting = len(self.scheduler.waiting)

        # Iteration stats if we have scheduler output.
        num_prompt_tokens = 0
        num_generation_tokens = 0
        time_to_first_tokens = []
        time_per_output_tokens = []
        time_e2e_requests = []
        if scheduler_outputs is not None:
            prompt_run = scheduler_outputs.prompt_run

            # Number of Tokens.
            if prompt_run:
727
                num_prompt_tokens = sum(
728
729
730
                    len(scheduled_seq_group.seq_group.prompt_token_ids)
                    for scheduled_seq_group in
                    scheduler_outputs.scheduled_seq_groups)
731
                num_generation_tokens = sum(
732
733
734
                    scheduled_seq_group.seq_group.num_seqs()
                    for scheduled_seq_group in
                    scheduler_outputs.scheduled_seq_groups)
735
736
737
738
739
            else:
                num_generation_tokens = scheduler_outputs.num_batched_tokens

            # Latency Timings.
            time_last_iters = []
740
741
            for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
                seq_group = scheduled_seq_group.seq_group
742
743
                # Time since last token.
                # (n.b. updates seq_group.metrics.last_token_time)
744
745
746
                time_last_iters.append(seq_group.get_last_latency(now))
                # Time since arrival for all finished requests.
                if seq_group.is_finished():
747
748
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
749
750
751
752
753
754
755
756
757

            time_to_first_tokens = time_last_iters if prompt_run else []
            time_per_output_tokens = [] if prompt_run else time_last_iters

        return Stats(
            now=now,
            num_running=num_running,
            num_swapped=num_swapped,
            num_waiting=num_waiting,
758
759
            gpu_cache_usage=gpu_cache_usage,
            cpu_cache_usage=cpu_cache_usage,
760
761
762
763
764
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=num_generation_tokens,
            time_to_first_tokens=time_to_first_tokens,
            time_per_output_tokens=time_per_output_tokens,
            time_e2e_requests=time_e2e_requests,
765
766
        )

767
768
    def _check_stop(self, seq: Sequence,
                    sampling_params: SamplingParams) -> None:
769
        """Stop the finished sequences."""
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
        # Check if the sequence has reached max_model_len.
        if seq.get_len() > self.scheduler_config.max_model_len:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has reached max_tokens.
        if seq.get_output_len() == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the minimum number of tokens has been generated yet;
        # skip the stop string/token checks if not
        if seq.get_output_len() < sampling_params.min_tokens:
            return

785
786
        for stop_str in sampling_params.stop:
            if seq.output_text.endswith(stop_str):
787
                self._finalize_sequence(seq, sampling_params, stop_str)
788
                seq.status = SequenceStatus.FINISHED_STOPPED
789
                seq.stop_reason = stop_str
790
                return
791
792
        last_token_id = seq.get_last_token_id()
        if last_token_id in sampling_params.stop_token_ids:
793
            stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
794
                last_token_id)
795
            self._finalize_sequence(seq, sampling_params, stop_str)
796
            seq.status = SequenceStatus.FINISHED_STOPPED
797
            seq.stop_reason = last_token_id
798
            return
799
800

        # Check if the sequence has generated the EOS token.
801
802
        if ((not sampling_params.ignore_eos)
                and seq.get_last_token_id() == seq.eos_token_id):
803
804
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
805

806
807
808
    def _finalize_sequence(self, seq: Sequence,
                           sampling_params: SamplingParams,
                           stop_string: str) -> None:
809
810
811
812
        if sampling_params.include_stop_str_in_output:
            return

        if stop_string and seq.output_text.endswith(stop_string):
813
814
815
816
            # Truncate the output text so that the stop string is
            # not included in the output.
            seq.output_text = seq.output_text[:-len(stop_string)]

817
    def add_lora(self, lora_request: LoRARequest) -> bool:
818
        return self.model_executor.add_lora(lora_request)
819
820

    def remove_lora(self, lora_id: int) -> bool:
821
        return self.model_executor.remove_lora(lora_id)
822
823

    def list_loras(self) -> List[int]:
824
        return self.model_executor.list_loras()
825
826

    def check_health(self) -> None:
827
        self.model_executor.check_health()