llm_engine.py 16.1 KB
Newer Older
1
import time
2
3
from functools import partial
from typing import Any, List, Optional, TYPE_CHECKING
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
9
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
14
15
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
                                               get_tokenizer)
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.utils import Counter
17
18
19
20
21
22
23

if ray:
    from ray.air.util.torch_dist import init_torch_dist_process_group
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup
24
25
26
27

logger = init_logger(__name__)


28
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
29
    """An LLM engine that receives requests and generates texts.
30

Woosuk Kwon's avatar
Woosuk Kwon committed
31
    This is the main class for the vLLM engine. It receives requests
32
33
34
35
36
37
38
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
39
    `AsyncLLMEngine` class wraps this class for online serving.
40

Zhuohan Li's avatar
Zhuohan Li committed
41
42
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
43
44
45
46
47
48
49
50
51
52
53
54
55

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
        distributed_init_method: The initialization method for distributed
            execution. See `torch.distributed.init_process_group` for details.
        stage_devices: The list of devices for each stage. Each stage is a list
            of (rank, node_resource, device) tuples.
        log_stats: Whether to log statistics.
    """
56
57
58
59
60
61
62
63

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        distributed_init_method: str,
64
        placement_group: Optional["PlacementGroup"],
65
        log_stats: bool,
66
67
    ) -> None:
        logger.info(
Zhuohan Li's avatar
Zhuohan Li committed
68
            "Initializing an LLM engine with config: "
69
            f"model={model_config.model!r}, "
70
            f"tokenizer={model_config.tokenizer!r}, "
71
            f"tokenizer_mode={model_config.tokenizer_mode}, "
72
            f"trust_remote_code={model_config.trust_remote_code}, "
73
74
75
76
77
            f"dtype={model_config.dtype}, "
            f"use_dummy_weights={model_config.use_dummy_weights}, "
            f"download_dir={model_config.download_dir!r}, "
            f"use_np_weights={model_config.use_np_weights}, "
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
78
            f"seed={model_config.seed})")
79
80
81
82
83
84
85
86
87
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.log_stats = log_stats
        self._verify_args()

88
        self.tokenizer = get_tokenizer(
89
90
91
            model_config.tokenizer,
            tokenizer_mode=model_config.tokenizer_mode,
            trust_remote_code=model_config.trust_remote_code)
92
93
94
        self.seq_counter = Counter()

        # Create the parallel GPU workers.
95
96
97
98
99
        if self.parallel_config.worker_use_ray:
            self._init_workers_ray(placement_group)
        else:
            self._init_workers(distributed_init_method)

100
101
102
103
104
105
        # Profile the memory usage and initialize the cache.
        self._init_cache()

        # Create the scheduler.
        self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    def _init_workers(self, distributed_init_method: str):
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel

        assert self.parallel_config.world_size == 1, (
            "Ray is required if parallel_config.world_size > 1.")

        self.workers: List[Worker] = []
        worker = Worker(
            self.model_config,
            self.parallel_config,
            self.scheduler_config,
            0,
            distributed_init_method,
        )
        self.workers.append(worker)
        self._run_workers(
            "init_model",
            get_all_outputs=True,
        )

    def _init_workers_ray(self, placement_group: "PlacementGroup"):
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel

        self.workers: List[Worker] = []
        for bundle in placement_group.bundle_specs:
            if not bundle.get("GPU", 0):
                continue
            worker = ray.remote(
                num_cpus=0,
                num_gpus=1,
                scheduling_strategy=PlacementGroupSchedulingStrategy(
                    placement_group=placement_group,
                    placement_group_capture_child_tasks=True),
            )(RayWorker).remote()
            self.workers.append(worker)

        # Initialize torch distributed process group for the workers.
        init_torch_dist_process_group(self.workers, backend="nccl")
        self._run_workers("init_worker",
                          get_all_outputs=True,
                          worker_init_fn=lambda: Worker(
                              self.model_config,
                              self.parallel_config,
                              self.scheduler_config,
                              None,
                              None,
                          ))
        self._run_workers(
            "init_model",
            get_all_outputs=True,
        )

162
163
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
164
        self.cache_config.verify_with_parallel_config(self.parallel_config)
165
166

    def _init_cache(self) -> None:
167
        """Profiles the memory usage and initializes the KV cache."""
168
169
170
171
172
173
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers(
            "profile_num_available_blocks",
            get_all_outputs=True,
            block_size=self.cache_config.block_size,
            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
174
            cpu_swap_space=self.cache_config.swap_space_bytes,
175
176
177
178
179
180
181
182
        )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
        # FIXME(woosuk): Change to debug log.
183
184
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")
185

186
        if num_gpu_blocks <= 0:
187
188
189
190
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `gpu_memory_utilization` when "
                             "initializing the engine.")

191
192
193
194
195
196
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        # Initialize the cache.
        self._run_workers("init_cache_engine", cache_config=self.cache_config)

197
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
198
199
200
201
202
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
203
        # Initialize the cluster.
204
205
        distributed_init_method, placement_group = initialize_cluster(
            parallel_config)
Zhuohan Li's avatar
Zhuohan Li committed
206
        # Create the LLM engine.
207
208
        engine = cls(*engine_configs,
                     distributed_init_method,
209
                     placement_group,
Zhuohan Li's avatar
Zhuohan Li committed
210
211
                     log_stats=not engine_args.disable_log_stats)
        return engine
212

213
214
215
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
216
        prompt: Optional[str],
