llm_engine.py 35.5 KB
Newer Older
Antoni Baum's avatar
Antoni Baum committed
1
import time
2
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
3

4
5
from transformers import PreTrainedTokenizer

6
import vllm
7
from vllm.lora.request import LoRARequest
8
9
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, LoRAConfig)
Antoni Baum's avatar
Antoni Baum committed
10
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.engine.arg_utils import EngineArgs
12
from vllm.executor.executor_base import ExecutorBase
13
from vllm.engine.metrics import StatLogger, Stats
14
from vllm.engine.ray_utils import initialize_ray_cluster
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
18
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
19
                           SequenceGroupOutput, SequenceOutput, SequenceStatus)
20
21
22
from vllm.transformers_utils.tokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
23
from vllm.utils import Counter
24
25

logger = init_logger(__name__)
26
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
27

28

29
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
30
    """An LLM engine that receives requests and generates texts.
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
    This is the main class for the vLLM engine. It receives requests
33
34
35
36
37
38
39
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
40
    `AsyncLLMEngine` class wraps this class for online serving.
41

Zhuohan Li's avatar
Zhuohan Li committed
42
43
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
44
45
46
47
48
49
50

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
51
        device_config: The configuration related to the device.
52
53
        executor_class: The model executor class for managing distributed
            execution.
54
55
        log_stats: Whether to log statistics.
    """
56
57
58
59
60
61
62

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
63
        device_config: DeviceConfig,
64
        lora_config: Optional[LoRAConfig],
65
        executor_class: Type[ExecutorBase],
66
        log_stats: bool,
67
68
    ) -> None:
        logger.info(
69
            f"Initializing an LLM engine (v{vllm.__version__}) with config: "
70
            f"model={model_config.model!r}, "
71
            f"tokenizer={model_config.tokenizer!r}, "
72
            f"tokenizer_mode={model_config.tokenizer_mode}, "
Jasmond L's avatar
Jasmond L committed
73
            f"revision={model_config.revision}, "
74
            f"tokenizer_revision={model_config.tokenizer_revision}, "
75
            f"trust_remote_code={model_config.trust_remote_code}, "
76
            f"dtype={model_config.dtype}, "
77
            f"max_seq_len={model_config.max_model_len}, "
78
            f"download_dir={model_config.download_dir!r}, "
79
            f"load_format={model_config.load_format}, "
80
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
81
82
            f"disable_custom_all_reduce="
            f"{parallel_config.disable_custom_all_reduce}, "
83
            f"quantization={model_config.quantization}, "
84
            f"enforce_eager={model_config.enforce_eager}, "
85
            f"kv_cache_dtype={cache_config.cache_dtype}, "
86
            f"device_config={device_config.device}, "
87
            f"seed={model_config.seed})")
88
89
90
91
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
92
        self.lora_config = lora_config
93
94
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
95
        self.device_config = device_config
96
97
98
        self.log_stats = log_stats
        self._verify_args()

99
        self._init_tokenizer()
100
101
        self.seq_counter = Counter()

102
103
104
        self.model_executor = executor_class(model_config, cache_config,
                                             parallel_config, scheduler_config,
                                             device_config, lora_config)
105

106
107
108
109
        # Ping the tokenizer to ensure liveness if it runs in a
        # different process.
        self.tokenizer.ping()

110
        # Create the scheduler.
111
112
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
113
        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
114

115
116
117
        # Metric Logging.
        if self.log_stats:
            self.stat_logger = StatLogger(
118
119
                local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                labels=dict(model_name=model_config.model))
120
            self.stat_logger.info("cache_config", self.cache_config)
121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    @classmethod
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]

        # Initialize the cluster and specify the executor class.
        if parallel_config.worker_use_ray:
            initialize_ray_cluster(parallel_config)
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
        else:
            assert parallel_config.world_size == 1, (
                "Ray is required if parallel_config.world_size > 1.")
            from vllm.executor.gpu_executor import GPUExecutor
            executor_class = GPUExecutor

        # Create the LLM engine.
        engine = cls(*engine_configs,
                     executor_class=executor_class,
                     log_stats=not engine_args.disable_log_stats)
        return engine
