llm_engine.py 13.9 KB
Newer Older
1
2
3
import time
from typing import Any, List, Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
7
8
9
10
11
12
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
13
14
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
                                               get_tokenizer)
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
from vllm.utils import Counter
from vllm.worker.worker import Worker
17
18
19
20

logger = init_logger(__name__)


21
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
22
    """An LLM engine that receives requests and generates texts.
23

Woosuk Kwon's avatar
Woosuk Kwon committed
24
    This is the main class for the vLLM engine. It receives requests
25
26
27
28
29
30
31
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
32
    `AsyncLLMEngine` class wraps this class for online serving.
33

Zhuohan Li's avatar
Zhuohan Li committed
34
35
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
36
37
38
39
40
41
42
43
44
45
46
47
48

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
        distributed_init_method: The initialization method for distributed
            execution. See `torch.distributed.init_process_group` for details.
        stage_devices: The list of devices for each stage. Each stage is a list
            of (rank, node_resource, device) tuples.
        log_stats: Whether to log statistics.
    """
49
50
51
52
53
54
55
56

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        distributed_init_method: str,
57
        stage_devices: List[List[DeviceID]],
58
        log_stats: bool,
59
60
    ) -> None:
        logger.info(
Zhuohan Li's avatar
Zhuohan Li committed
61
            "Initializing an LLM engine with config: "
62
            f"model={model_config.model!r}, "
63
            f"tokenizer={model_config.tokenizer!r}, "
64
            f"tokenizer_mode={model_config.tokenizer_mode}, "
65
66
67
68
69
            f"dtype={model_config.dtype}, "
            f"use_dummy_weights={model_config.use_dummy_weights}, "
            f"download_dir={model_config.download_dir!r}, "
            f"use_np_weights={model_config.use_np_weights}, "
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
70
            f"seed={model_config.seed})")
71
72
73
74
75
76
77
78
79
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.log_stats = log_stats
        self._verify_args()

80
81
        self.tokenizer = get_tokenizer(
            model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
82
83
84
85
86
87
88
        self.seq_counter = Counter()

        # Create the parallel GPU workers.
        self.workers: List[Worker] = []
        assert len(stage_devices) == 1, "Only support one stage for now."
        for rank, node_resource, _ in stage_devices[0]:
            worker_cls = Worker
89
            if self.parallel_config.worker_use_ray:
90
91
92
                worker_cls = ray.remote(
                    num_cpus=0,
                    num_gpus=1,
Zhuohan Li's avatar
Zhuohan Li committed
93
                    resources={node_resource: 1e-3},
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
                )(worker_cls).remote

            worker = worker_cls(
                model_config,
                parallel_config,
                scheduler_config,
                rank,
                distributed_init_method,
            )
            self.workers.append(worker)
        # Profile the memory usage and initialize the cache.
        self._init_cache()

        # Create the scheduler.
        self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)

    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
112
        self.cache_config.verify_with_parallel_config(self.parallel_config)
113
114

    def _init_cache(self) -> None:
115
        """Profiles the memory usage and initializes the KV cache."""
116
117
118
119
120
121
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers(
            "profile_num_available_blocks",
            get_all_outputs=True,
            block_size=self.cache_config.block_size,
            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
122
            cpu_swap_space=self.cache_config.swap_space_bytes,
123
124
125
126
127
128
129
130
        )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
        # FIXME(woosuk): Change to debug log.
131
132
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")
133

134
        if num_gpu_blocks <= 0:
135
136
137
138
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `gpu_memory_utilization` when "
                             "initializing the engine.")

139
140
141
142
143
144
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        # Initialize the cache.
        self._run_workers("init_cache_engine", cache_config=self.cache_config)

145
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
146
147
148
149
150
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
151
152
        # Initialize the cluster.
        distributed_init_method, devices = initialize_cluster(parallel_config)
Zhuohan Li's avatar
Zhuohan Li committed
153
        # Create the LLM engine.
154
155
156
        engine = cls(*engine_configs,
                     distributed_init_method,
                     devices,
Zhuohan Li's avatar
Zhuohan Li committed
157
158
                     log_stats=not engine_args.disable_log_stats)
        return engine
159

160
161
162
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
163
        prompt: Optional[str],
164
165
166
167
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
168
        """Add a request to the engine's request pool.
169
170

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
171
        scheduler as `engine.step()` is called. The exact scheduling policy is
172
173
174
175
176
177
178
179
180
181
182
183
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
                the current time.
        """
184
185
186
        if arrival_time is None:
            arrival_time = time.time()
        if prompt_token_ids is None:
Woosuk Kwon's avatar
Woosuk Kwon committed
187
            assert prompt is not None
188
189
190
191
192
            prompt_token_ids = self.tokenizer.encode(prompt)

        # Create the sequences.
        block_size = self.cache_config.block_size
        seqs: List[Sequence] = []
