worker.py 12.1 KB
Newer Older
1
"""A GPU worker class."""
2
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch

6
7
8
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
                              SchedulerConfig)
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
9
from cacheflow.model_executor.parallel_utils.parallel_state import (
10
    initialize_model_parallel, initialize_all_reduce_launcher)
11
from cacheflow.sampling_params import SamplingParams
12
13
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
                                SequenceOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from cacheflow.worker.cache_engine import CacheEngine
15
from cacheflow.utils import get_gpu_memory
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
class Worker:
19
20
21
22
23
24
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27

    def __init__(
        self,
28
29
30
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
Zhuohan Li's avatar
Zhuohan Li committed
31
        rank: int,
32
        distributed_init_method: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
    ) -> None:
34
35
36
37
38
39
40
41
42
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.rank = rank
        self.distributed_init_method = distributed_init_method

        # Initialize the distributed environment.
        _init_distributed_environment(parallel_config, rank,
                                      distributed_init_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44

        # Initialize the model.
45
46
        set_random_seed(self.model_config.seed)
        self.model = get_model(model_config)
47
        initialize_all_reduce_launcher(
48
49
50
51
52
53
            self.scheduler_config.max_num_batched_tokens,
            self.model_config.get_hidden_size(),
            self.model_config.dtype,
        )

        # Uninitialized cache engine. Will be initialized by
54
        # self.init_cache_engine().
55
        self.cache_config = None
56
57
58
59
60
61
        self.block_size = None
        self.cache_engine = None
        self.cache_events = None
        self.gpu_cache = None

    @torch.inference_mode()
62
63
64
65
66
67
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
    ) -> Tuple[int, int]:
68
69
70
71
72
73
74
75
76
77
78
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.

        # Enable top-k sampling to reflect the accurate memory usage.
        sampling_params = SamplingParams(top_p=0.99,
                                         top_k=self.model.config.vocab_size - 1)
79
80
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
81
        seqs = []
82
83
84
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
85
86
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
87
                request_id=str(group_id),
88
89
90
91
92
93
94
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
            )
            seqs.append(seq)

95
        input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
96
97

        # Execute the model.
98
        num_layers = self.model_config.get_num_layers(self.parallel_config)
99
100
101
        self.model(
            input_ids=input_tokens,
            positions=input_positions,
102
            kv_caches=[(None, None)] * num_layers,
103
104
105
106
107
108
109
110
111
            input_metadata=input_metadata,
            cache_events=None,
        )

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
        peak_memory = torch.cuda.max_memory_allocated()
        total_gpu_memory = get_gpu_memory()
112
113
        cache_block_size = CacheEngine.get_cache_block_size(
            block_size, self.model_config, self.parallel_config)
114
115
116
117
        num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
                              - peak_memory) // cache_block_size)
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
        torch.cuda.empty_cache()
118
119
120
121

        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)
122
123
        return num_gpu_blocks, num_cpu_blocks

124
125
126
    def init_cache_engine(self, cache_config: CacheConfig) -> None:
        self.cache_config = cache_config
        self.block_size = cache_config.block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
