llm_engine.py 13.2 KB
Newer Older
1
2
3
4
5
6
import time
from typing import Any, List, Optional

from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
                              SchedulerConfig)
from cacheflow.core.scheduler import Scheduler
Zhuohan Li's avatar
Zhuohan Li committed
7
8
9
10
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.engine.tokenizer_utils import (detokenize_incrementally,
                                              get_tokenizer)
11
12
13
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
14
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
15
16
17
18
19
20
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker

logger = init_logger(__name__)


21
class LLMEngine:
Zhuohan Li's avatar
Zhuohan Li committed
22
    """An LLM engine that receives requests and generates texts.
23

Zhuohan Li's avatar
Zhuohan Li committed
24
    This is the main class for the CacheFlow LLM engine. It receives requests
25
26
27
28
29
30
31
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
32
    `AsyncLLMEngine` class wraps this class for online serving.
33

Zhuohan Li's avatar
Zhuohan Li committed
34
35
    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.
36
37
38
39
40
41
42
43
44
45
46
47
48

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
        distributed_init_method: The initialization method for distributed
            execution. See `torch.distributed.init_process_group` for details.
        stage_devices: The list of devices for each stage. Each stage is a list
            of (rank, node_resource, device) tuples.
        log_stats: Whether to log statistics.
    """
49
50
51
52
53
54
55
56

    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        distributed_init_method: str,
57
        stage_devices: List[List[DeviceID]],
58
        log_stats: bool,
59
60
    ) -> None:
        logger.info(
Zhuohan Li's avatar
Zhuohan Li committed
61
            "Initializing an LLM engine with config: "
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
            f"model={model_config.model!r}, "
            f"dtype={model_config.dtype}, "
            f"use_dummy_weights={model_config.use_dummy_weights}, "
            f"download_dir={model_config.download_dir!r}, "
            f"use_np_weights={model_config.use_np_weights}, "
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
            f"seed={model_config.seed})"
        )
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.log_stats = log_stats
        self._verify_args()

        self.tokenizer = get_tokenizer(model_config.model)
        self.seq_counter = Counter()

        # Create the parallel GPU workers.
        self.workers: List[Worker] = []
        assert len(stage_devices) == 1, "Only support one stage for now."
        for rank, node_resource, _ in stage_devices[0]:
            worker_cls = Worker
87
            if self.parallel_config.worker_use_ray:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
                worker_cls = ray.remote(
                    num_cpus=0,
                    num_gpus=1,
                    resources={node_resource: 1e-5},
                )(worker_cls).remote

            worker = worker_cls(
                model_config,
                parallel_config,
                scheduler_config,
                rank,
                distributed_init_method,
            )
            self.workers.append(worker)
        # Profile the memory usage and initialize the cache.
        self._init_cache()

        # Create the scheduler.
        self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)

    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
110
        self.cache_config.verify_with_parallel_config(self.parallel_config)
111
112

    def _init_cache(self) -> None:
113
        """Profiles the memory usage and initializes the KV cache."""