217
218
219
220
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
221
        """Add a request to the engine's request pool.
222
223

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
224
        scheduler as `engine.step()` is called. The exact scheduling policy is
225
226
227
228
229
230
231
232
233
234
235
236
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
                the current time.
        """
237
238
239
        if arrival_time is None:
            arrival_time = time.time()
        if prompt_token_ids is None:
Woosuk Kwon's avatar
Woosuk Kwon committed
240
            assert prompt is not None
241
242
243
244
245
            prompt_token_ids = self.tokenizer.encode(prompt)

        # Create the sequences.
        block_size = self.cache_config.block_size
        seqs: List[Sequence] = []
246
        for _ in range(sampling_params.best_of):
247
248
249
250
251
252
253
254
255
256
257
            seq_id = next(self.seq_counter)
            seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
            seqs.append(seq)

        # Create the sequence group.
        seq_group = SequenceGroup(request_id, seqs, sampling_params,
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

258
    def abort_request(self, request_id: str) -> None:
259
260
261
262
263
        """Aborts a request with the given ID.

        Args:
            request_id: The ID of the request to abort.
        """
264
265
        self.scheduler.abort_seq_group(request_id)

266
267
268
269
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

270
    def get_num_unfinished_requests(self) -> int:
271
        """Gets the number of unfinished requests."""
272
273
        return self.scheduler.get_num_unfinished_seq_groups()

274
    def has_unfinished_requests(self) -> bool:
275
        """Returns True if there are unfinished requests."""
276
277
278
        return self.scheduler.has_unfinished_seqs()

    def step(self) -> List[RequestOutput]:
279
280
        """Performs one decoding iteration and returns newly generated results.

Zhuohan Li's avatar
Zhuohan Li committed
281
        This function performs one decoding iteration of the engine. It first
282
283
284
285
286
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
287
288
289
290
        (seq_group_metadata_list, scheduler_outputs,
         ignored_seq_groups) = self.scheduler.schedule()
        if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
                and (not ignored_seq_groups)):
291
292
293
294
295
296
297
298
299
300
301
            # Nothing to do.
            return []

        # Execute the model.
        output = self._run_workers(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )
302
303
304
305
306
307
308
309
310
        # Update the scheduler with the model outputs.
        seq_groups = self.scheduler.update(output)

        # Decode the sequences.
        self._decode_sequences(seq_groups)
        # Stop the sequences that meet the stopping criteria.
        self._stop_sequences(seq_groups)
        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
311
312
313

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
Lily Liu's avatar
Lily Liu committed
314
        for seq_group in seq_groups + ignored_seq_groups:
315
            request_output = RequestOutput.from_seq_group(seq_group)
316
317
318
            request_outputs.append(request_output)
        return request_outputs

319
    def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
320
        """Decodes the sequence outputs."""
321
        for seq_group in seq_groups:
322
323
324
325
326
327
328
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                new_token, new_output_text = detokenize_incrementally(
                    self.tokenizer,
                    seq.output_tokens,
                    seq.get_last_token_id(),
                    skip_special_tokens=True,
                )
329
330
331
                if new_token is not None:
                    seq.output_tokens.append(new_token)
                    seq.output_text = new_output_text
332
333

    def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
334
        """Stop the finished sequences."""
335
336
337
338
339
340
341
342
343
344
        for seq_group in seq_groups:
            sampling_params = seq_group.sampling_params
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                # Check if the sequence has generated a stop string.
                stopped = False
                for stop_str in sampling_params.stop:
                    if seq.output_text.endswith(stop_str):
                        # Truncate the output text so that the stop string is
                        # not included in the output.
                        seq.output_text = seq.output_text[:-len(stop_str)]
345
346
                        self.scheduler.free_seq(
                            seq, SequenceStatus.FINISHED_STOPPED)
347
348
349
350
351
                        stopped = True
                        break
                if stopped:
                    continue

Lily Liu's avatar
Lily Liu committed
352
                # Check if the sequence has reached max_seq_len.
Lily Liu's avatar
Lily Liu committed
353
                if seq.get_len() > self.scheduler_config.max_model_len:
Lily Liu's avatar
Lily Liu committed
354
355
356
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
                    continue
357
358
                # Check if the sequence has reached max_tokens.
                if seq.get_output_len() == sampling_params.max_tokens:
Zhuohan Li's avatar
Zhuohan Li committed
359
360
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
361
362
363
364
                    continue
                # Check if the sequence has generated the EOS token.
                if not sampling_params.ignore_eos:
                    if seq.get_last_token_id() == self.tokenizer.eos_token_id:
365
366
                        self.scheduler.free_seq(
                            seq, SequenceStatus.FINISHED_STOPPED)
367
368
                        continue

369
370
371
372
    def _run_workers(
        self,
        method: str,
        *args,
373
        get_all_outputs: bool = False,
374
375
        **kwargs,
    ) -> Any:
376
        """Runs the given method on all workers."""
377
378
        all_outputs = []
        for worker in self.workers:
379
            if self.parallel_config.worker_use_ray:
380
381
382
                executor = partial(worker.execute_method.remote, method)
            else:
                executor = getattr(worker, method)
Zhuohan Li's avatar
Zhuohan Li committed
383

384
385
            output = executor(*args, **kwargs)
            all_outputs.append(output)
Zhuohan Li's avatar
Zhuohan Li committed
386

387
        if self.parallel_config.worker_use_ray:
388
389
390
391
392
393
394
395
396
397
            all_outputs = ray.get(all_outputs)

        if get_all_outputs:
            return all_outputs

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
        return output