llm_engine.py 44.5 KB
Newer Older
Fang li's avatar
Fang li committed
1
import copy
2
from collections import defaultdict
3
import os
Antoni Baum's avatar
Antoni Baum committed
4
import time
5
import pickle
6
7
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
                    Union)
8

9
from vllm.lora.request import LoRARequest
10
11
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, LoRAConfig)
Antoni Baum's avatar
Antoni Baum committed
12
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import EngineArgs
14
from vllm.engine.metrics import StatLogger, Stats
15
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
19
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
20
                           SequenceGroupOutput, SequenceOutput, SequenceStatus)
21
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
22
                                               TokenizerGroup)
23
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
24
25
26
27
28
29

if ray:
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup
30
31

logger = init_logger(__name__)
32
_LOCAL_LOGGING_INTERVAL_SEC = 5
Woosuk Kwon's avatar
Woosuk Kwon committed
33

34
35
36
37
38
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))

39

40
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
41
    """An LLM engine that receives requests and generates texts.
42

Woosuk Kwon's avatar
Woosuk Kwon committed
43
    This is the main class for the vLLM engine. It receives requests
44
45
46
47
48
49
50
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
51
    `AsyncLLMEngine` class wraps this class for online serving.
52

Zhuohan Li's avatar
Zhuohan Li committed
53
54
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
55
56
57
58
59
60
61

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
62
        device_config: The configuration related to the device.
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
63
64
        placement_group: Ray placement group for distributed execution.
            Required for distributed execution.
65
66
        log_stats: Whether to log statistics.
    """
67
68
69
70
71
72
73

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
74
        device_config: DeviceConfig,
75
        lora_config: Optional[LoRAConfig],
76
        placement_group: Optional["PlacementGroup"],
77
        log_stats: bool,
78
79
    ) -> None:
        logger.info(
Zhuohan Li's avatar
Zhuohan Li committed
80
            "Initializing an LLM engine with config: "
81
            f"model={model_config.model!r}, "
82
            f"tokenizer={model_config.tokenizer!r}, "
83
            f"tokenizer_mode={model_config.tokenizer_mode}, "
Jasmond L's avatar
Jasmond L committed
84
            f"revision={model_config.revision}, "
85
            f"tokenizer_revision={model_config.tokenizer_revision}, "
86
            f"trust_remote_code={model_config.trust_remote_code}, "
87
            f"dtype={model_config.dtype}, "
88
            f"max_seq_len={model_config.max_model_len}, "
89
            f"download_dir={model_config.download_dir!r}, "
90
            f"load_format={model_config.load_format}, "
91
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
92
            f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
93
            f"quantization={model_config.quantization}, "
94
            f"enforce_eager={model_config.enforce_eager}, "
95
            f"kv_cache_dtype={cache_config.cache_dtype}, "
96
            f"device_config={device_config.device}, "
97
            f"seed={model_config.seed})")
98
99
100
101
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
102
        self.lora_config = lora_config
103
104
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
105
        self.device_config = device_config
106
107
108
        self.log_stats = log_stats
        self._verify_args()

109
        self._init_tokenizer()
110
111
112
        self.seq_counter = Counter()

        # Create the parallel GPU workers.
113
        if self.parallel_config.worker_use_ray:
114
115
116
117
            # Disable Ray usage stats collection.
            ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
            if ray_usage != "1":
                os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
118
119
            self._init_workers_ray(placement_group)
        else:
120
            self._init_workers()
121

122
123
124
125
        # Profile the memory usage and initialize the cache.
        self._init_cache()

        # Create the scheduler.
126
        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
127