193
        for _ in range(sampling_params.best_of):
194
195
196
197
198
199
200
201
202
203
204
            seq_id = next(self.seq_counter)
            seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
            seqs.append(seq)

        # Create the sequence group.
        seq_group = SequenceGroup(request_id, seqs, sampling_params,
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

205
    def abort_request(self, request_id: str) -> None:
206
207
208
209
210
        """Aborts a request with the given ID.

        Args:
            request_id: The ID of the request to abort.
        """
211
212
        self.scheduler.abort_seq_group(request_id)

213
    def get_num_unfinished_requests(self) -> int:
214
        """Gets the number of unfinished requests."""
215
216
        return self.scheduler.get_num_unfinished_seq_groups()

217
    def has_unfinished_requests(self) -> bool:
218
        """Returns True if there are unfinished requests."""
219
220
221
        return self.scheduler.has_unfinished_seqs()

    def step(self) -> List[RequestOutput]:
222
223
        """Performs one decoding iteration and returns newly generated results.

Zhuohan Li's avatar
Zhuohan Li committed
224
        This function performs one decoding iteration of the engine. It first
225
226
227
228
229
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
230
231
232
233
        (seq_group_metadata_list, scheduler_outputs,
         ignored_seq_groups) = self.scheduler.schedule()
        if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
                and (not ignored_seq_groups)):
234
235
236
237
238
239
240
241
242
243
244
            # Nothing to do.
            return []

        # Execute the model.
        output = self._run_workers(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )
245
246
247
248
249
250
251
252
253
        # Update the scheduler with the model outputs.
        seq_groups = self.scheduler.update(output)

        # Decode the sequences.
        self._decode_sequences(seq_groups)
        # Stop the sequences that meet the stopping criteria.
        self._stop_sequences(seq_groups)
        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
254
255
256

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
Lily Liu's avatar
Lily Liu committed
257
        for seq_group in seq_groups + ignored_seq_groups:
258
            request_output = RequestOutput.from_seq_group(seq_group)
259
260
261
            request_outputs.append(request_output)
        return request_outputs

262
    def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
263
        """Decodes the sequence outputs."""
264
        for seq_group in seq_groups:
265
266
267
268
269
270
271
272
273
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                new_token, new_output_text = detokenize_incrementally(
                    self.tokenizer,
                    seq.output_tokens,
                    seq.get_last_token_id(),
                    skip_special_tokens=True,
                )
                seq.output_tokens.append(new_token)
                seq.output_text = new_output_text
274
275

    def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
276
        """Stop the finished sequences."""
277
278
279
280
281
282
283
284
285
286
        for seq_group in seq_groups:
            sampling_params = seq_group.sampling_params
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                # Check if the sequence has generated a stop string.
                stopped = False
                for stop_str in sampling_params.stop:
                    if seq.output_text.endswith(stop_str):
                        # Truncate the output text so that the stop string is
                        # not included in the output.
                        seq.output_text = seq.output_text[:-len(stop_str)]
287
288
                        self.scheduler.free_seq(
                            seq, SequenceStatus.FINISHED_STOPPED)
289
290
291
292
293
                        stopped = True
                        break
                if stopped:
                    continue

Lily Liu's avatar
Lily Liu committed
294
295
                # Check if the sequence has reached max_seq_len.
                if (seq.get_len() >=
296
                        self.scheduler.scheduler_config.max_seq_len):
Lily Liu's avatar
Lily Liu committed
297
298
299
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
                    continue
300
301
                # Check if the sequence has reached max_tokens.
                if seq.get_output_len() == sampling_params.max_tokens:
Zhuohan Li's avatar
Zhuohan Li committed
302
303
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
304
305
306
307
                    continue
                # Check if the sequence has generated the EOS token.
                if not sampling_params.ignore_eos:
                    if seq.get_last_token_id() == self.tokenizer.eos_token_id:
308
309
                        self.scheduler.free_seq(
                            seq, SequenceStatus.FINISHED_STOPPED)
310
311
                        continue

312
313
314
315
    def _run_workers(
        self,
        method: str,
        *args,
316
        get_all_outputs: bool = False,
317
318
        **kwargs,
    ) -> Any:
319
        """Runs the given method on all workers."""
320
321
322
        all_outputs = []
        for worker in self.workers:
            executor = getattr(worker, method)
323
            if self.parallel_config.worker_use_ray:
324
                executor = executor.remote
Zhuohan Li's avatar
Zhuohan Li committed
325

326
327
            output = executor(*args, **kwargs)
            all_outputs.append(output)
Zhuohan Li's avatar
Zhuohan Li committed
328

329
        if self.parallel_config.worker_use_ray:
330
331
332
333
334
335
336
337
338
339
            all_outputs = ray.get(all_outputs)

        if get_all_outputs:
            return all_outputs

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
        return output