145

146
147
148
149
150
    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

151
152
153
154
155
    def get_tokenizer(self) -> "PreTrainedTokenizer":
        return self.tokenizer.get_lora_tokenizer()

    def get_tokenizer_for_seq(self,
                              sequence: Sequence) -> "PreTrainedTokenizer":
156
157
158
159
        return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

    def _init_tokenizer(self, **tokenizer_init_kwargs):
        init_kwargs = dict(
160
            tokenizer_id=self.model_config.tokenizer,
161
162
163
164
165
166
167
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
168
169
170

        self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
            self.parallel_config.tokenizer_pool_config, **init_kwargs)
171

172
173
174
175
176
177
178
179
        if len(self.get_tokenizer()) != self.model_config.get_vocab_size():
            logger.warning(
                f"The tokenizer's vocabulary size {len(self.get_tokenizer())}"
                f" does not match the model's vocabulary size "
                f"{self.model_config.get_vocab_size()}. This might "
                f"cause an error in decoding. Please change config.json "
                "to match the tokenizer's vocabulary size.")

180
181
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
182
        self.cache_config.verify_with_parallel_config(self.parallel_config)
183
184
185
186
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def encode_request(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = self.tokenizer.encode(request_id=request_id,
                                                     prompt=prompt,
                                                     lora_request=lora_request)
        return prompt_token_ids

202
203
204
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
205
        prompt: Optional[str],
206
207
208
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
209
        lora_request: Optional[LoRARequest] = None,
210
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
211
        """Add a request to the engine's request pool.
212
213

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
214
        scheduler as `engine.step()` is called. The exact scheduling policy is
215
216
217
218
219
220
221
222
223
224
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
225
                the current monotonic time.
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
250
        """
251
252
253
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
254
255
256
257
258
259
260
        max_logprobs = self.get_model_config().max_logprobs
        if (sampling_params.logprobs
                and sampling_params.logprobs > max_logprobs) or (
                    sampling_params.prompt_logprobs
                    and sampling_params.prompt_logprobs > max_logprobs):
            raise ValueError(f"Cannot request more than "
                             f"{max_logprobs} logprobs.")
261
        if arrival_time is None:
262
            arrival_time = time.time()
263
264
265
266
267
        prompt_token_ids = self.encode_request(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)
268
269
270

        # Create the sequences.
        block_size = self.cache_config.block_size
271
        seq_id = next(self.seq_counter)
272
273
        eos_token_id = self.tokenizer.get_lora_tokenizer(
            lora_request).eos_token_id
274
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
275
                       eos_token_id, lora_request)
276

277
278
279
        # Defensive copy of SamplingParams, which are used by the sampler,
        # this doesn't deep-copy LogitsProcessor objects
        sampling_params = sampling_params.clone()
280

281
        # Create the sequence group.
282
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
283
                                  arrival_time, lora_request)
284
285
286
287

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
288
289
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
290
291

        Args:
Antoni Baum's avatar
Antoni Baum committed
292
            request_id: The ID(s) of the request to abort.
