llm_engine.py 16.3 KB
Newer Older
1
import time
Fang li's avatar
Fang li committed
2
import copy
3
4
from functools import partial
from typing import Any, List, Optional, TYPE_CHECKING
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
10
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
15
16
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
                                               get_tokenizer)
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.utils import Counter
18
19
20
21
22
23
24

if ray:
    from ray.air.util.torch_dist import init_torch_dist_process_group
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup
25
26
27
28

logger = init_logger(__name__)


29
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
30
    """An LLM engine that receives requests and generates texts.
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
    This is the main class for the vLLM engine. It receives requests
33
34
35
36
37
38
39
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
40
    `AsyncLLMEngine` class wraps this class for online serving.
41

Zhuohan Li's avatar
Zhuohan Li committed
42
43
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
44
45
46
47
48
49
50
51
52
53
54
55
56

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
        distributed_init_method: The initialization method for distributed
            execution. See `torch.distributed.init_process_group` for details.
        stage_devices: The list of devices for each stage. Each stage is a list
            of (rank, node_resource, device) tuples.
        log_stats: Whether to log statistics.
    """
57
58
59
60
61
62
63
64

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        distributed_init_method: str,
65
        placement_group: Optional["PlacementGroup"],
66
        log_stats: bool,
67
68
    ) -> None:
        logger.info(
Zhuohan Li's avatar
Zhuohan Li committed
69
            "Initializing an LLM engine with config: "
70
            f"model={model_config.model!r}, "
71
            f"tokenizer={model_config.tokenizer!r}, "
72
            f"tokenizer_mode={model_config.tokenizer_mode}, "
73
            f"trust_remote_code={model_config.trust_remote_code}, "
74
75
76
77
78
            f"dtype={model_config.dtype}, "
            f"use_dummy_weights={model_config.use_dummy_weights}, "
            f"download_dir={model_config.download_dir!r}, "
            f"use_np_weights={model_config.use_np_weights}, "
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
79
            f"seed={model_config.seed})")
80
81
82
83
84
85
86
87
88
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.log_stats = log_stats
        self._verify_args()

89
        self.tokenizer = get_tokenizer(
90
91
92
            model_config.tokenizer,
            tokenizer_mode=model_config.tokenizer_mode,
            trust_remote_code=model_config.trust_remote_code)
93
94
95
        self.seq_counter = Counter()

        # Create the parallel GPU workers.
96
97
98
99
100
        if self.parallel_config.worker_use_ray:
            self._init_workers_ray(placement_group)
        else:
            self._init_workers(distributed_init_method)

101
102
103
104
105
106
        # Profile the memory usage and initialize the cache.
        self._init_cache()

        # Create the scheduler.
        self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    def _init_workers(self, distributed_init_method: str):
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel

        assert self.parallel_config.world_size == 1, (
            "Ray is required if parallel_config.world_size > 1.")

        self.workers: List[Worker] = []
        worker = Worker(
            self.model_config,
            self.parallel_config,
            self.scheduler_config,
            0,
            distributed_init_method,
        )
        self.workers.append(worker)
        self._run_workers(
            "init_model",
            get_all_outputs=True,
        )

    def _init_workers_ray(self, placement_group: "PlacementGroup"):
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel

        self.workers: List[Worker] = []
        for bundle in placement_group.bundle_specs:
            if not bundle.get("GPU", 0):
                continue
            worker = ray.remote(
                num_cpus=0,
                num_gpus=1,
                scheduling_strategy=PlacementGroupSchedulingStrategy(
                    placement_group=placement_group,
                    placement_group_capture_child_tasks=True),
            )(RayWorker).remote()
            self.workers.append(worker)

        # Initialize torch distributed process group for the workers.
        init_torch_dist_process_group(self.workers, backend="nccl")
Fang li's avatar
Fang li committed
149
150
151
        model_config = copy.deepcopy(self.model_config)
        parallel_config = copy.deepcopy(self.parallel_config)
        scheduler_config = copy.deepcopy(self.scheduler_config)
152
153
154
        self._run_workers("init_worker",
                          get_all_outputs=True,
                          worker_init_fn=lambda: Worker(
Fang li's avatar
Fang li committed
155
156
157
                              model_config,
                              parallel_config,
                              scheduler_config,
158
159
160
161
162
163
164
165
                              None,
                              None,
                          ))
        self._run_workers(
            "init_model",
            get_all_outputs=True,
        )

166
167
    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
168
        self.cache_config.verify_with_parallel_config(self.parallel_config)
169
170

    def _init_cache(self) -> None:
171
        """Profiles the memory usage and initializes the KV cache."""
172
173
174
175
176
177
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers(
            "profile_num_available_blocks",
            get_all_outputs=True,
            block_size=self.cache_config.block_size,
            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
178
            cpu_swap_space=self.cache_config.swap_space_bytes,
179
180
181
182
183
184
185
186
        )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
        # FIXME(woosuk): Change to debug log.
187
188
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")
189

190
        if num_gpu_blocks <= 0:
191
192
193
194
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `gpu_memory_utilization` when "
                             "initializing the engine.")

195
196
197
198
199
200
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        # Initialize the cache.
        self._run_workers("init_cache_engine", cache_config=self.cache_config)

201
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
202
203
204
205
206
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
207
        # Initialize the cluster.
208
209
        distributed_init_method, placement_group = initialize_cluster(
            parallel_config)
Zhuohan Li's avatar
Zhuohan Li committed
210
        # Create the LLM engine.
211
212
        engine = cls(*engine_configs,
                     distributed_init_method,
213
                     placement_group,
Zhuohan Li's avatar
Zhuohan Li committed
214
215
                     log_stats=not engine_args.disable_log_stats)
        return engine
216

