llm_engine.py 20.4 KB
Newer Older
Fang li's avatar
Fang li committed
1
import copy
Antoni Baum's avatar
Antoni Baum committed
2
import time
3
from functools import partial
Antoni Baum's avatar
Antoni Baum committed
4
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
Antoni Baum's avatar
Antoni Baum committed
8
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.engine.arg_utils import EngineArgs
Antoni Baum's avatar
Antoni Baum committed
10
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
Antoni Baum's avatar
Antoni Baum committed
14
15
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
                           SequenceStatus)
16
17
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
                                               get_tokenizer)
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.utils import Counter
19
20
21
22
23
24
25

if ray:
    from ray.air.util.torch_dist import init_torch_dist_process_group
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup
26
27
28

logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
_LOGGING_INTERVAL_SEC = 5

31

32
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
33
    """An LLM engine that receives requests and generates texts.
34

Woosuk Kwon's avatar
Woosuk Kwon committed
35
    This is the main class for the vLLM engine. It receives requests
36
37
38
39
40
41
42
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
43
    `AsyncLLMEngine` class wraps this class for online serving.
44

Zhuohan Li's avatar
Zhuohan Li committed
45
46
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
47
48
49
50
51
52
53
54
55
56
57
58
59

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
        distributed_init_method: The initialization method for distributed
            execution. See `torch.distributed.init_process_group` for details.
        stage_devices: The list of devices for each stage. Each stage is a list
            of (rank, node_resource, device) tuples.
        log_stats: Whether to log statistics.
    """
60
61
62
63
64
65
66
67

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        distributed_init_method: str,
68
        placement_group: Optional["PlacementGroup"],
69
        log_stats: bool,
70
71
    ) -> None:
        logger.info(
Zhuohan Li's avatar
Zhuohan Li committed
72
            "Initializing an LLM engine with config: "
73
            f"model={model_config.model!r}, "
74
            f"tokenizer={model_config.tokenizer!r}, "
75
            f"tokenizer_mode={model_config.tokenizer_mode}, "
76
            f"trust_remote_code={model_config.trust_remote_code}, "
77
78
79
80
81
            f"dtype={model_config.dtype}, "
            f"use_dummy_weights={model_config.use_dummy_weights}, "
            f"download_dir={model_config.download_dir!r}, "
            f"use_np_weights={model_config.use_np_weights}, "
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
82
            f"seed={model_config.seed})")
83
84
85
86
87
88
89
90
91
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.log_stats = log_stats
        self._verify_args()

92
        self.tokenizer = get_tokenizer(
93
94
95
            model_config.tokenizer,
            tokenizer_mode=model_config.tokenizer_mode,
            trust_remote_code=model_config.trust_remote_code)
96
97
98
        self.seq_counter = Counter()

        # Create the parallel GPU workers.
99
100
101
102
103
        if self.parallel_config.worker_use_ray:
            self._init_workers_ray(placement_group)
        else:
            self._init_workers(distributed_init_method)

104
105
106
107
        # Profile the memory usage and initialize the cache.
        self._init_cache()

        # Create the scheduler.
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
113
114
115
        self.scheduler = Scheduler(scheduler_config, cache_config)

        # Logging.
        self.last_logging_time = 0.0
        # List of (timestamp, num_tokens)
        self.num_prompt_tokens: List[Tuple[float, int]] = []
        # List of (timestamp, num_tokens)
        self.num_generation_tokens: List[Tuple[float, int]] = []
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    def _init_workers(self, distributed_init_method: str):
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel

        assert self.parallel_config.world_size == 1, (
            "Ray is required if parallel_config.world_size > 1.")

        self.workers: List[Worker] = []
        worker = Worker(
            self.model_config,
            self.parallel_config,
            self.scheduler_config,
            0,
            distributed_init_method,
        )
        self.workers.append(worker)
        self._run_workers(
            "init_model",
            get_all_outputs=True,
        )

Antoni Baum's avatar
Antoni Baum committed
139
140
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel

        self.workers: List[Worker] = []
        for bundle in placement_group.bundle_specs:
            if not bundle.get("GPU", 0):
                continue
            worker = ray.remote(
                num_cpus=0,
                num_gpus=1,
                scheduling_strategy=PlacementGroupSchedulingStrategy(
                    placement_group=placement_group,
                    placement_group_capture_child_tasks=True),
Antoni Baum's avatar
Antoni Baum committed
155
                **ray_remote_kwargs,
156
157
158
159
160
            )(RayWorker).remote()
            self.workers.append(worker)

        # Initialize torch distributed process group for the workers.
        init_torch_dist_process_group(self.workers, backend="nccl")
Fang li's avatar
Fang li committed
161
162
163
        model_config = copy.deepcopy(self.model_config)
        parallel_config = copy.deepcopy(self.parallel_config)
        scheduler_config = copy.deepcopy(self.scheduler_config)