128
129
130
131
        # Metric Logging.
        if self.log_stats:
            self.stat_logger = StatLogger(
                local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
132

133
134
135
136
        self.forward_dag = None
        if USE_RAY_COMPILED_DAG:
            self.forward_dag = self._compiled_ray_dag()

137
138
139
    def get_tokenizer_for_seq(self, sequence: Sequence):
        return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

140
    def _init_workers(self):
141
142
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
143
        from vllm.worker.worker import Worker
144
145
146
147
148

        assert self.parallel_config.world_size == 1, (
            "Ray is required if parallel_config.world_size > 1.")

        self.workers: List[Worker] = []
149
150
        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
151
        self.driver_worker = Worker(
152
153
154
            self.model_config,
            self.parallel_config,
            self.scheduler_config,
155
            self.device_config,
156
157
158
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
159
            lora_config=self.lora_config,
160
            kv_cache_dtype=self.cache_config.cache_dtype,
161
            is_driver_worker=True,
162
        )
163
164
        self._run_workers("init_model")
        self._run_workers("load_model")
165

166
167
168
169
170
171
172
173
174
175
176
177
    def _init_tokenizer(self, **tokenizer_init_kwargs):
        init_kwargs = dict(
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
        self.tokenizer: TokenizerGroup = TokenizerGroup(
            self.model_config.tokenizer, **init_kwargs)

Antoni Baum's avatar
Antoni Baum committed
178
179
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
180
181
182
183
        if self.parallel_config.tensor_parallel_size == 1:
            num_gpus = self.cache_config.gpu_memory_utilization
        else:
            num_gpus = 1
184

185
186
187
188
189
        self.driver_dummy_worker: RayWorkerVllm = None
        self.workers: List[RayWorkerVllm] = []

        driver_ip = get_ip()
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
190
191
            if not bundle.get("GPU", 0):
                continue
192
193
194
195
196
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
197
198
            worker = ray.remote(
                num_cpus=0,
Woosuk Kwon's avatar
Woosuk Kwon committed
199
                num_gpus=num_gpus,
200
                scheduling_strategy=scheduling_strategy,
Antoni Baum's avatar
Antoni Baum committed
201
                **ray_remote_kwargs,
202
            )(RayWorkerVllm).remote(self.model_config.trust_remote_code)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

            worker_ip = ray.get(worker.get_node_ip.remote())
            if worker_ip == driver_ip and self.driver_dummy_worker is None:
                # If the worker is on the same node as the driver, we use it
                # as the resource holder for the driver process.
                self.driver_dummy_worker = worker
            else:
                self.workers.append(worker)

        if self.driver_dummy_worker is None:
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

        driver_node_id, driver_gpu_ids = ray.get(
            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
        worker_node_and_gpu_ids = ray.get(
            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])

        node_workers = defaultdict(list)
        node_gpus = defaultdict(list)

        node_workers[driver_node_id].append(0)
        node_gpus[driver_node_id].extend(driver_gpu_ids)
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
                                               start=1):
            node_workers[node_id].append(i)
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

        # Set CUDA_VISIBLE_DEVICES for the driver.
        set_cuda_visible_devices(node_gpus[driver_node_id])
        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
            worker.set_cuda_visible_devices.remote(node_gpus[node_id])

240
        distributed_init_method = get_distributed_init_method(
241
            driver_ip, get_open_port())
242
243
244
245

        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker
246
247

        # Initialize torch distributed process group for the workers.
Fang li's avatar
Fang li committed
248
249
250
        model_config = copy.deepcopy(self.model_config)
        parallel_config = copy.deepcopy(self.parallel_config)
        scheduler_config = copy.deepcopy(self.scheduler_config)
251
        device_config = copy.deepcopy(self.device_config)
252
253
254
255
256
257
258
259
260
261
262

        for rank, (worker, (node_id,
                            _)) in enumerate(zip(self.workers,
                                                 worker_node_and_gpu_ids),
                                             start=1):
            local_rank = node_workers[node_id].index(rank)
            worker.init_worker.remote(
                lambda rank=rank, local_rank=local_rank: Worker(
                    model_config,
                    parallel_config,
                    scheduler_config,
263
                    device_config,
264
265
266
                    local_rank,
                    rank,
                    distributed_init_method,
267
                    lora_config=self.lora_config,
zhaoyang-star's avatar
zhaoyang-star committed
268
                    kv_cache_dtype=self.cache_config.cache_dtype,
269
270
271
272
273
274
275
276
                ))

        driver_rank = 0
        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
        self.driver_worker = Worker(
            model_config,
            parallel_config,
            scheduler_config,
277
            device_config,
278
279
280
            driver_local_rank,
            driver_rank,
            distributed_init_method,
281
            lora_config=self.lora_config,
zhaoyang-star's avatar
zhaoyang-star committed
282
            kv_cache_dtype=self.cache_config.cache_dtype,
283
            is_driver_worker=True,
284
        )
285

Woosuk Kwon's avatar
Woosuk Kwon committed
286
        self._run_workers("init_model", cupy_port=get_open_port())
287
288
289
290
291
        self._run_workers(
            "load_model",
            max_concurrent_workers=self.parallel_config.
            max_parallel_loading_workers,
        )
292

293
294
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
295
        self.cache_config.verify_with_parallel_config(self.parallel_config)
296
297
298
299
        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
300
301

    def _init_cache(self) -> None:
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        """Profiles the memory usage and initializes the KV cache.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.
        More details can be found in the
        :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
        from class :class:`~vllm.worker.Worker`.

        Afterwards, as there may be multiple workers,
        we take the minimum number of blocks across all workers
        to ensure this can be applied to all of them.

        Finally, the engine will initialize the KV cache
        with the calculated number of blocks.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameters.
        """
322
323
324
325
326
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers(
            "profile_num_available_blocks",
            block_size=self.cache_config.block_size,
            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
327
            cpu_swap_space=self.cache_config.swap_space_bytes,
328
            cache_dtype=self.cache_config.cache_dtype,
329
330
331
332
333
334
335
336
        )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
        # FIXME(woosuk): Change to debug log.
337
338
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")
339

340
        if num_gpu_blocks <= 0:
341
342
343
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `gpu_memory_utilization` when "
                             "initializing the engine.")