217
218
219
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
220
        prompt: Optional[str],
221
222
223
224
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
225
        """Add a request to the engine's request pool.
226
227

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
228
        scheduler as `engine.step()` is called. The exact scheduling policy is
229
230
231
232
233
234
235
236
237
238
239
240
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
                the current time.
        """
241
242
243
        if arrival_time is None:
            arrival_time = time.time()
        if prompt_token_ids is None:
Woosuk Kwon's avatar
Woosuk Kwon committed
244
            assert prompt is not None
245
246
247
248
249
            prompt_token_ids = self.tokenizer.encode(prompt)

        # Create the sequences.
        block_size = self.cache_config.block_size
        seqs: List[Sequence] = []
250
        for _ in range(sampling_params.best_of):
251
252
253
254
255
256
257
258
259
260
261
            seq_id = next(self.seq_counter)
            seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
            seqs.append(seq)

        # Create the sequence group.
        seq_group = SequenceGroup(request_id, seqs, sampling_params,
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

262
    def abort_request(self, request_id: str) -> None:
263
264
265
266
267
        """Aborts a request with the given ID.

        Args:
            request_id: The ID of the request to abort.
        """
268
269
        self.scheduler.abort_seq_group(request_id)

270
271
272
273
    def get_model_config(self) -> ModelConfig:
        """Gets the model configuration."""
        return self.model_config

274
    def get_num_unfinished_requests(self) -> int:
275
        """Gets the number of unfinished requests."""
276
277
        return self.scheduler.get_num_unfinished_seq_groups()

278
    def has_unfinished_requests(self) -> bool:
279
        """Returns True if there are unfinished requests."""
280
281
282
        return self.scheduler.has_unfinished_seqs()

    def step(self) -> List[RequestOutput]:
283
284
        """Performs one decoding iteration and returns newly generated results.

Zhuohan Li's avatar
Zhuohan Li committed
285
        This function performs one decoding iteration of the engine. It first
286
287
288
289
290
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
291
292
293
294
        (seq_group_metadata_list, scheduler_outputs,
         ignored_seq_groups) = self.scheduler.schedule()
        if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
                and (not ignored_seq_groups)):
295
296
297
298
299
300
301
302
303
304
305
            # Nothing to do.
            return []

        # Execute the model.
        output = self._run_workers(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )
306
307
308
309
310
311
312
313
314
        # Update the scheduler with the model outputs.
        seq_groups = self.scheduler.update(output)

        # Decode the sequences.
        self._decode_sequences(seq_groups)
        # Stop the sequences that meet the stopping criteria.
        self._stop_sequences(seq_groups)
        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
315
316
317

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
Lily Liu's avatar
Lily Liu committed
318
        for seq_group in seq_groups + ignored_seq_groups:
319
            request_output = RequestOutput.from_seq_group(seq_group)
320
321
322
            request_outputs.append(request_output)
        return request_outputs

323
    def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
324
        """Decodes the sequence outputs."""
325
        for seq_group in seq_groups:
326
327
328
329
330
331
332
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                new_token, new_output_text = detokenize_incrementally(
                    self.tokenizer,
                    seq.output_tokens,
                    seq.get_last_token_id(),
                    skip_special_tokens=True,
                )
333
334
335
                if new_token is not None:
                    seq.output_tokens.append(new_token)
                    seq.output_text = new_output_text
336
337

    def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
338
        """Stop the finished sequences."""
339
340
341
342
343
344
345
346
347
348
        for seq_group in seq_groups:
            sampling_params = seq_group.sampling_params
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                # Check if the sequence has generated a stop string.
                stopped = False
                for stop_str in sampling_params.stop:
                    if seq.output_text.endswith(stop_str):
                        # Truncate the output text so that the stop string is
                        # not included in the output.
                        seq.output_text = seq.output_text[:-len(stop_str)]
349
350
                        self.scheduler.free_seq(
                            seq, SequenceStatus.FINISHED_STOPPED)
351
352
353
354
355
                        stopped = True
                        break
                if stopped:
                    continue

Chaofan Lin's avatar
Chaofan Lin committed
356
                # Check if the sequence has reached max_model_len.
Lily Liu's avatar
Lily Liu committed
357
                if seq.get_len() > self.scheduler_config.max_model_len:
Lily Liu's avatar
Lily Liu committed
358
359
360
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
                    continue
361
362
                # Check if the sequence has reached max_tokens.
                if seq.get_output_len() == sampling_params.max_tokens:
Zhuohan Li's avatar
Zhuohan Li committed
363
364
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
365
366
367
368
                    continue
                # Check if the sequence has generated the EOS token.
                if not sampling_params.ignore_eos:
                    if seq.get_last_token_id() == self.tokenizer.eos_token_id:
369
370
                        self.scheduler.free_seq(
                            seq, SequenceStatus.FINISHED_STOPPED)
371
372
                        continue

373
374
375
376
    def _run_workers(
        self,
        method: str,
        *args,
377
        get_all_outputs: bool = False,
378
379
        **kwargs,
    ) -> Any:
380
        """Runs the given method on all workers."""
381
382
        all_outputs = []
        for worker in self.workers:
383
            if self.parallel_config.worker_use_ray:
384
385
386
                executor = partial(worker.execute_method.remote, method)
            else:
                executor = getattr(worker, method)
Zhuohan Li's avatar
Zhuohan Li committed
387

388
389
            output = executor(*args, **kwargs)
            all_outputs.append(output)
Zhuohan Li's avatar
Zhuohan Li committed
390

391
        if self.parallel_config.worker_use_ray:
392
393
394
395
396
397
398
399
400
401
            all_outputs = ray.get(all_outputs)

        if get_all_outputs:
            return all_outputs

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
        return output