164
165
166
        self._run_workers("init_worker",
                          get_all_outputs=True,
                          worker_init_fn=lambda: Worker(
Fang li's avatar
Fang li committed
167
168
169
                              model_config,
                              parallel_config,
                              scheduler_config,
170
171
172
173
174
175
176
177
                              None,
                              None,
                          ))
        self._run_workers(
            "init_model",
            get_all_outputs=True,
        )

178
179
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
180
        self.cache_config.verify_with_parallel_config(self.parallel_config)
181
182

    def _init_cache(self) -> None:
183
        """Profiles the memory usage and initializes the KV cache."""
184
185
186
187
188
189
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers(
            "profile_num_available_blocks",
            get_all_outputs=True,
            block_size=self.cache_config.block_size,
            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
190
            cpu_swap_space=self.cache_config.swap_space_bytes,
191
192
193
194
195
196
197
198
        )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
        # FIXME(woosuk): Change to debug log.
199
200
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")
201

202
        if num_gpu_blocks <= 0:
203
204
205
206
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `gpu_memory_utilization` when "
                             "initializing the engine.")

207
208
209
210
211
212
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        # Initialize the cache.
        self._run_workers("init_cache_engine", cache_config=self.cache_config)

213
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
214
215
216
217
218
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
219
        # Initialize the cluster.
220
221
        distributed_init_method, placement_group = initialize_cluster(
            parallel_config)
Zhuohan Li's avatar
Zhuohan Li committed
222
        # Create the LLM engine.
223
224
        engine = cls(*engine_configs,
                     distributed_init_method,
225
                     placement_group,
Zhuohan Li's avatar
Zhuohan Li committed
226
227
                     log_stats=not engine_args.disable_log_stats)
        return engine
228

229
230
231
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
232
        prompt: Optional[str],
233
234
235
236
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
237
        """Add a request to the engine's request pool.
238
239

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
240
        scheduler as `engine.step()` is called. The exact scheduling policy is
241
242
243
244
245
246
247
248
249
250
251
252
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
                the current time.
        """
253
254
255
        if arrival_time is None:
            arrival_time = time.time()
        if prompt_token_ids is None:
Woosuk Kwon's avatar
Woosuk Kwon committed
256
            assert prompt is not None
257
258
259
260
261
            prompt_token_ids = self.tokenizer.encode(prompt)

        # Create the sequences.
        block_size = self.cache_config.block_size
        seqs: List[Sequence] = []
262
        for _ in range(sampling_params.best_of):