344
345
346
347
348
349
350
351
        max_seq_len = self.cache_config.block_size * num_gpu_blocks
        if self.model_config.max_model_len > max_seq_len:
            raise ValueError(
                f"The model's max seq len ({self.model_config.max_model_len}) "
                "is larger than the maximum number of tokens that can be "
                f"stored in KV cache ({max_seq_len}). Try increasing "
                "`gpu_memory_utilization` or decreasing `max_model_len` when "
                "initializing the engine.")
352

353
354
355
356
357
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        # Initialize the cache.
        self._run_workers("init_cache_engine", cache_config=self.cache_config)
358
359
360
        # Warm up the model. This includes capturing the model into CUDA graph
        # if enforce_eager is False.
        self._run_workers("warm_up_model")
361

362
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
363
364
365
366
367
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
368
        # Initialize the cluster.
369
        placement_group = initialize_cluster(parallel_config)
Zhuohan Li's avatar
Zhuohan Li committed
370
        # Create the LLM engine.
371
        engine = cls(*engine_configs,
372
                     placement_group,
Zhuohan Li's avatar
Zhuohan Li committed
373
374
                     log_stats=not engine_args.disable_log_stats)
        return engine
375

376
377
378
379
380
381
382
383
384
385
386
387
388
389
    def encode_request(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = self.tokenizer.encode(request_id=request_id,
                                                     prompt=prompt,
                                                     lora_request=lora_request)
        return prompt_token_ids

390
391
392
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
393
        prompt: Optional[str],
394
395
396
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
397
        lora_request: Optional[LoRARequest] = None,
398
        prefix_pos: Optional[int] = None,
399
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
400
        """Add a request to the engine's request pool.
401
402

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
403
        scheduler as `engine.step()` is called. The exact scheduling policy is
404
405
406
407
408
409
410
411
412
413
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
414
                the current monotonic time.
415
416
417
418
419
            prefix_pos: If not None, we use the given position as the prefix
                position for each prompt. We will cache the prefix's KV
                cache and reuse it for the next request with the same prefix.
                This is an experimental feature, and may be replaced with
                automatic prefix caching in the future.
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443

        Details:
            - Set arrival_time to the current time if it is None.
            - Set prompt_token_ids to the encoded prompt if it is None.
            - Create `best_of` number of :class:`~vllm.Sequence` objects.
            - Create a :class:`~vllm.SequenceGroup` object
              from the list of :class:`~vllm.Sequence`.
            - Add the :class:`~vllm.SequenceGroup` object to the scheduler.

        Example:
            >>> # initialize engine
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> # set request arguments
            >>> example_prompt = "Who is the president of the United States?"
            >>> sampling_params = SamplingParams(temperature=0.0)
            >>> request_id = 0
            >>>
            >>> # add the request to the engine
            >>> engine.add_request(
            >>>    str(request_id),
            >>>    example_prompt,
            >>>    SamplingParams(temperature=0.0))
            >>> # continue the request processing
            >>> ...
444
        """
445
446
447
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
448
        if arrival_time is None:
449
            arrival_time = time.monotonic()
450
451
452
453
454
        prompt_token_ids = self.encode_request(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)
455
456
457

        # Create the sequences.
        block_size = self.cache_config.block_size
458
        seq_id = next(self.seq_counter)
459
460
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
                       lora_request)