127
        self.cache_engine = CacheEngine(
128
            self.cache_config, self.model_config, self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache

132
    def _prepare_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
133
        self,
134
        seq_group_metadata_list: List[SequenceGroupMetadata],
135
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
136
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []

141
142
        # Add prompt tokens.
        prompt_lens: List[int] = []
143
144
        for seq_group_metadata in seq_group_metadata_list:
            if not seq_group_metadata.is_prompt:
145
146
                continue

147
            seq_ids = list(seq_group_metadata.seq_data.keys())
148
            sampling_params = seq_group_metadata.sampling_params
149
150
151
152
153
            seq_groups.append((seq_ids, sampling_params))

            # Use any sequence in the group.
            seq_id = seq_ids[0]

154
155
            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
156
            prompt_len = len(prompt_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
            prompt_lens.append(prompt_len)

159
160
161
162
            input_tokens.extend(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
            input_positions.extend(range(len(prompt_tokens)))
Woosuk Kwon's avatar
Woosuk Kwon committed
163

164
165
166
167
168
169
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
                slot_mapping.extend([0] * prompt_len)
                continue

170
            # Compute the slot mapping.
171
            block_table = seq_group_metadata.block_tables[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
174
175
176
177
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

178
        # Add generation tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
        max_context_len = 0
        max_num_blocks_per_seq = 0
181
        context_lens: List[int] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
182
        generation_block_tables: List[List[int]] = []
183
184
        for seq_group_metadata in seq_group_metadata_list:
            if seq_group_metadata.is_prompt:
185
186
                continue

187
            seq_ids = list(seq_group_metadata.seq_data.keys())
188
            sampling_params = seq_group_metadata.sampling_params
189
190
191
            seq_groups.append((seq_ids, sampling_params))

            for seq_id in seq_ids:
192
193
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
194
195
                input_tokens.append(generation_token)

196
197
                context_len = seq_data.get_len()
                position = context_len - 1
198
199
                input_positions.append(position)

200
                block_table = seq_group_metadata.block_tables[seq_id]
201
202
                generation_block_tables.append(block_table)

203
                max_context_len = max(max_context_len, context_len)
204
205
                max_num_blocks_per_seq = max(
                    max_num_blocks_per_seq, len(block_table))
206
                context_lens.append(context_len)
207
208
209
210
211

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
215
216
217
218

        # Optimization: Pad the input length to be a multiple of 8.
        # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
        input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
        input_positions = _pad_to_alignment(input_positions, multiple_of=8)

        # Convert to tensors.
219
220
221
222
        tokens_tensor = torch.cuda.LongTensor(input_tokens)
        positions_tensor = torch.cuda.LongTensor(input_positions)
        slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping)
        context_lens_tensor = torch.cuda.IntTensor(context_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
        padded_block_tables = [
            _pad_to_max(block_table, max_num_blocks_per_seq)
            for block_table in generation_block_tables]
226
        block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
Woosuk Kwon's avatar
Woosuk Kwon committed
227

228
229
230
231
        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

Woosuk Kwon's avatar
Woosuk Kwon committed
232
        input_metadata = InputMetadata(
233
            seq_groups=seq_groups,
234
            seq_data=seq_data,
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
240
241
242
243
            prompt_lens=prompt_lens,
            slot_mapping=slot_mapping_tensor,
            context_lens=context_lens_tensor,
            max_context_len=max_context_len,
            block_tables=block_tables_tensor,
        )
        return tokens_tensor, positions_tensor, input_metadata

    @torch.inference_mode()
244
    def execute_model(
Woosuk Kwon's avatar
Woosuk Kwon committed
245
        self,
246
        seq_group_metadata_list: List[SequenceGroupMetadata],
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
249
250
        blocks_to_copy: Dict[int, List[int]],
    ) -> Dict[int, SequenceOutputs]:
Woosuk Kwon's avatar
Woosuk Kwon committed
251
        # Issue cache operations.
252
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
255
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
258
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
261
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
262

263
        if issued_cache_op:
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
267
            cache_events = self.cache_events
        else:
            cache_events = None

Woosuk Kwon's avatar
Woosuk Kwon committed
268
        # If there is no input, we don't need to execute the model.
269
        if not seq_group_metadata_list:
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271
272
273
274
            if cache_events is not None:
                for event in cache_events:
                    event.wait()
            return {}

Woosuk Kwon's avatar
Woosuk Kwon committed
275
        # Prepare input tensors.
276
        input_tokens, input_positions, input_metadata = self._prepare_inputs(
277
            seq_group_metadata_list)
Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
280
281
282

        # Execute the model.
        output = self.model(
            input_ids=input_tokens,
            positions=input_positions,
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
283
            kv_caches=self.gpu_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286
287
288
289
            input_metadata=input_metadata,
            cache_events=cache_events,
        )
        return output


290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def _init_distributed_environment(
    parallel_config: ParallelConfig,
    rank: int,
    distributed_init_method: str,
) -> None:
    """Initialize the distributed environment."""
    torch.distributed.init_process_group(
        backend="nccl",
        world_size=parallel_config.world_size,
        rank=rank,
        init_method=distributed_init_method,
    )
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
    initialize_model_parallel(parallel_config.tensor_parallel_size,
                              parallel_config.pipeline_parallel_size)


Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
311
312
313
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
    return x + [0] * ((-len(x)) % multiple_of)


def _pad_to_max(x: List[int], max_len: int) -> List[int]:
    return x + [0] * (max_len - len(x))