263
264
265
266
267
268
269
270
271
272
273
            seq_id = next(self.seq_counter)
            seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
            seqs.append(seq)

        # Create the sequence group.
        seq_group = SequenceGroup(request_id, seqs, sampling_params,
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

Antoni Baum's avatar
Antoni Baum committed
274
275
    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
        """Aborts a request(s) with the given ID.
276
277

        Args:
Antoni Baum's avatar
Antoni Baum committed
278
            request_id: The ID(s) of the request to abort.
279
        """
280
281
        self.scheduler.abort_seq_group(request_id)

282
283
284
285
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

286
    def get_num_unfinished_requests(self) -> int:
287
        """Gets the number of unfinished requests."""
288
289
        return self.scheduler.get_num_unfinished_seq_groups()

290
    def has_unfinished_requests(self) -> bool:
291
        """Returns True if there are unfinished requests."""
292
293
        return self.scheduler.has_unfinished_seqs()

Antoni Baum's avatar
Antoni Baum committed
294
295
296
297
    def _schedule(
        self
    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
               Optional[List[RequestOutput]]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
        if scheduler_outputs.is_empty():
Antoni Baum's avatar
Antoni Baum committed
300
            return seq_group_metadata_list, scheduler_outputs, [
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
303
                RequestOutput.from_seq_group(seq_group)
                for seq_group in scheduler_outputs.ignored_seq_groups
            ]
Antoni Baum's avatar
Antoni Baum committed
304
        return seq_group_metadata_list, scheduler_outputs, None
305

Antoni Baum's avatar
Antoni Baum committed
306
307
308
    def _process_worker_outputs(
            self, output,
            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
309
310
311
312
313
314
315
316
317
        # Update the scheduler with the model outputs.
        seq_groups = self.scheduler.update(output)

        # Decode the sequences.
        self._decode_sequences(seq_groups)
        # Stop the sequences that meet the stopping criteria.
        self._stop_sequences(seq_groups)
        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
318
319
320

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
322
            request_output = RequestOutput.from_seq_group(seq_group)
323
            request_outputs.append(request_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
324
325
326
327
328

        if self.log_stats:
            # Log the system stats.
            self._log_system_stats(scheduler_outputs.prompt_run,
                                   scheduler_outputs.num_batched_tokens)
329
330
        return request_outputs

Antoni Baum's avatar
Antoni Baum committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    def step(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
        (seq_group_metadata_list, scheduler_outputs,
         early_return) = self._schedule()
        if early_return is not None:
            return early_return

        # Execute the model.
        output = self._run_workers(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )

        return self._process_worker_outputs(output, scheduler_outputs)

Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    def _log_system_stats(
        self,
        prompt_run: bool,
        num_batched_tokens: int,
    ) -> None:
        now = time.time()
        # Log the number of batched input tokens.
        if prompt_run:
            self.num_prompt_tokens.append((now, num_batched_tokens))
        else:
            self.num_generation_tokens.append((now, num_batched_tokens))

        elapsed_time = now - self.last_logging_time
        if elapsed_time < _LOGGING_INTERVAL_SEC:
            return

        # Discard the old stats.
        self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
                                  if now - t < _LOGGING_INTERVAL_SEC]
        self.num_generation_tokens = [(t, n)
                                      for t, n in self.num_generation_tokens
                                      if now - t < _LOGGING_INTERVAL_SEC]

        if len(self.num_prompt_tokens) > 1:
            total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
            window = now - self.num_prompt_tokens[0][0]
            avg_prompt_throughput = total_num_tokens / window
        else:
            avg_prompt_throughput = 0.0
        if len(self.num_generation_tokens) > 1:
            total_num_tokens = sum(n
                                   for _, n in self.num_generation_tokens[:-1])
            window = now - self.num_generation_tokens[0][0]
            avg_generation_throughput = total_num_tokens / window
        else:
            avg_generation_throughput = 0.0

        total_num_gpu_blocks = self.cache_config.num_gpu_blocks
        num_free_gpu_blocks = (
            self.scheduler.block_manager.get_num_free_gpu_blocks())
        num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
        gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks

        total_num_cpu_blocks = self.cache_config.num_cpu_blocks
        if total_num_cpu_blocks > 0:
            num_free_cpu_blocks = (
                self.scheduler.block_manager.get_num_free_cpu_blocks())
            num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
            cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
        else:
            cpu_cache_usage = 0.0

        logger.info("Avg prompt throughput: "
                    f"{avg_prompt_throughput:.1f} tokens/s, "
                    "Avg generation throughput: "
                    f"{avg_generation_throughput:.1f} tokens/s, "
                    f"Running: {len(self.scheduler.running)} reqs, "
                    f"Swapped: {len(self.scheduler.swapped)} reqs, "
                    f"Pending: {len(self.scheduler.waiting)} reqs, "
                    f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
                    f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
        self.last_logging_time = now

419
    def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
420
        """Decodes the sequence outputs."""
421
        for seq_group in seq_groups:
422
423
424
425
426
427
428
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                new_token, new_output_text = detokenize_incrementally(
                    self.tokenizer,
                    seq.output_tokens,
                    seq.get_last_token_id(),
                    skip_special_tokens=True,
                )
429
430
431
                if new_token is not None:
                    seq.output_tokens.append(new_token)
                    seq.output_text = new_output_text
432
433

    def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
434
        """Stop the finished sequences."""
435
436
437
438
439
440
441
442
443
444
        for seq_group in seq_groups:
            sampling_params = seq_group.sampling_params
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                # Check if the sequence has generated a stop string.
                stopped = False
                for stop_str in sampling_params.stop:
                    if seq.output_text.endswith(stop_str):
                        # Truncate the output text so that the stop string is
                        # not included in the output.
                        seq.output_text = seq.output_text[:-len(stop_str)]
445
446
                        self.scheduler.free_seq(
                            seq, SequenceStatus.FINISHED_STOPPED)
447
448
449
450
451
                        stopped = True
                        break
                if stopped:
                    continue

Chaofan Lin's avatar
Chaofan Lin committed
452
                # Check if the sequence has reached max_model_len.
Lily Liu's avatar
Lily Liu committed
453
                if seq.get_len() > self.scheduler_config.max_model_len:
Lily Liu's avatar
Lily Liu committed
454
455
456
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
                    continue
457
458
                # Check if the sequence has reached max_tokens.
                if seq.get_output_len() == sampling_params.max_tokens:
Zhuohan Li's avatar
Zhuohan Li committed
459
460
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
461
462
463
464
                    continue
                # Check if the sequence has generated the EOS token.
                if not sampling_params.ignore_eos:
                    if seq.get_last_token_id() == self.tokenizer.eos_token_id:
465
466
                        self.scheduler.free_seq(
                            seq, SequenceStatus.FINISHED_STOPPED)
467
468
                        continue

469
470
471
472
    def _run_workers(
        self,
        method: str,
        *args,
473
        get_all_outputs: bool = False,
474
475
        **kwargs,
    ) -> Any:
476
        """Runs the given method on all workers."""
477
478
        all_outputs = []
        for worker in self.workers:
479
            if self.parallel_config.worker_use_ray:
480
481
482
                executor = partial(worker.execute_method.remote, method)
            else:
                executor = getattr(worker, method)
Zhuohan Li's avatar
Zhuohan Li committed
483

484
485
            output = executor(*args, **kwargs)
            all_outputs.append(output)
Zhuohan Li's avatar
Zhuohan Li committed
486

487
        if self.parallel_config.worker_use_ray:
488
489
490
491
492
493
494
495
496
497
            all_outputs = ray.get(all_outputs)

        if get_all_outputs:
            return all_outputs

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
        return output