293
294
295
296
297
298
299
300
301
302
303

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
304
        """
305
306
        self.scheduler.abort_seq_group(request_id)

307
308
309
310
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

311
    def get_num_unfinished_requests(self) -> int:
312
        """Gets the number of unfinished requests."""
313
314
        return self.scheduler.get_num_unfinished_seq_groups()

315
    def has_unfinished_requests(self) -> bool:
316
        """Returns True if there are unfinished requests."""
317
318
        return self.scheduler.has_unfinished_seqs()

319
320
321
322
323
324
325
326
327
328
329
330
    def _check_beam_search_early_stopping(
        self,
        early_stopping: Union[bool, str],
        sampling_params: SamplingParams,
        best_running_seq: Sequence,
        current_worst_seq: Sequence,
    ) -> bool:
        assert sampling_params.use_beam_search
        length_penalty = sampling_params.length_penalty
        if early_stopping is True:
            return True

331
        current_worst_score = current_worst_seq.get_beam_search_score(
332
            length_penalty=length_penalty,
333
            eos_token_id=current_worst_seq.eos_token_id)
334
        if early_stopping is False:
335
            highest_attainable_score = best_running_seq.get_beam_search_score(
336
                length_penalty=length_penalty,
337
                eos_token_id=best_running_seq.eos_token_id)
338
339
340
341
342
343
344
345
346
347
348
349
350
        else:
            assert early_stopping == "never"
            if length_penalty > 0.0:
                # If length_penalty > 0.0, beam search will prefer longer
                # sequences. The highest attainable score calculation is
                # based on the longest possible sequence length in this case.
                max_possible_length = max(
                    best_running_seq.get_prompt_len() +
                    sampling_params.max_tokens,
                    self.scheduler_config.max_model_len)
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
351
                        eos_token_id=best_running_seq.eos_token_id,
352
353
354
355
356
357
358
359
                        seq_len=max_possible_length))
            else:
                # Otherwise, beam search will prefer shorter sequences. The
                # highest attainable score calculation is based on the current
                # sequence length.
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
360
                        eos_token_id=best_running_seq.eos_token_id))
361
362
        return current_worst_score >= highest_attainable_score

363
    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
Zhuohan Li's avatar
Zhuohan Li committed
364
                                        outputs: SequenceGroupOutput) -> None:
365

366
367
368
        # Process prompt logprobs
        prompt_logprobs = outputs.prompt_logprobs
        if prompt_logprobs is not None:
369
370
371
372
373
374
375
            # We can pick any sequence for the prompt.
            seq = next(iter(seq_group.seqs_dict.values()))
            all_token_ids = seq.get_token_ids()
            for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
                self._decode_logprobs(seq, seq_group.sampling_params,
                                      prompt_logprobs_for_token,
                                      all_token_ids[:i])
376
377
378
379
            seq_group.prompt_logprobs = prompt_logprobs

        # Process samples
        samples = outputs.samples
380
381
382
383
384
385
386
387
388
389
390
391
392
        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        existing_finished_seqs = seq_group.get_finished_seqs()
        parent_child_dict = {
            parent_seq.seq_id: []
            for parent_seq in parent_seqs
        }
        for sample in samples:
            parent_child_dict[sample.parent_seq_id].append(sample)
        # List of (child, parent)
        child_seqs: List[Tuple[Sequence, Sequence]] = []

        # Process the child samples for each parent sequence
        for parent in parent_seqs:
Zhuohan Li's avatar
Zhuohan Li committed
393
            child_samples: List[SequenceOutput] = parent_child_dict[
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
                parent.seq_id]
            if len(child_samples) == 0:
                # This parent sequence has no children samples. Remove
                # the parent sequence from the sequence group since it will
                # not be used in the future iterations.
                parent.status = SequenceStatus.FINISHED_ABORTED
                seq_group.remove(parent.seq_id)
                self.scheduler.free_seq(parent)
                continue
            # Fork the parent sequence if there are multiple child samples.
            for child_sample in child_samples[:-1]:
                new_child_seq_id = next(self.seq_counter)
                child = parent.fork(new_child_seq_id)
                child.append_token_id(child_sample.output_token,
                                      child_sample.logprobs)
                child_seqs.append((child, parent))
            # Continue the parent sequence for the last child sample.
            # We reuse the parent sequence here to reduce redundant memory
            # copies, especially when using non-beam search sampling methods.
            last_child_sample = child_samples[-1]
            parent.append_token_id(last_child_sample.output_token,
                                   last_child_sample.logprobs)
            child_seqs.append((parent, parent))

        for seq, _ in child_seqs:
419
            self._decode_sequence(seq, seq_group.sampling_params)
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
            self._check_stop(seq, seq_group.sampling_params)

        # Non-beam search case
        if not seq_group.sampling_params.use_beam_search:
            # For newly created child sequences, add them to the sequence group
            # and fork them in block manager if they are not finished.
            for seq, parent in child_seqs:
                if seq is not parent:
                    seq_group.add(seq)
                    if not seq.is_finished():
                        self.scheduler.fork_seq(parent, seq)

            # Free the finished and selected parent sequences' memory in block
            # manager. Keep them in the sequence group as candidate output.
            # NOTE: we need to fork the new sequences before freeing the
            # old sequences.
            for seq, parent in child_seqs:
                if seq is parent and seq.is_finished():
                    self.scheduler.free_seq(seq)
            return

        # Beam search case
        # Select the child sequences to keep in the sequence group.
        selected_child_seqs = []
        unselected_child_seqs = []
        beam_width = seq_group.sampling_params.best_of
        length_penalty = seq_group.sampling_params.length_penalty

        # Select the newly finished sequences with the highest scores
        # to replace existing finished sequences.
        # Tuple of (seq, parent, is_new)
        existing_finished_seqs = [(seq, None, False)
                                  for seq in existing_finished_seqs]
        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
                             if seq.is_finished()]
        all_finished_seqs = existing_finished_seqs + new_finished_seqs
        # Sort the finished sequences by their scores.
        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
458
            length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
                               reverse=True)
        for seq, parent, is_new in all_finished_seqs[:beam_width]:
            if is_new:
                # A newly generated child sequence finishes and has a high
                # score, so we will add it into the sequence group.
                selected_child_seqs.append((seq, parent))
        for seq, parent, is_new in all_finished_seqs[beam_width:]:
            if is_new:
                # A newly generated child sequence finishes but has a low
                # score, so we will not add it into the sequence group.
                # Additionally, if this sequence is a continuation of a
                # parent sequence, we will need remove the parent sequence
                # from the sequence group.
                unselected_child_seqs.append((seq, parent))
            else:
                # An existing finished sequence has a low score, so we will
                # remove it from the sequence group.
                seq_group.remove(seq.seq_id)

        # select the top beam_width sequences from the running
        # sequences for the next iteration to continue the beam
        # search.
        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                              if not seq.is_finished()]
        # Sort the running sequences by their scores.
        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
485
            length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
                                reverse=True)

        # Check if we can stop the beam search.
        if len(running_child_seqs) == 0:
            # No running sequences, stop the beam search.
            stop_beam_search = True
        elif len(all_finished_seqs) < beam_width:
            # Not enough finished sequences, continue the beam search.
            stop_beam_search = False
        else:
            # Check the early stopping criteria
            best_running_seq = running_child_seqs[0][0]
            current_worst_seq = all_finished_seqs[beam_width - 1][0]
            stop_beam_search = self._check_beam_search_early_stopping(
                seq_group.sampling_params.early_stopping,
                seq_group.sampling_params, best_running_seq, current_worst_seq)

        if stop_beam_search:
            # Stop the beam search and remove all the running sequences from
            # the sequence group.
            unselected_child_seqs.extend(running_child_seqs)
        else:
            # Continue the beam search and select the top beam_width sequences
            # to continue the beam search.
            selected_child_seqs.extend(running_child_seqs[:beam_width])
            # The remaining running sequences will not be used in the next
            # iteration. Again, if these sequences are continuations of
            # parent sequences, we will need to remove the parent sequences
            # from the sequence group.
            unselected_child_seqs.extend(running_child_seqs[beam_width:])

        # For newly created child sequences, add them to the sequence group
        # and fork them in block manager if they are not finished.
        for seq, parent in selected_child_seqs:
            if seq is not parent:
                seq_group.add(seq)
                if not seq.is_finished():
                    self.scheduler.fork_seq(parent, seq)

        # Free the finished and selected parent sequences' memory in block
        # manager. Keep them in the sequence group as candidate output.
        for seq, parent in selected_child_seqs:
            if seq is parent and seq.is_finished():
                self.scheduler.free_seq(seq)

        # Remove the unselected parent sequences from the sequence group and
        # free their memory in block manager.
        for seq, parent in unselected_child_seqs:
            if seq is parent:
                # Remove the parent sequence if it is not selected for next
                # iteration
                seq_group.remove(seq.seq_id)
                self.scheduler.free_seq(seq)

    def _process_model_outputs(
            self, output: SamplerOutput,
Antoni Baum's avatar
Antoni Baum committed
542
            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
543
        now = time.time()
544
545
        # Update the scheduled sequence groups with the model outputs.
        scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
546
547
548
549
550
551
552

        # If prefix caching is enabled, mark all blocks in the sequence groups
        # as completed so that future requests don't attempt to recompute them
        if self.cache_config.enable_prefix_caching:
            for seq_group in scheduled_seq_groups:
                self.scheduler.mark_blocks_as_computed(seq_group)

553
554
        for seq_group, outputs in zip(scheduled_seq_groups, output):
            self._process_sequence_group_outputs(seq_group, outputs)
555
556
557

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
558
559
560

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
561
        for seq_group in scheduled_seq_groups:
562
            seq_group.maybe_set_first_token_time(now)
563
564
565
            request_output = RequestOutput.from_seq_group(seq_group)
            request_outputs.append(request_output)
        for seq_group in scheduler_outputs.ignored_seq_groups:
566
            request_output = RequestOutput.from_seq_group(seq_group)
567
            request_outputs.append(request_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
568

569
        # Log stats.
Woosuk Kwon's avatar
Woosuk Kwon committed
570
        if self.log_stats:
571
            self.stat_logger.log(self._get_stats(scheduler_outputs))
572
573
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
574
575
576
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

592
            - Step 2: Calls the distributed executor to execute the model.
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
            >>>         engine.add_request(str(req_id), prompt, sampling_params)
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
624
        """