114
115
116
117
118
119
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers(
            "profile_num_available_blocks",
            get_all_outputs=True,
            block_size=self.cache_config.block_size,
            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
120
            cpu_swap_space=self.cache_config.swap_space_bytes,
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
        # FIXME(woosuk): Change to debug log.
        logger.info(f'# GPU blocks: {num_gpu_blocks}, '
                    f'# CPU blocks: {num_cpu_blocks}')
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        # Initialize the cache.
        self._run_workers("init_cache_engine", cache_config=self.cache_config)

137
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
138
139
140
141
142
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
143
144
        # Initialize the cluster.
        distributed_init_method, devices = initialize_cluster(parallel_config)
Zhuohan Li's avatar
Zhuohan Li committed
145
146
147
148
        # Create the LLM engine.
        engine = cls(*engine_configs, distributed_init_method, devices,
                     log_stats=not engine_args.disable_log_stats)
        return engine
149

150
151
152
    def add_request(
        self,
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
153
        prompt: Optional[str],
154
155
156
157
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
158
        """Add a request to the engine's request pool.
159
160

        The request is added to the request pool and will be processed by the
Zhuohan Li's avatar
Zhuohan Li committed
161
        scheduler as `engine.step()` is called. The exact scheduling policy is
162
163
164
165
166
167
168
169
170
171
172
173
        determined by the scheduler.

        Args:
            request_id: The unique ID of the request.
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters for text generation.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
            arrival_time: The arrival time of the request. If None, we use
                the current time.
        """
174
175
176
        if arrival_time is None:
            arrival_time = time.time()
        if prompt_token_ids is None:
Woosuk Kwon's avatar
Woosuk Kwon committed
177
            assert prompt is not None
178
179
180
181
182
            prompt_token_ids = self.tokenizer.encode(prompt)

        # Create the sequences.
        block_size = self.cache_config.block_size
        seqs: List[Sequence] = []
183
        for _ in range(sampling_params.best_of):
184
185
186
187
188
189
190
191
192
193
194
            seq_id = next(self.seq_counter)
            seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
            seqs.append(seq)

        # Create the sequence group.
        seq_group = SequenceGroup(request_id, seqs, sampling_params,
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

195
    def abort_request(self, request_id: str) -> None:
196
197
198
199
200
        """Aborts a request with the given ID.

        Args:
            request_id: The ID of the request to abort.
        """
201
202
        self.scheduler.abort_seq_group(request_id)

203
    def get_num_unfinished_requests(self) -> int:
204
        """Gets the number of unfinished requests."""
205
206
        return self.scheduler.get_num_unfinished_seq_groups()

207
    def has_unfinished_requests(self) -> bool:
208
        """Returns True if there are unfinished requests."""
209
210
211
        return self.scheduler.has_unfinished_seqs()

    def step(self) -> List[RequestOutput]:
212
213
        """Performs one decoding iteration and returns newly generated results.

Zhuohan Li's avatar
Zhuohan Li committed
214
        This function performs one decoding iteration of the engine. It first
215
216
217
218
219
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
220
221
222
223
224
225
226
227
228
229
230
231
232
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
        if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
            # Nothing to do.
            return []

        # Execute the model.
        output = self._run_workers(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )
233
234
235
236
237
238
239
240
241
        # Update the scheduler with the model outputs.
        seq_groups = self.scheduler.update(output)

        # Decode the sequences.
        self._decode_sequences(seq_groups)
        # Stop the sequences that meet the stopping criteria.
        self._stop_sequences(seq_groups)
        # Free the finished sequence groups.
        self.scheduler.free_finished_seq_groups()
242
243
244

        # Create the outputs.
        request_outputs: List[RequestOutput] = []
245
246
        for seq_group in seq_groups:
            request_output = RequestOutput.from_seq_group(seq_group)
247
248
249
            request_outputs.append(request_output)
        return request_outputs

250
    def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
251
        """Decodes the sequence outputs."""
252
        for seq_group in seq_groups:
253
254
255
256
257
258
259
260
261
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                new_token, new_output_text = detokenize_incrementally(
                    self.tokenizer,
                    seq.output_tokens,
                    seq.get_last_token_id(),
                    skip_special_tokens=True,
                )
                seq.output_tokens.append(new_token)
                seq.output_text = new_output_text
262
263

    def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
264
        """Stop the finished sequences."""
265
266
267
268
269
270
271
272
273
274
        for seq_group in seq_groups:
            sampling_params = seq_group.sampling_params
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                # Check if the sequence has generated a stop string.
                stopped = False
                for stop_str in sampling_params.stop:
                    if seq.output_text.endswith(stop_str):
                        # Truncate the output text so that the stop string is
                        # not included in the output.
                        seq.output_text = seq.output_text[:-len(stop_str)]
Zhuohan Li's avatar
Zhuohan Li committed
275
276
                        self.scheduler.free_seq(seq,
                                                SequenceStatus.FINISHED_STOPPED)
277
278
279
280
281
282
283
                        stopped = True
                        break
                if stopped:
                    continue

                # Check if the sequence has reached max_tokens.
                if seq.get_output_len() == sampling_params.max_tokens:
Zhuohan Li's avatar
Zhuohan Li committed
284
285
                    self.scheduler.free_seq(
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
286
287
288
289
                    continue
                # Check if the sequence has generated the EOS token.
                if not sampling_params.ignore_eos:
                    if seq.get_last_token_id() == self.tokenizer.eos_token_id:
Zhuohan Li's avatar
Zhuohan Li committed
290
291
                        self.scheduler.free_seq(seq,
                                                SequenceStatus.FINISHED_STOPPED)
292
293
                        continue

294
295
296
297
298
299
300
    def _run_workers(
        self,
        method: str,
        get_all_outputs: bool = False,
        *args,
        **kwargs,
    ) -> Any:
301
        """Runs the given method on all workers."""
302
303
304
        all_outputs = []
        for worker in self.workers:
            executor = getattr(worker, method)
305
            if self.parallel_config.worker_use_ray:
306
                executor = executor.remote
Zhuohan Li's avatar
Zhuohan Li committed
307

308
309
            output = executor(*args, **kwargs)
            all_outputs.append(output)
Zhuohan Li's avatar
Zhuohan Li committed
310

311
        if self.parallel_config.worker_use_ray:
312
313
314
315
316
317
318
319
320
321
            all_outputs = ray.get(all_outputs)

        if get_all_outputs:
            return all_outputs

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
        return output