worker.py 10.1 KB
Newer Older
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6

import torch

from cacheflow.models import get_model
from cacheflow.models import InputMetadata
7
8
9
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
10
from cacheflow.worker.cache_engine import CacheEngine
Zhuohan Li's avatar
Zhuohan Li committed
11
from cacheflow.parallel_utils.parallel_state import (
12
13
14
    initialize_model_parallel,
    initialize_all_reduce_launcher,
    get_tensor_model_parallel_world_size)
Zhuohan Li's avatar
Zhuohan Li committed
15
from cacheflow.utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20
21
22
23
24
25


class Worker:

    def __init__(
        self,
        model_name: str,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
26
        dtype: str,
27
        seed: int,
Zhuohan Li's avatar
Zhuohan Li committed
28
29
30
31
        distributed_init_method: str,
        rank: int,
        world_size: int,
        model_path: str,
32
        max_num_batched_tokens: int,
Zhuohan Li's avatar
Zhuohan Li committed
33
34
        tensor_parallel_size: int = 1,
        pipeline_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
35
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
36
37
38
39
40
41
        self.init_distributed_environment(distributed_init_method,
                                          rank,
                                          world_size,
                                          tensor_parallel_size,
                                          pipeline_parallel_size)
        self.worker_id = rank
Woosuk Kwon's avatar
Woosuk Kwon committed
42
        self.block_size = block_size
Zhuohan Li's avatar
Zhuohan Li committed
43
        set_random_seed(seed)
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45

        # Initialize the model.
Zhuohan Li's avatar
Zhuohan Li committed
46
47
48
49
        self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path)
        self.model = self.model.cuda()
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
50
51
        initialize_all_reduce_launcher(
            max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
52
        self.num_layers = self.model.config.num_hidden_layers
Zhuohan Li's avatar
Zhuohan Li committed
53
54
55
        assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
        self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
        self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
56

Zhuohan Li's avatar
Zhuohan Li committed
57
        # We reset the seed after initializing the model to ensure that
58
        # the random state is not affected by the model initialization.
Zhuohan Li's avatar
Zhuohan Li committed
59
        set_random_seed(seed)
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61
        self.cache_engine = CacheEngine(
Zhuohan Li's avatar
Zhuohan Li committed
62
            worker_id=self.worker_id,
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
66
67
68
69
70
71
72
73
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            head_size=self.head_size,
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
            dtype=self.dtype,
        )
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache

Zhuohan Li's avatar
Zhuohan Li committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    def init_distributed_environment(self,
                                     distributed_init_method: str,
                                     rank: int,
                                     world_size: int,
                                     tensor_parallel_size: int = 1,
                                     pipeline_parallel_size: int = 1) -> None:
        """Initialize the distributed environment."""
        torch.distributed.init_process_group(
            backend='nccl',
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank,
        )
        # A small all_reduce for warmup.
        torch.distributed.all_reduce(torch.zeros(1).cuda())
        initialize_model_parallel(tensor_parallel_size,
                                  pipeline_parallel_size)


Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
    def prepare_inputs(
        self,
96
        input_seq_groups: List[SequenceGroupInputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
97
    ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
98
99
100
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        seq_logprobs: Dict[int, float] = {}
        sampling_params: Dict[int, SamplingParams] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        # Add prompt tokens.
        prompt_lens: List[int] = []
        for input_seq_group in input_seq_groups:
            if not input_seq_group.is_prompt:
                continue

            seq_ids = list(input_seq_group.input_tokens.keys())
            sampling_params = input_seq_group.sampling_params
            seq_groups.append((seq_ids, sampling_params))
            seq_logprobs.update(input_seq_group.seq_logprobs)

            # Use any sequence in the group.
            seq_id = seq_ids[0]

            prompt_tokens = input_seq_group.input_tokens[seq_id]
            prompt_len = len(prompt_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
            prompt_lens.append(prompt_len)

123
124
125
126
            input_tokens.extend(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
            input_positions.extend(range(len(prompt_tokens)))
Woosuk Kwon's avatar
Woosuk Kwon committed
127

128
129
            # Compute the slot mapping.
            block_table = input_seq_group.block_tables[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134
135
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140
        cumulative_prompt_lens: List[int] = [0]
        for prompt_len in prompt_lens:
            cumulative_prompt_lens.append(
                cumulative_prompt_lens[-1] + prompt_len)

141
        # Add generation tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
        max_context_len = 0
        max_num_blocks_per_seq = 0
144
        context_lens: List[int] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        generation_block_tables: List[List[int]] = []
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        for input_seq_group in input_seq_groups:
            if input_seq_group.is_prompt:
                continue

            seq_ids = list(input_seq_group.input_tokens.keys())
            sampling_params = input_seq_group.sampling_params
            seq_groups.append((seq_ids, sampling_params))
            seq_logprobs.update(input_seq_group.seq_logprobs)

            for seq_id in seq_ids:
                assert len(input_seq_group.input_tokens[seq_id]) == 1
                generation_token = input_seq_group.input_tokens[seq_id][0]
                input_tokens.append(generation_token)

                position = input_seq_group.context_len - 1
                input_positions.append(position)

                block_table = input_seq_group.block_tables[seq_id]
                generation_block_tables.append(block_table)

                max_context_len = max(
                    max_context_len, input_seq_group.context_len)
                max_num_blocks_per_seq = max(
                    max_num_blocks_per_seq, len(block_table))
                context_lens.append(input_seq_group.context_len)

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178
179
180
181
182
183

        # Optimization: Pad the input length to be a multiple of 8.
        # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
        input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
        input_positions = _pad_to_alignment(input_positions, multiple_of=8)

        # Convert to tensors.
        tokens_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
184
            input_tokens, dtype=torch.long, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
185
        positions_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
186
            input_positions, dtype=torch.long, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
187
        slot_mapping_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
188
            slot_mapping, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
189
        context_lens_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
190
            context_lens, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
191
192
193
        padded_block_tables = [
            _pad_to_max(block_table, max_num_blocks_per_seq)
            for block_table in generation_block_tables]
Woosuk Kwon's avatar
Woosuk Kwon committed
194
        block_tables_tensor = torch.tensor(
Zhuohan Li's avatar
Zhuohan Li committed
195
            padded_block_tables, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
        cumulative_prompt_lens_tensor = torch.tensor(
            cumulative_prompt_lens, dtype=torch.int, device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199

        input_metadata = InputMetadata(
200
201
            seq_groups=seq_groups,
            seq_logprobs=seq_logprobs,
Woosuk Kwon's avatar
Woosuk Kwon committed
202
            prompt_lens=prompt_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
203
            cumulative_prompt_lens=cumulative_prompt_lens_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
207
208
209
210
211
212
213
            slot_mapping=slot_mapping_tensor,
            context_lens=context_lens_tensor,
            max_context_len=max_context_len,
            block_tables=block_tables_tensor,
        )
        return tokens_tensor, positions_tensor, input_metadata

    @torch.inference_mode()
    def execute_stage(
        self,
214
        input_seq_groups: List[SequenceGroupInputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
217
218
        blocks_to_copy: Dict[int, List[int]],
    ) -> Dict[int, SequenceOutputs]:
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        # Issue cache operations.
        command_issued = False
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
            command_issued = True
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
            command_issued = True
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
            command_issued = True

        if command_issued:
            cache_events = self.cache_events
        else:
            cache_events = None

Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238
239
240
241
242
        # If there is no input, we don't need to execute the model.
        if not input_seq_groups:
            if cache_events is not None:
                for event in cache_events:
                    event.wait()
            return {}

Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
        # Prepare input tensors.
        input_tokens, input_positions, input_metadata = self.prepare_inputs(
245
            input_seq_groups)
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
249
250

        # Execute the model.
        output = self.model(
            input_ids=input_tokens,
            positions=input_positions,
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
251
            kv_caches=self.gpu_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
254
255
256
257
258
259
260
261
262
263
            input_metadata=input_metadata,
            cache_events=cache_events,
        )
        return output


def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
    return x + [0] * ((-len(x)) % multiple_of)


def _pad_to_max(x: List[int], max_len: int) -> List[int]:
    return x + [0] * (max_len - len(x))