625
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
626

627
        if not scheduler_outputs.is_empty():
628
629
630
631
            output = self.model_executor.execute_model(
                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
                scheduler_outputs.blocks_to_swap_out,
                scheduler_outputs.blocks_to_copy)
632
633
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
634

635
        return self._process_model_outputs(output, scheduler_outputs)
Antoni Baum's avatar
Antoni Baum committed
636

637
    def do_log_stats(self) -> None:
638
639
640
        """Forced log when no requests active."""
        if self.log_stats:
            self.stat_logger.log(self._get_stats(scheduler_outputs=None))
641

642
643
644
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
        """Get Stats to be Logged to Prometheus."""
645
        now = time.time()
Woosuk Kwon's avatar
Woosuk Kwon committed
646

647
648
649
650
        # KV Cache Usage in %.
        num_total_gpu = self.cache_config.num_gpu_blocks
        num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
        gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
651

652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        num_total_cpu = self.cache_config.num_cpu_blocks
        cpu_cache_usage = 0.
        if num_total_cpu > 0:
            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
            )
            cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)

        # Scheduler State
        num_running = len(self.scheduler.running)
        num_swapped = len(self.scheduler.swapped)
        num_waiting = len(self.scheduler.waiting)

        # Iteration stats if we have scheduler output.
        num_prompt_tokens = 0
        num_generation_tokens = 0
        time_to_first_tokens = []
        time_per_output_tokens = []
        time_e2e_requests = []
        if scheduler_outputs is not None:
            prompt_run = scheduler_outputs.prompt_run

            # Number of Tokens.
            if prompt_run:
675
676
677
                num_prompt_tokens = sum(
                    len(seq_group.prompt_token_ids)
                    for seq_group in scheduler_outputs.scheduled_seq_groups)
678
679
680
                num_generation_tokens = sum(
                    seq_group.num_seqs()
                    for seq_group in scheduler_outputs.scheduled_seq_groups)
681
682
683
684
685
686
            else:
                num_generation_tokens = scheduler_outputs.num_batched_tokens

            # Latency Timings.
            time_last_iters = []
            for seq_group in scheduler_outputs.scheduled_seq_groups:
687
688
                # Time since last token.
                # (n.b. updates seq_group.metrics.last_token_time)
689
690
691
                time_last_iters.append(seq_group.get_last_latency(now))
                # Time since arrival for all finished requests.
                if seq_group.is_finished():
692
693
                    time_e2e_requests.append(now -
                                             seq_group.metrics.arrival_time)
694
695
696
697
698
699
700
701
702

            time_to_first_tokens = time_last_iters if prompt_run else []
            time_per_output_tokens = [] if prompt_run else time_last_iters

        return Stats(
            now=now,
            num_running=num_running,
            num_swapped=num_swapped,
            num_waiting=num_waiting,
703
704
            gpu_cache_usage=gpu_cache_usage,
            cpu_cache_usage=cpu_cache_usage,
705
706
707
708
709
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=num_generation_tokens,
            time_to_first_tokens=time_to_first_tokens,
            time_per_output_tokens=time_per_output_tokens,
            time_e2e_requests=time_e2e_requests,
710
711
        )

