worker.py 13.5 KB
Newer Older
1
"""A GPU worker class."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch

6
7
from cacheflow.model_executor import (get_model, get_cache_block_size,
                                      InputMetadata, set_random_seed)
8
9
10
11
from cacheflow.model_executor.parallel_utils.parallel_state import (
    initialize_model_parallel,
    initialize_all_reduce_launcher,
    get_tensor_model_parallel_world_size)
12
from cacheflow.sampling_params import SamplingParams
13
14
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
                                SequenceOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from cacheflow.worker.cache_engine import CacheEngine
16
from cacheflow.utils import get_gpu_memory
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18

Woosuk Kwon's avatar
Woosuk Kwon committed
19
class Worker:
20
21
22
23
24
25
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29

    def __init__(
        self,
        model_name: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
30
        dtype: str,
31
        seed: int,
Zhuohan Li's avatar
Zhuohan Li committed
32
33
34
        distributed_init_method: str,
        rank: int,
        world_size: int,
35
        cache_dir: Optional[str],
36
        use_dummy_weights: bool,
37
        use_np_cache: bool,
38
        max_num_batched_tokens: int,
39
        max_num_sequences: int,
Zhuohan Li's avatar
Zhuohan Li committed
40
41
        tensor_parallel_size: int = 1,
        pipeline_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
42
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
43
44
45
46
47
48
        self.init_distributed_environment(distributed_init_method,
                                          rank,
                                          world_size,
                                          tensor_parallel_size,
                                          pipeline_parallel_size)
        self.worker_id = rank
49
50
        self.seed = seed
        set_random_seed(self.seed)
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52

        # Initialize the model.
53
        self.model, self.dtype = get_model(
54
55
            model_name, dtype=dtype, cache_dir=cache_dir,
            use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
Zhuohan Li's avatar
Zhuohan Li committed
56
57
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
58
        self.max_num_batched_tokens = max_num_batched_tokens
59
        initialize_all_reduce_launcher(
60
61
            self.max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
        self.max_num_sequences = max_num_sequences
Woosuk Kwon's avatar
Woosuk Kwon committed
62
        self.num_layers = self.model.config.num_hidden_layers
Zhuohan Li's avatar
Zhuohan Li committed
63
64
65
        assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
        self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
        self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
66

Zhuohan Li's avatar
Zhuohan Li committed
67
        # We reset the seed after initializing the model to ensure that
68
        # the random state is not affected by the model initialization.
Zhuohan Li's avatar
Zhuohan Li committed
69
        set_random_seed(seed)
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        # Uninitialized cache engine. Will be initialized with
        # self.init_cache_engine().
        self.block_size = None
        self.cache_engine = None
        self.cache_events = None
        self.gpu_cache = None

    @torch.inference_mode()
    def get_num_available_blocks(
        self, block_size: int, cpu_swap_space: int,
        gpu_memory_utilization: float) -> Tuple[int, int]:
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.

        # Enable top-k sampling to reflect the accurate memory usage.
        sampling_params = SamplingParams(top_p=0.99,
                                         top_k=self.model.config.vocab_size - 1)
        seqs = []
        for group_id in range(self.max_num_sequences):
            seq_len = (self.max_num_batched_tokens // self.max_num_sequences +
                       (group_id < self.max_num_batched_tokens %
                        self.max_num_sequences))
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
                group_id=group_id,
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
            )
            seqs.append(seq)

        input_tokens, input_positions, input_metadata = self.prepare_inputs(seqs)

        # Execute the model.
        self.model(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=[(None, None)] * self.num_layers,
            input_metadata=input_metadata,
            cache_events=None,
        )

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
        peak_memory = torch.cuda.max_memory_allocated()
        total_gpu_memory = get_gpu_memory()
        cache_block_size = get_cache_block_size(block_size, self.num_heads,
                                                self.head_size, self.num_layers,
                                                self.dtype)
        num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
                              - peak_memory) // cache_block_size)
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
        torch.cuda.empty_cache()
        # Reset the seed to ensure that the model output is not affected by
        # the profiling.
        set_random_seed(self.seed)
        return num_gpu_blocks, num_cpu_blocks

    def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
                          num_cpu_blocks: int):
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
139
        self.cache_engine = CacheEngine(
Zhuohan Li's avatar
Zhuohan Li committed
140
            worker_id=self.worker_id,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            head_size=self.head_size,
144
            block_size=self.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
148
149
150
151
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
            dtype=self.dtype,
        )
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache

Zhuohan Li's avatar
Zhuohan Li committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    def init_distributed_environment(self,
                                     distributed_init_method: str,
                                     rank: int,
                                     world_size: int,
                                     tensor_parallel_size: int = 1,
                                     pipeline_parallel_size: int = 1) -> None:
        """Initialize the distributed environment."""
        torch.distributed.init_process_group(
            backend='nccl',
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank,
        )
        # A small all_reduce for warmup.
        torch.distributed.all_reduce(torch.zeros(1).cuda())
        initialize_model_parallel(tensor_parallel_size,
                                  pipeline_parallel_size)

Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
    def prepare_inputs(
        self,
172
        seq_group_metadata_list: List[SequenceGroupMetadata],
Woosuk Kwon's avatar
Woosuk Kwon committed
173
    ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
174
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
178
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []

179
180
        # Add prompt tokens.
        prompt_lens: List[int] = []
181
182
        for seq_group_metadata in seq_group_metadata_list:
            if not seq_group_metadata.is_prompt:
183
184
                continue

185
            seq_ids = list(seq_group_metadata.seq_data.keys())
186
            sampling_params = seq_group_metadata.sampling_params
187
188
189
190
191
            seq_groups.append((seq_ids, sampling_params))

            # Use any sequence in the group.
            seq_id = seq_ids[0]

192
193
            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
194
            prompt_len = len(prompt_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
            prompt_lens.append(prompt_len)

197
198
199
200
            input_tokens.extend(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
            input_positions.extend(range(len(prompt_tokens)))
Woosuk Kwon's avatar
Woosuk Kwon committed
201

202
203
204
205
206
207
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
                slot_mapping.extend([0] * prompt_len)
                continue

208
            # Compute the slot mapping.
209
            block_table = seq_group_metadata.block_tables[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212
213
214
215
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

216
        # Add generation tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
        max_context_len = 0
        max_num_blocks_per_seq = 0
219
        context_lens: List[int] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
220
        generation_block_tables: List[List[int]] = []
221
222
        for seq_group_metadata in seq_group_metadata_list:
            if seq_group_metadata.is_prompt:
223
224
                continue

225
            seq_ids = list(seq_group_metadata.seq_data.keys())
226
            sampling_params = seq_group_metadata.sampling_params
227
228
229
            seq_groups.append((seq_ids, sampling_params))

            for seq_id in seq_ids:
230
231
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
232
233
                input_tokens.append(generation_token)

234
235
                context_len = seq_data.get_len()
                position = context_len - 1
236
237
                input_positions.append(position)

238
                block_table = seq_group_metadata.block_tables[seq_id]
239
240
                generation_block_tables.append(block_table)

241
                max_context_len = max(max_context_len, context_len)
242
243
                max_num_blocks_per_seq = max(
                    max_num_blocks_per_seq, len(block_table))
244
                context_lens.append(context_len)
245
246
247
248
249

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
253
254
255
256
257

        # Optimization: Pad the input length to be a multiple of 8.
        # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
        input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
        input_positions = _pad_to_alignment(input_positions, multiple_of=8)

        # Convert to tensors.
        tokens_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
258
            input_tokens, dtype=torch.long, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
259
        positions_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
260
            input_positions, dtype=torch.long, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
261
        slot_mapping_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
262
            slot_mapping, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
263
        context_lens_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
264
            context_lens, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
267
        padded_block_tables = [
            _pad_to_max(block_table, max_num_blocks_per_seq)
            for block_table in generation_block_tables]
Woosuk Kwon's avatar
Woosuk Kwon committed
268
        block_tables_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
269
            padded_block_tables, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
270

271
272
273
274
        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

Woosuk Kwon's avatar
Woosuk Kwon committed
275
        input_metadata = InputMetadata(
276
            seq_groups=seq_groups,
277
            seq_data=seq_data,
Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
280
281
282
283
284
285
286
287
288
            prompt_lens=prompt_lens,
            slot_mapping=slot_mapping_tensor,
            context_lens=context_lens_tensor,
            max_context_len=max_context_len,
            block_tables=block_tables_tensor,
        )
        return tokens_tensor, positions_tensor, input_metadata

    @torch.inference_mode()
    def execute_stage(
        self,
289
        seq_group_metadata_list: List[SequenceGroupMetadata],
Woosuk Kwon's avatar
Woosuk Kwon committed
290
291
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
292
293
        blocks_to_copy: Dict[int, List[int]],
    ) -> Dict[int, SequenceOutputs]:
Woosuk Kwon's avatar
Woosuk Kwon committed
294
        # Issue cache operations.
295
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
298
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
301
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
304
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
305

306
        if issued_cache_op:
Woosuk Kwon's avatar
Woosuk Kwon committed
307
308
309
310
            cache_events = self.cache_events
        else:
            cache_events = None

Woosuk Kwon's avatar
Woosuk Kwon committed
311
        # If there is no input, we don't need to execute the model.
312
        if not seq_group_metadata_list:
Woosuk Kwon's avatar
Woosuk Kwon committed
313
314
315
316
317
            if cache_events is not None:
                for event in cache_events:
                    event.wait()
            return {}

Woosuk Kwon's avatar
Woosuk Kwon committed
318
319
        # Prepare input tensors.
        input_tokens, input_positions, input_metadata = self.prepare_inputs(
320
            seq_group_metadata_list)
Woosuk Kwon's avatar
Woosuk Kwon committed
321
322
323
324
325

        # Execute the model.
        output = self.model(
            input_ids=input_tokens,
            positions=input_positions,
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
326
            kv_caches=self.gpu_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
329
330
331
332
333
334
335
336
337
338
            input_metadata=input_metadata,
            cache_events=cache_events,
        )
        return output


def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
    return x + [0] * ((-len(x)) % multiple_of)


def _pad_to_max(x: List[int], max_len: int) -> List[int]:
    return x + [0] * (max_len - len(x))