461

462
463
        # Check whether the input specifies prefix
        prefix = self.scheduler.prefix_pool.add_or_get_prefix(
464
465
            prompt_token_ids[:prefix_pos], lora_request.lora_int_id
            if lora_request else 0) if prefix_pos is not None else None
466

467
        # Create the sequence group.
468
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
469
                                  arrival_time, lora_request, prefix)
470
471
472
473

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
474
475
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
476
477

        Args:
Antoni Baum's avatar
Antoni Baum committed
478
            request_id: The ID(s) of the request to abort.
479
480
481
482
483
484
485
486
487
488
489

        Details:
            - Refer to the
              :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
              from class :class:`~vllm.core.scheduler.Scheduler`.

        Example:
            >>> # initialize engine and add a request with request_id
            >>> request_id = str(0)
            >>> # abort the request
            >>> engine.abort_request(request_id)
490
        """
491
492
        self.scheduler.abort_seq_group(request_id)

493
494
495
496
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

497
    def get_num_unfinished_requests(self) -> int:
498
        """Gets the number of unfinished requests."""
499
500
        return self.scheduler.get_num_unfinished_seq_groups()

501
    def has_unfinished_requests(self) -> bool:
502
        """Returns True if there are unfinished requests."""
503
504
        return self.scheduler.has_unfinished_seqs()

505
506
507
508
509
510
511
512
513
514
515
516
517
518
    def _check_beam_search_early_stopping(
        self,
        early_stopping: Union[bool, str],
        sampling_params: SamplingParams,
        best_running_seq: Sequence,
        current_worst_seq: Sequence,
    ) -> bool:
        assert sampling_params.use_beam_search
        length_penalty = sampling_params.length_penalty
        if early_stopping is True:
            return True

        current_worst_score = (current_worst_seq.get_beam_search_score(
            length_penalty=length_penalty,
519
520
            eos_token_id=self.get_tokenizer_for_seq(
                current_worst_seq).eos_token_id))
521
522
523
        if early_stopping is False:
            highest_attainable_score = (best_running_seq.get_beam_search_score(
                length_penalty=length_penalty,
524
525
                eos_token_id=self.get_tokenizer_for_seq(
                    best_running_seq).eos_token_id))
526
527
528
529
530
531
532
533
534
535
536
537
538
        else:
            assert early_stopping == "never"
            if length_penalty > 0.0:
                # If length_penalty > 0.0, beam search will prefer longer
                # sequences. The highest attainable score calculation is
                # based on the longest possible sequence length in this case.
                max_possible_length = max(
                    best_running_seq.get_prompt_len() +
                    sampling_params.max_tokens,
                    self.scheduler_config.max_model_len)
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
539
540
                        eos_token_id=self.get_tokenizer_for_seq(
                            best_running_seq).eos_token_id,
541
542
543
544
545
546
547
548
                        seq_len=max_possible_length))
            else:
                # Otherwise, beam search will prefer shorter sequences. The
                # highest attainable score calculation is based on the current
                # sequence length.
                highest_attainable_score = (
                    best_running_seq.get_beam_search_score(
                        length_penalty=length_penalty,
549
550
                        eos_token_id=self.get_tokenizer_for_seq(
                            best_running_seq).eos_token_id))
551
552
        return current_worst_score >= highest_attainable_score

553
    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
Zhuohan Li's avatar
Zhuohan Li committed
554
                                        outputs: SequenceGroupOutput) -> None:
555

556
557
558
559
560
561
562
        # Process prompt logprobs
        prompt_logprobs = outputs.prompt_logprobs
        if prompt_logprobs is not None:
            seq_group.prompt_logprobs = prompt_logprobs

        # Process samples
        samples = outputs.samples
563
564
565
566
567
568
569
570
571
572
573
574
575
        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        existing_finished_seqs = seq_group.get_finished_seqs()
        parent_child_dict = {
            parent_seq.seq_id: []
            for parent_seq in parent_seqs
        }
        for sample in samples:
            parent_child_dict[sample.parent_seq_id].append(sample)
        # List of (child, parent)
        child_seqs: List[Tuple[Sequence, Sequence]] = []

        # Process the child samples for each parent sequence
        for parent in parent_seqs:
Zhuohan Li's avatar
Zhuohan Li committed
576
            child_samples: List[SequenceOutput] = parent_child_dict[
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
                parent.seq_id]
            if len(child_samples) == 0:
                # This parent sequence has no children samples. Remove
                # the parent sequence from the sequence group since it will
                # not be used in the future iterations.
                parent.status = SequenceStatus.FINISHED_ABORTED
                seq_group.remove(parent.seq_id)
                self.scheduler.free_seq(parent)
                continue
            # Fork the parent sequence if there are multiple child samples.
            for child_sample in child_samples[:-1]:
                new_child_seq_id = next(self.seq_counter)
                child = parent.fork(new_child_seq_id)
                child.append_token_id(child_sample.output_token,
                                      child_sample.logprobs)
                child_seqs.append((child, parent))
            # Continue the parent sequence for the last child sample.
            # We reuse the parent sequence here to reduce redundant memory
            # copies, especially when using non-beam search sampling methods.
            last_child_sample = child_samples[-1]
            parent.append_token_id(last_child_sample.output_token,
                                   last_child_sample.logprobs)
            child_seqs.append((parent, parent))

        for seq, _ in child_seqs:
602
            self._decode_sequence(seq, seq_group.sampling_params)
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
            self._check_stop(seq, seq_group.sampling_params)

        # Non-beam search case
        if not seq_group.sampling_params.use_beam_search:
            # For newly created child sequences, add them to the sequence group
            # and fork them in block manager if they are not finished.
            for seq, parent in child_seqs:
                if seq is not parent:
                    seq_group.add(seq)
                    if not seq.is_finished():
                        self.scheduler.fork_seq(parent, seq)

            # Free the finished and selected parent sequences' memory in block
            # manager. Keep them in the sequence group as candidate output.
            # NOTE: we need to fork the new sequences before freeing the
            # old sequences.
            for seq, parent in child_seqs:
                if seq is parent and seq.is_finished():
                    self.scheduler.free_seq(seq)
            return

        # Beam search case
        # Select the child sequences to keep in the sequence group.
        selected_child_seqs = []
        unselected_child_seqs = []
        beam_width = seq_group.sampling_params.best_of
        length_penalty = seq_group.sampling_params.length_penalty

        # Select the newly finished sequences with the highest scores
        # to replace existing finished sequences.
        # Tuple of (seq, parent, is_new)
        existing_finished_seqs = [(seq, None, False)
                                  for seq in existing_finished_seqs]
        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
                             if seq.is_finished()]
        all_finished_seqs = existing_finished_seqs + new_finished_seqs
        # Sort the finished sequences by their scores.
        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
            length_penalty=length_penalty,
642
            eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
                               reverse=True)
        for seq, parent, is_new in all_finished_seqs[:beam_width]:
            if is_new:
                # A newly generated child sequence finishes and has a high
                # score, so we will add it into the sequence group.
                selected_child_seqs.append((seq, parent))
        for seq, parent, is_new in all_finished_seqs[beam_width:]:
            if is_new:
                # A newly generated child sequence finishes but has a low
                # score, so we will not add it into the sequence group.
                # Additionally, if this sequence is a continuation of a
                # parent sequence, we will need remove the parent sequence
                # from the sequence group.
                unselected_child_seqs.append((seq, parent))
            else:
                # An existing finished sequence has a low score, so we will
                # remove it from the sequence group.
                seq_group.remove(seq.seq_id)

        # select the top beam_width sequences from the running
        # sequences for the next iteration to continue the beam
        # search.
        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                              if not seq.is_finished()]
        # Sort the running sequences by their scores.
        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
            length_penalty=length_penalty,
670
            eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
                                reverse=True)

        # Check if we can stop the beam search.
        if len(running_child_seqs) == 0:
            # No running sequences, stop the beam search.
            stop_beam_search = True
        elif len(all_finished_seqs) < beam_width:
            # Not enough finished sequences, continue the beam search.
            stop_beam_search = False
        else:
            # Check the early stopping criteria
            best_running_seq = running_child_seqs[0][0]
            current_worst_seq = all_finished_seqs[beam_width - 1][0]
            stop_beam_search = self._check_beam_search_early_stopping(
                seq_group.sampling_params.early_stopping,
                seq_group.sampling_params, best_running_seq, current_worst_seq)

        if stop_beam_search:
            # Stop the beam search and remove all the running sequences from
            # the sequence group.
            unselected_child_seqs.extend(running_child_seqs)
        else:
            # Continue the beam search and select the top beam_width sequences
            # to continue the beam search.
            selected_child_seqs.extend(running_child_seqs[:beam_width])
            # The remaining running sequences will not be used in the next
            # iteration. Again, if these sequences are continuations of
            # parent sequences, we will need to remove the parent sequences
            # from the sequence group.
            unselected_child_seqs.extend(running_child_seqs[beam_width:])

        # For newly created child sequences, add them to the sequence group
        # and fork them in block manager if they are not finished.
        for seq, parent in selected_child_seqs:
            if seq is not parent:
                seq_group.add(seq)
                if not seq.is_finished():
                    self.scheduler.fork_seq(parent, seq)

        # Free the finished and selected parent sequences' memory in block
        # manager. Keep them in the sequence group as candidate output.
        for seq, parent in selected_child_seqs:
            if seq is parent and seq.is_finished():
                self.scheduler.free_seq(seq)

        # Remove the unselected parent sequences from the sequence group and
        # free their memory in block manager.
        for seq, parent in unselected_child_seqs:
            if seq is parent:
                # Remove the parent sequence if it is not selected for next
                # iteration
                seq_group.remove(seq.seq_id)
                self.scheduler.free_seq(seq)

    def _process_model_outputs(
            self, output: SamplerOutput,
Antoni Baum's avatar
Antoni Baum committed
727
            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
728
729
        # Update the scheduled sequence groups with the model outputs.
        scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
730
731
        for seq_group, outputs in zip(scheduled_seq_groups, output):
            self._process_sequence_group_outputs(seq_group, outputs)
732
733
734

        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
735
736
737

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
738
739
740
741
        for seq_group in scheduled_seq_groups:
            request_output = RequestOutput.from_seq_group(seq_group)
            request_outputs.append(request_output)
        for seq_group in scheduler_outputs.ignored_seq_groups:
742
            request_output = RequestOutput.from_seq_group(seq_group)
743
            request_outputs.append(request_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
744

745
746
747
748
749
750
        # Update prefix state, now all the uncomputed prefixes are computed.
        for seq_group in scheduled_seq_groups:
            if (seq_group.prefix is not None and seq_group.prefix.allocated
                    and not seq_group.prefix.computed):
                seq_group.prefix.computed = True

751
        # Log stats.
Woosuk Kwon's avatar
Woosuk Kwon committed
752
        if self.log_stats:
753
754
            self.stat_logger.log(self._get_stats(scheduler_outputs))

755
756
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
757
758
759
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
        .. figure:: https://i.imgur.com/sv2HssD.png
            :alt: Overview of the step function
            :align: center

            Overview of the step function.

        Details:
            - Step 1: Schedules the sequences to be executed in the next
              iteration and the token blocks to be swapped in/out/copy.

                - Depending on the scheduling policy,
                  sequences may be `preempted/reordered`.
                - A Sequence Group (SG) refer to a group of sequences
                  that are generated from the same prompt.

            - Step 2: Calls the workers to execute the model.
            - Step 3: Processes the model output. This mainly includes:

                - Decodes the relevant outputs.
                - Updates the scheduled sequence groups with model outputs
                  based on its `sampling parameters` (`use_beam_search` or not).
                - Frees the finished sequence groups.

            - Finally, it creates and returns the newly generated results.

        Example:
            >>> # Please see the example/ folder for more detailed examples.
            >>>
            >>> # initialize engine and request arguments
            >>> engine = LLMEngine.from_engine_args(engine_args)
            >>> example_inputs = [(0, "What is LLM?",
            >>>    SamplingParams(temperature=0.0))]
            >>>
            >>> # Start the engine with an event loop
            >>> while True:
            >>>     if example_inputs:
            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
            >>>         engine.add_request(str(req_id), prompt, sampling_params)
            >>>
            >>>     # continue the request processing
            >>>     request_outputs = engine.step()
            >>>     for request_output in request_outputs:
            >>>         if request_output.finished:
            >>>             # return or show the request output
            >>>
            >>>     if not (engine.has_unfinished_requests() or example_inputs):
            >>>         break
Antoni Baum's avatar
Antoni Baum committed
807
        """