712
713
714
715
716
717
718
719
    def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
                         logprobs: Dict[int, Logprob],
                         all_input_ids: List[int]) -> None:
        if not logprobs:
            return
        for token_id, sample_logprob in logprobs.items():
            if (sample_logprob.decoded_token is None and token_id != -1):
                all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
720
721
722
723
724
725
726
727
728
729
730
                (_, new_text, prefix_offset,
                 read_offset) = detokenize_incrementally(
                     self.get_tokenizer_for_seq(seq),
                     all_input_ids=all_input_ids_with_logprob,
                     prev_tokens=seq.tokens,
                     prefix_offset=seq.prefix_offset,
                     read_offset=seq.read_offset,
                     skip_special_tokens=prms.skip_special_tokens,
                     spaces_between_special_tokens=prms.
                     spaces_between_special_tokens,
                 )
731
732
                sample_logprob.decoded_token = new_text

733
    def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
734
        """Decodes the new token for a sequence."""
735
736
737
738
        all_input_ids = seq.get_token_ids()
        self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
                              all_input_ids)

739
740
        (new_tokens, new_output_text, prefix_offset,
         read_offset) = detokenize_incrementally(
741
             self.get_tokenizer_for_seq(seq),
742
             all_input_ids=all_input_ids,
743
744
745
             prev_tokens=seq.tokens,
             prefix_offset=seq.prefix_offset,
             read_offset=seq.read_offset,
746
747
             skip_special_tokens=prms.skip_special_tokens,
             spaces_between_special_tokens=prms.spaces_between_special_tokens,
748
749
750
751
752
753
754
755
         )
        if seq.tokens is None:
            seq.tokens = new_tokens
        else:
            seq.tokens.extend(new_tokens)
        seq.prefix_offset = prefix_offset
        seq.read_offset = read_offset
        seq.output_text += new_output_text
756
757
758

    def _check_stop(self, seq: Sequence,
                    sampling_params: SamplingParams) -> None:
759
        """Stop the finished sequences."""
760
761
        for stop_str in sampling_params.stop:
            if seq.output_text.endswith(stop_str):
762
                self._finalize_sequence(seq, sampling_params, stop_str)
763
764
                seq.status = SequenceStatus.FINISHED_STOPPED
                return
765
        if seq.get_last_token_id() in sampling_params.stop_token_ids:
766
767
768
            stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
                seq.get_last_token_id())
            self._finalize_sequence(seq, sampling_params, stop_str)
769
770
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
771
772
773
774
775
776
777
778
779
780
781
782

        # Check if the sequence has reached max_model_len.
        if seq.get_len() > self.scheduler_config.max_model_len:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has reached max_tokens.
        if seq.get_output_len() == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has generated the EOS token.
783
784
        if ((not sampling_params.ignore_eos)
                and seq.get_last_token_id() == seq.eos_token_id):
785
786
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
787

788
789
790
    def _finalize_sequence(self, seq: Sequence,
                           sampling_params: SamplingParams,
                           stop_string: str) -> None:
791
792
793
794
        if sampling_params.include_stop_str_in_output:
            return

        if stop_string and seq.output_text.endswith(stop_string):
795
796
797
798
            # Truncate the output text so that the stop string is
            # not included in the output.
            seq.output_text = seq.output_text[:-len(stop_string)]

799
    def add_lora(self, lora_request: LoRARequest) -> bool:
800
        return self.model_executor.add_lora(lora_request)
801
802

    def remove_lora(self, lora_id: int) -> bool:
803
        return self.model_executor.remove_lora(lora_id)
804
805

    def list_loras(self) -> List[int]:
806
        return self.model_executor.list_loras()
807
808

    def check_health(self) -> None:
809
        self.model_executor.check_health()