worker.py 7.82 KB
Newer Older
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5

import torch

from cacheflow.models import get_model
6
from cacheflow.models import set_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
7
from cacheflow.models import InputMetadata
8
9
10
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
17
18
19
20
21
22
23
from cacheflow.worker.cache_engine import CacheEngine


class Worker:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        model_name: str,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
24
        dtype: str,
25
        seed: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
31
32
33
34
    ) -> None:
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.block_size = block_size

        self.device = torch.device('cuda', index=gpu_id)

        # Initialize the model.
        # FIXME(woosuk): This is a hack.
Woosuk Kwon's avatar
Woosuk Kwon committed
35
        self.model = get_model(model_name, dtype=dtype).to(device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
40
        self.num_layers = self.model.config.num_hidden_layers
        self.num_heads = self.model.config.num_attention_heads
        self.head_size = self.model.config.hidden_size // self.num_heads
        self.dtype = self.model.dtype

41
42
43
44
45
        # Set the seed.
        # We set the seed after initializing the model to ensure that
        # the random state is not affected by the model initialization.
        set_seed(seed)

Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        self.cache_engine = CacheEngine(
            worker_id=worker_id,
            gpu_id=gpu_id,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            head_size=self.head_size,
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
            dtype=self.dtype,
        )
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache

    def prepare_inputs(
        self,
62
        input_seq_groups: List[SequenceGroupInputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
63
    ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
64
65
66
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        seq_logprobs: Dict[int, float] = {}
        sampling_params: Dict[int, SamplingParams] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        # Add prompt tokens.
        prompt_lens: List[int] = []
        for input_seq_group in input_seq_groups:
            if not input_seq_group.is_prompt:
                continue

            seq_ids = list(input_seq_group.input_tokens.keys())
            sampling_params = input_seq_group.sampling_params
            seq_groups.append((seq_ids, sampling_params))
            seq_logprobs.update(input_seq_group.seq_logprobs)

            # Use any sequence in the group.
            seq_id = seq_ids[0]

            prompt_tokens = input_seq_group.input_tokens[seq_id]
            prompt_len = len(prompt_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
            prompt_lens.append(prompt_len)

89
90
91
92
            input_tokens.extend(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
            input_positions.extend(range(len(prompt_tokens)))
Woosuk Kwon's avatar
Woosuk Kwon committed
93

94
95
            # Compute the slot mapping.
            block_table = input_seq_group.block_tables[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

102
        # Add generation tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
        max_context_len = 0
        max_num_blocks_per_seq = 0
105
        context_lens: List[int] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
106
        generation_block_tables: List[List[int]] = []
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        for input_seq_group in input_seq_groups:
            if input_seq_group.is_prompt:
                continue

            seq_ids = list(input_seq_group.input_tokens.keys())
            sampling_params = input_seq_group.sampling_params
            seq_groups.append((seq_ids, sampling_params))
            seq_logprobs.update(input_seq_group.seq_logprobs)

            for seq_id in seq_ids:
                assert len(input_seq_group.input_tokens[seq_id]) == 1
                generation_token = input_seq_group.input_tokens[seq_id][0]
                input_tokens.append(generation_token)

                position = input_seq_group.context_len - 1
                input_positions.append(position)

                block_table = input_seq_group.block_tables[seq_id]
                generation_block_tables.append(block_table)

                max_context_len = max(
                    max_context_len, input_seq_group.context_len)
                max_num_blocks_per_seq = max(
                    max_num_blocks_per_seq, len(block_table))
                context_lens.append(input_seq_group.context_len)

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150

        # Optimization: Pad the input length to be a multiple of 8.
        # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
        input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
        input_positions = _pad_to_alignment(input_positions, multiple_of=8)

        # Convert to tensors.
        tokens_tensor = torch.tensor(
            input_tokens, dtype=torch.long, device=self.device)
        positions_tensor = torch.tensor(
            input_positions, dtype=torch.long, device=self.device)
        slot_mapping_tensor = torch.tensor(
            slot_mapping, dtype=torch.int, device=self.device)
        context_lens_tensor = torch.tensor(
151
            context_lens, dtype=torch.int, device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
        padded_block_tables = [
            _pad_to_max(block_table, max_num_blocks_per_seq)
            for block_table in generation_block_tables]
Woosuk Kwon's avatar
Woosuk Kwon committed
155
        block_tables_tensor = torch.tensor(
156
            padded_block_tables, dtype=torch.int, device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158

        input_metadata = InputMetadata(
159
160
            seq_groups=seq_groups,
            seq_logprobs=seq_logprobs,
Woosuk Kwon's avatar
Woosuk Kwon committed
161
162
163
164
165
166
167
168
169
170
171
            prompt_lens=prompt_lens,
            slot_mapping=slot_mapping_tensor,
            context_lens=context_lens_tensor,
            max_context_len=max_context_len,
            block_tables=block_tables_tensor,
        )
        return tokens_tensor, positions_tensor, input_metadata

    @torch.inference_mode()
    def execute_stage(
        self,
172
        input_seq_groups: List[SequenceGroupInputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
175
176
        blocks_to_copy: Dict[int, List[int]],
    ) -> Dict[int, SequenceOutputs]:
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # Issue cache operations.
        command_issued = False
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
            command_issued = True
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
            command_issued = True
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
            command_issued = True

        if command_issued:
            cache_events = self.cache_events
        else:
            cache_events = None

        # Prepare input tensors.
        input_tokens, input_positions, input_metadata = self.prepare_inputs(
196
            input_seq_groups)
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
200
201

        # Execute the model.
        output = self.model(
            input_ids=input_tokens,
            positions=input_positions,
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
202
            kv_caches=self.gpu_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
206
207
208
209
210
211
212
213
214
            input_metadata=input_metadata,
            cache_events=cache_events,
        )
        return output


def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
    return x + [0] * ((-len(x)) % multiple_of)


def _pad_to_max(x: List[int], max_len: int) -> List[int]:
    return x + [0] * (max_len - len(x))