808
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
809

810
811
812
813
814
815
816
817
818
        if not scheduler_outputs.is_empty():
            # Execute the model.
            all_outputs = self._run_workers(
                "execute_model",
                driver_kwargs={
                    "seq_group_metadata_list": seq_group_metadata_list,
                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
819
820
                },
                use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
821
822
823
824
825

            # Only the driver worker returns the sampling results.
            output = all_outputs[0]
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
826

827
        return self._process_model_outputs(output, scheduler_outputs)
Antoni Baum's avatar
Antoni Baum committed
828

829
    def do_log_stats(self) -> None:
830
831
832
        """Forced log when no requests active."""
        if self.log_stats:
            self.stat_logger.log(self._get_stats(scheduler_outputs=None))
833

834
835
836
    def _get_stats(self,
                   scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
        """Get Stats to be Logged to Prometheus."""
837
        now = time.monotonic()
Woosuk Kwon's avatar
Woosuk Kwon committed
838

839
840
841
842
        # KV Cache Usage in %.
        num_total_gpu = self.cache_config.num_gpu_blocks
        num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
        gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
Woosuk Kwon's avatar
Woosuk Kwon committed
843

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
        num_total_cpu = self.cache_config.num_cpu_blocks
        cpu_cache_usage = 0.
        if num_total_cpu > 0:
            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
            )
            cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)

        # Scheduler State
        num_running = len(self.scheduler.running)
        num_swapped = len(self.scheduler.swapped)
        num_waiting = len(self.scheduler.waiting)

        # Iteration stats if we have scheduler output.
        num_prompt_tokens = 0
        num_generation_tokens = 0
        time_to_first_tokens = []
        time_per_output_tokens = []
        time_e2e_requests = []
        if scheduler_outputs is not None:
            prompt_run = scheduler_outputs.prompt_run

            # Number of Tokens.
            if prompt_run:
                num_prompt_tokens = scheduler_outputs.num_batched_tokens
            else:
                num_generation_tokens = scheduler_outputs.num_batched_tokens

            # Latency Timings.
            time_last_iters = []
            for seq_group in scheduler_outputs.scheduled_seq_groups:
                # Time since last token. (n.b. updates seq_group.last_token_time)
                time_last_iters.append(seq_group.get_last_latency(now))
                # Time since arrival for all finished requests.
                if seq_group.is_finished():
                    time_e2e_requests.append(now - seq_group.arrival_time)

            time_to_first_tokens = time_last_iters if prompt_run else []
            time_per_output_tokens = [] if prompt_run else time_last_iters

        return Stats(
            now=now,
            num_running=num_running,
            num_swapped=num_swapped,
            num_waiting=num_waiting,
888
889
            gpu_cache_usage=gpu_cache_usage,
            cpu_cache_usage=cpu_cache_usage,
890
891
892
893
894
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=num_generation_tokens,
            time_to_first_tokens=time_to_first_tokens,
            time_per_output_tokens=time_per_output_tokens,
            time_e2e_requests=time_e2e_requests,
895
896
        )

