worker.py 10.2 KB
Newer Older
1
"""A GPU worker class."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch

6
7
8
9
10
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
from cacheflow.model_executor.parallel_utils.parallel_state import (
    initialize_model_parallel,
    initialize_all_reduce_launcher,
    get_tensor_model_parallel_world_size)
11
from cacheflow.sampling_params import SamplingParams
12
13
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
                                SequenceOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
from cacheflow.worker.cache_engine import CacheEngine

16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
class Worker:
18
19
20
21
22
23
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
28
29
30

    def __init__(
        self,
        model_name: str,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
        dtype: str,
32
        seed: int,
Zhuohan Li's avatar
Zhuohan Li committed
33
34
35
        distributed_init_method: str,
        rank: int,
        world_size: int,
36
        cache_dir: Optional[str],
37
        use_dummy_weights: bool,
38
        use_np_cache: bool,
39
        max_num_batched_tokens: int,
Zhuohan Li's avatar
Zhuohan Li committed
40
41
        tensor_parallel_size: int = 1,
        pipeline_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
42
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
43
44
45
46
47
48
        self.init_distributed_environment(distributed_init_method,
                                          rank,
                                          world_size,
                                          tensor_parallel_size,
                                          pipeline_parallel_size)
        self.worker_id = rank
Woosuk Kwon's avatar
Woosuk Kwon committed
49
        self.block_size = block_size
Zhuohan Li's avatar
Zhuohan Li committed
50
        set_random_seed(seed)
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52

        # Initialize the model.
53
        self.model, self.dtype = get_model(
54
55
            model_name, dtype=dtype, cache_dir=cache_dir,
            use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
Zhuohan Li's avatar
Zhuohan Li committed
56
57
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
58
59
        initialize_all_reduce_launcher(
            max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
60
        self.num_layers = self.model.config.num_hidden_layers
Zhuohan Li's avatar
Zhuohan Li committed
61
62
63
        assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
        self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
        self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
64

Zhuohan Li's avatar
Zhuohan Li committed
65
        # We reset the seed after initializing the model to ensure that
66
        # the random state is not affected by the model initialization.
Zhuohan Li's avatar
Zhuohan Li committed
67
        set_random_seed(seed)
68

Woosuk Kwon's avatar
Woosuk Kwon committed
69
        self.cache_engine = CacheEngine(
Zhuohan Li's avatar
Zhuohan Li committed
70
            worker_id=self.worker_id,
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74
75
76
77
78
79
80
81
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            head_size=self.head_size,
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
            dtype=self.dtype,
        )
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache

Zhuohan Li's avatar
Zhuohan Li committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    def init_distributed_environment(self,
                                     distributed_init_method: str,
                                     rank: int,
                                     world_size: int,
                                     tensor_parallel_size: int = 1,
                                     pipeline_parallel_size: int = 1) -> None:
        """Initialize the distributed environment."""
        torch.distributed.init_process_group(
            backend='nccl',
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank,
        )
        # A small all_reduce for warmup.
        torch.distributed.all_reduce(torch.zeros(1).cuda())
        initialize_model_parallel(tensor_parallel_size,
                                  pipeline_parallel_size)

Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
    def prepare_inputs(
        self,
102
        seq_group_metadata_list: List[SequenceGroupMetadata],
Woosuk Kwon's avatar
Woosuk Kwon committed
103
    ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
104
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []

109
110
        # Add prompt tokens.
        prompt_lens: List[int] = []
111
112
        for seq_group_metadata in seq_group_metadata_list:
            if not seq_group_metadata.is_prompt:
113
114
                continue

115
            seq_ids = list(seq_group_metadata.seq_data.keys())
116
            sampling_params = seq_group_metadata.sampling_params
117
118
119
120
121
            seq_groups.append((seq_ids, sampling_params))

            # Use any sequence in the group.
            seq_id = seq_ids[0]

122
123
            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
124
            prompt_len = len(prompt_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
            prompt_lens.append(prompt_len)

127
128
129
130
            input_tokens.extend(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
            input_positions.extend(range(len(prompt_tokens)))
Woosuk Kwon's avatar
Woosuk Kwon committed
131

132
            # Compute the slot mapping.
133
            block_table = seq_group_metadata.block_tables[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
134
135
136
137
138
139
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

140
        # Add generation tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
        max_context_len = 0
        max_num_blocks_per_seq = 0
143
        context_lens: List[int] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        generation_block_tables: List[List[int]] = []
145
146
        for seq_group_metadata in seq_group_metadata_list:
            if seq_group_metadata.is_prompt:
147
148
                continue

149
            seq_ids = list(seq_group_metadata.seq_data.keys())
150
            sampling_params = seq_group_metadata.sampling_params
151
152
153
            seq_groups.append((seq_ids, sampling_params))

            for seq_id in seq_ids:
154
155
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
156
157
                input_tokens.append(generation_token)

158
159
                context_len = seq_data.get_len()
                position = context_len - 1
160
161
                input_positions.append(position)

162
                block_table = seq_group_metadata.block_tables[seq_id]
163
164
                generation_block_tables.append(block_table)

165
                max_context_len = max(max_context_len, context_len)
166
167
                max_num_blocks_per_seq = max(
                    max_num_blocks_per_seq, len(block_table))
168
                context_lens.append(context_len)
169
170
171
172
173

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
178
179
180
181

        # Optimization: Pad the input length to be a multiple of 8.
        # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
        input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
        input_positions = _pad_to_alignment(input_positions, multiple_of=8)

        # Convert to tensors.
        tokens_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
182
            input_tokens, dtype=torch.long, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
183
        positions_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
184
            input_positions, dtype=torch.long, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
185
        slot_mapping_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
186
            slot_mapping, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
187
        context_lens_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
188
            context_lens, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
191
        padded_block_tables = [
            _pad_to_max(block_table, max_num_blocks_per_seq)
            for block_table in generation_block_tables]
Woosuk Kwon's avatar
Woosuk Kwon committed
192
        block_tables_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
193
            padded_block_tables, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
194

195
196
197
198
        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

Woosuk Kwon's avatar
Woosuk Kwon committed
199
        input_metadata = InputMetadata(
200
            seq_groups=seq_groups,
201
            seq_data=seq_data,
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203
204
205
206
207
208
209
210
211
212
            prompt_lens=prompt_lens,
            slot_mapping=slot_mapping_tensor,
            context_lens=context_lens_tensor,
            max_context_len=max_context_len,
            block_tables=block_tables_tensor,
        )
        return tokens_tensor, positions_tensor, input_metadata

    @torch.inference_mode()
    def execute_stage(
        self,
213
        seq_group_metadata_list: List[SequenceGroupMetadata],
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
216
217
        blocks_to_copy: Dict[int, List[int]],
    ) -> Dict[int, SequenceOutputs]:
Woosuk Kwon's avatar
Woosuk Kwon committed
218
        # Issue cache operations.
219
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
222
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
225
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
226
227
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
228
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
229

230
        if issued_cache_op:
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
233
234
            cache_events = self.cache_events
        else:
            cache_events = None

Woosuk Kwon's avatar
Woosuk Kwon committed
235
        # If there is no input, we don't need to execute the model.
236
        if not seq_group_metadata_list:
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239
240
241
            if cache_events is not None:
                for event in cache_events:
                    event.wait()
            return {}

Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
        # Prepare input tensors.
        input_tokens, input_positions, input_metadata = self.prepare_inputs(
244
            seq_group_metadata_list)
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
247
248
249

        # Execute the model.
        output = self.model(
            input_ids=input_tokens,
            positions=input_positions,
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
250
            kv_caches=self.gpu_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
254
255
256
257
258
259
260
261
262
            input_metadata=input_metadata,
            cache_events=cache_events,
        )
        return output


def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
    return x + [0] * ((-len(x)) % multiple_of)


def _pad_to_max(x: List[int], max_len: int) -> List[int]:
    return x + [0] * (max_len - len(x))