897
    def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
898
        """Decodes the new token for a sequence."""
899
900
        (new_tokens, new_output_text, prefix_offset,
         read_offset) = detokenize_incrementally(
901
             self.get_tokenizer_for_seq(seq),
902
903
904
905
             all_input_ids=seq.get_token_ids(),
             prev_tokens=seq.tokens,
             prefix_offset=seq.prefix_offset,
             read_offset=seq.read_offset,
906
907
             skip_special_tokens=prms.skip_special_tokens,
             spaces_between_special_tokens=prms.spaces_between_special_tokens,
908
909
910
911
912
913
914
915
         )
        if seq.tokens is None:
            seq.tokens = new_tokens
        else:
            seq.tokens.extend(new_tokens)
        seq.prefix_offset = prefix_offset
        seq.read_offset = read_offset
        seq.output_text += new_output_text
916
917
918

    def _check_stop(self, seq: Sequence,
                    sampling_params: SamplingParams) -> None:
919
        """Stop the finished sequences."""
920
921
        for stop_str in sampling_params.stop:
            if seq.output_text.endswith(stop_str):
922
                self._finalize_sequence(seq, sampling_params, stop_str)
923
924
                seq.status = SequenceStatus.FINISHED_STOPPED
                return
925
        if seq.get_last_token_id() in sampling_params.stop_token_ids:
926
927
928
            stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
                seq.get_last_token_id())
            self._finalize_sequence(seq, sampling_params, stop_str)
929
930
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
931
932
933
934
935
936
937
938
939
940
941
942

        # Check if the sequence has reached max_model_len.
        if seq.get_len() > self.scheduler_config.max_model_len:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has reached max_tokens.
        if seq.get_output_len() == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has generated the EOS token.
943
944
        if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
                == self.get_tokenizer_for_seq(seq).eos_token_id):
945
946
            seq.status = SequenceStatus.FINISHED_STOPPED
            return
947

948
949
950
951
952
953
954
955
    def _finalize_sequence(self, seq: Sequence,
                           sampling_params: SamplingParams,
                           stop_string: str) -> None:
        if not sampling_params.include_stop_str_in_output and stop_string:
            # Truncate the output text so that the stop string is
            # not included in the output.
            seq.output_text = seq.output_text[:-len(stop_string)]

956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "add_lora",
            lora_request=lora_request,
        )

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "remove_lora",
            lora_id=lora_id,
        )

    def list_loras(self) -> List[int]:
        return self._run_workers("list_loras")

973
974
975
976
    def _run_workers(
        self,
        method: str,
        *args,
977
978
        driver_args: Optional[List[Any]] = None,
        driver_kwargs: Optional[Dict[str, Any]] = None,
979
        max_concurrent_workers: Optional[int] = None,
980
        use_ray_compiled_dag: bool = False,
981
982
983
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers."""
984

985
        if max_concurrent_workers:
986
987
988
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

989
990
991
992
993
994
995
996
997
998
        if use_ray_compiled_dag:
            # Right now, compiled DAG can only accept a single
            # input. TODO(sang): Fix it.
            output_channels = self.forward_dag.execute(1)
        else:
            # Start the ray workers first.
            ray_worker_outputs = [
                worker.execute_method.remote(method, *args, **kwargs)
                for worker in self.workers
            ]
999
1000
1001
1002
1003

        if driver_args is None:
            driver_args = args
        if driver_kwargs is None:
            driver_kwargs = kwargs
1004

1005
1006
1007
        # Start the driver worker after all the ray workers.
        driver_worker_output = getattr(self.driver_worker,
                                       method)(*driver_args, **driver_kwargs)
1008

1009
1010
        # Get the results of the ray workers.
        if self.workers:
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
            if use_ray_compiled_dag:
                try:
                    ray_worker_outputs = [
                        pickle.loads(chan.begin_read())
                        for chan in output_channels
                    ]
                finally:
                    # Has to call end_read in order to reuse the DAG.
                    for chan in output_channels:
                        chan.end_read()
            else:
                ray_worker_outputs = ray.get(ray_worker_outputs)
1023

1024
        return [driver_worker_output] + ray_worker_outputs
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044

    def _compiled_ray_dag(self):
        import pkg_resources
        required_version = "2.9"
        current_version = pkg_resources.get_distribution("ray").version
        if current_version < required_version:
            raise ValueError(f"Ray version {required_version} or greater is "
                             f"required, but found {current_version}")

        from ray.dag import MultiOutputNode, InputNode
        assert self.parallel_config.worker_use_ray

        # Right now, compiled DAG requires at least 1 arg. We send
        # a dummy value for now. It will be fixed soon.
        with InputNode() as input_data:
            forward_dag = MultiOutputNode([
                worker.execute_model_compiled_dag_remote.bind(input_data)
                for worker in self.workers
            ])
        return forward_dag.experimental_compile()