medusa_worker.py 13.8 KB
Newer Older
1
import weakref
2
from typing import List, Optional, Set, Tuple, Dict
3
4

import torch
5
import torch.nn.functional as F
6
7

from vllm.model_executor import SamplingMetadata
8
9
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
10
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
11
12
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
13
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
14
from vllm.worker.worker import Worker
15
from vllm.distributed import broadcast_tensor_dict
16

17
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
18
19
20
21
22
23

class MedusaWorker(NonLLMProposerWorkerBase, Worker):
    """Worker for Medusa.
    """

    def __init__(self, *args, **kwargs):
24
25
26
27
28
        # skip lora config in medusa
        kwargs_copy = kwargs.copy()
        kwargs_copy['lora_config'] = None

        super().__init__(*args, **kwargs_copy)
29
30

        # Lazy initialization list.
31
        self._proposer: SpeculativeProposer
32
33
34

    def init_device(self):
        super().init_device()
35
36
37
38
39
40
41
42
    
    def load_model(self):
        super().load_model()

        # get medusa choices and generate medusa_buffers
        self.medusa_buffers = None
        if hasattr(self.model_runner.model, 'medusa_choices'):
            self.medusa_choices = self.model_runner.model.medusa_choices
43
44
45
46
            if self.medusa_choices is not None:
                self.medusa_buffers = self.generate_medusa_buffers(
                    self.medusa_choices, device=self.device
                )
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

        if self.medusa_buffers is None:
            self._proposer = Top1Proposer(
                weakref.proxy(self),  # type: ignore[arg-type]
                self.device,
                self.vocab_size,
                max_proposal_len=self.max_model_len,
            )
        else:
            self._proposer = TreeStyleProposer(
                weakref.proxy(self),  # type: ignore[arg-type]
                self.device,
                self.vocab_size,
                self.medusa_buffers,
                max_proposal_len=self.max_model_len,                
            )
63
64
65
66
67


    def set_include_gpu_probs_tensor(self):
        pass

68
69
70
    def set_should_modify_greedy_probs_inplace(self):
        pass

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
    ) -> Dict[str, torch.Tensor]:
        
        seq_group_metadata_list = execute_model_req.seq_group_metadata_list

        seq_lens, query_lens = self._prepare_input_tensors(
            seq_group_metadata_list)

        generators = self.model_runner.get_generators(
            execute_model_req.finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list, seq_lens, query_lens, self.device,
            self.model_runner.pin_memory, generators)
        
        sample_indices_list = []
        for seq_group in sampling_metadata.seq_groups:
            sample_indices_list.append(seq_group.sample_indices)
        
        previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
        previous_logits = execute_model_req.previous_logits.logits if \
            execute_model_req.previous_logits is not None else None

        tensor_dict = {
            "previous_hidden_states": previous_hidden_states,
            "previous_logits": previous_logits,
            "sample_indices_list": sample_indices_list,
            "seq_lens": seq_lens
        }

        if self.do_metadata_broadcast:
            broadcast_tensor_dict(tensor_dict, src=0)

        return tensor_dict
    
    def _get_worker_input_from_broadcast(
        self
    ) -> Optional[Dict[str, torch.Tensor]]:
        """ Get the worker input from the broadcasted tensor dict. """
        assert self.do_metadata_broadcast
        assert not self.is_driver_worker
        broadcast_data = broadcast_tensor_dict(src=0)

        return broadcast_data

116
117
118
119
120
    @torch.inference_mode()
    def sampler_output(
        self,
        execute_model_req: ExecuteModelRequest,
        sample_len: int,
121
122
        # Unused parameter.
        seq_ids_with_bonus_token_in_last_step: Set[int],
123
124
125
126
127
128
129
130
131
    ) -> Tuple[List[SamplerOutput], bool]:
        """Run the model forward pass to generate sample_len future tokens.
        Returns the list of sampler output, one per layer, along with indicator
        of whether torch tensor in sampler output need to be transposed in
        latter sampler_output_to_torch logic.

        For medusa worker, this indicator shall be False.
        """
        self._raise_if_unsupported(execute_model_req)
132
133
134
135
136
137
138
        
        if self.is_driver_worker:
            tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
        else:
            tensor_dict = self._get_worker_input_from_broadcast()
            if tensor_dict is None:
                raise ValueError("Can not get inputs of medusa worker!!!")
139
140

        model_outputs = self.model_runner.model.generate_proposals(
141
142
143
            previous_hidden_states=tensor_dict["previous_hidden_states"],
            sample_indices_list=tensor_dict["sample_indices_list"],
            previous_logits=tensor_dict["previous_logits"],
144
145
146
147
            medusa_buffers=self.medusa_buffers)
        
        # create tree attn masks
        if self.medusa_buffers is not None:
148
            seq_lens = tensor_dict["seq_lens"]
149
150
151
152
153
154
155
156
157
158
159
160
            max_context_len = max(seq_lens)
            for sampler_output, seq_len in zip(model_outputs, seq_lens):
                context_len = seq_len
                attn_masks = self.medusa_buffers['tree_attn_masks']
                left_mask = torch.ones(attn_masks.shape[0], context_len,
                                            dtype=attn_masks.dtype,
                                            device=attn_masks.device)
                attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
                right_pad = max_context_len - context_len
                if right_pad > 0:
                    attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
                sampler_output.tree_attn_masks = attn_masks
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        return model_outputs, False

    def _prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
    ) -> Tuple[List[int], List[int]]:
        if not seq_group_metadata_list:
            return [], []

        seq_lens: List[int] = []
        query_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            is_prompt = seq_group_metadata.is_prompt

            for seq_data in seq_group_metadata.seq_data.values():
                seq_data_len = seq_data.get_len()
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                    seq_len = min(
                        seq_data_len,
                        context_len + seq_group_metadata.token_chunk_size)
                    seq_lens.append(seq_len)
                    query_lens.append(seq_len - context_len)
                else:
187
                    # first step of tree decoding need to ignore first token
188
                    if self.medusa_buffers is not None and seq_data.get_first_step_flag():
189
                        seq_data_len -= 1
190
191
192
193
194
195
196
197
                    seq_lens.append(seq_data_len)
                    query_lens.append(1)

        return seq_lens, query_lens

    def get_spec_proposals(
        self,
        execute_model_req: ExecuteModelRequest,
198
        seq_ids_with_bonus_token_in_last_step: Set[int],
199
200
201
202
203
    ) -> SpeculativeProposals:
        """Produce speculations given an input batch of sequences. The number of
        speculative tokens per sequence is determined by max_proposal_len.
        """

204
205
        return self._proposer.get_spec_proposals(
            execute_model_req, seq_ids_with_bonus_token_in_last_step)
206
207
208
209
210
211
212
213

    def _raise_if_unsupported(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> None:
        """MedusaWorker does not yet implement support for cache swap
        operations or beam search.
        """
214
215
216
        if execute_model_req is None:
            return None

217
218
219
220
221
222
223
224
225
226
227
228
229
230
        if any([
                execute_model_req.blocks_to_swap_in,
                execute_model_req.blocks_to_swap_out,
                execute_model_req.blocks_to_copy
        ]):
            raise NotImplementedError(
                "MedusaWorker does not support cache operations")

        if any(
                len(seq_group_metadata.seq_data.keys()) != 1
                for seq_group_metadata in
                execute_model_req.seq_group_metadata_list):
            raise NotImplementedError(
                "MedusaWorker does not support beam search.")
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        
    def pad_path(self, path, length, pad_value=-2):
        """
        Pad the given path list with a specific value up to a specified length.
        
        Parameters:
        - path (list): The original list that needs padding.
        - length (int): The desired length of the padded list.
        - pad_value (optional, default=-2): The value to use for padding.
        
        Returns:
        - list: A new list based on the original path but padded to the desired length.
        
        Example:
        >>> pad_path([1,2,3], 5)
        [1, 2, 3, -2, -2]
        
        Note:
        If the given path is already longer than the specified length, 
        then no padding occurs, and the original path is returned.
        """
        
        # Calculate the number of padding values needed by subtracting the length
        # of the path from the desired length.
        # Append the padding values to the original path and return the new list.
        return path + [pad_value] * (length - len(path))

    def generate_medusa_buffers(self, medusa_choices, device="cuda"):
        """
        Generate buffers for the Medusa structure based on the provided choices.
        
        Parameters:
        - medusa_choices (list): A nested list representing tree in the Medusa structure.
        - device (str): Device to which the tensors should be moved. Default is "cuda".
        
        Returns:
        - dict: A dictionary containing buffers related to the Medusa structure.
        """

        # Sort the medusa_choices based on their lengths and then their values
        sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
        medusa_len = len(sorted_medusa_choices) + 1

        # Initialize depth_counts to keep track of how many choices have a particular depth
        depth_counts = []
        prev_depth = 0
        for path in sorted_medusa_choices:
            depth = len(path)
            if depth != prev_depth:
                depth_counts.append(0)
            depth_counts[depth - 1] += 1
            prev_depth = depth
        
        # Create the attention mask for Medusa
        medusa_attn_mask = torch.eye(medusa_len, medusa_len)
        medusa_attn_mask[:, 0] = 1
        start = 0
        for i in range(len(depth_counts)):
            for j in range(depth_counts[i]):
                cur_medusa_choice = sorted_medusa_choices[start + j]
                # retrieve ancestor position
                if len(cur_medusa_choice) == 1:
                    continue
                ancestor_idx = []
                for c in range(len(cur_medusa_choice) - 1):
                    ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
                medusa_attn_mask[j + start + 1, ancestor_idx] = 1
            start += depth_counts[i]

        # Generate tree indices for the Medusa structure
        medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
        medusa_tree_indices[0] = 0
        start = 0
        for i in range(len(depth_counts)):
            for j in range(depth_counts[i]):
                cur_medusa_choice = sorted_medusa_choices[start + j]
                medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
            start += depth_counts[i]

        # Generate position IDs for the Medusa structure
        medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
        start = 0
        for i in range(len(depth_counts)):
            medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
            start += depth_counts[i]

        # Generate retrieval indices for Medusa structure verification
        retrieve_indices_nest = []
        retrieve_paths = []
        for i in range(len(sorted_medusa_choices)):
            cur_medusa_choice = sorted_medusa_choices[-i-1]
            retrieve_indice = []
            if cur_medusa_choice in retrieve_paths:
                continue
            else:
                for c in range(len(cur_medusa_choice)):
                    retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
                    retrieve_paths.append(cur_medusa_choice[:c+1])
            retrieve_indices_nest.append(retrieve_indice)
        max_length = max([len(x) for x in retrieve_indices_nest])
        retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
        retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
        retrieve_indices = retrieve_indices + 1
        retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)

        # Aggregate the generated buffers into a dictionary
        medusa_buffers = {
            "tree_attn_masks": medusa_attn_mask.int(),
            "tree_indices": medusa_tree_indices,
            "tree_position_ids": medusa_position_ids,
            "retrieve_indices": retrieve_indices,
            }
        
        # Move the tensors in the dictionary to the specified device
        medusa_buffers = {
346
            k: v.clone().to(device)
347
348
349
350
351
            if isinstance(v, torch.Tensor)
            else torch.tensor(v,  device=device)
            for k, v in medusa_buffers.items()
        }
        return medusa_buffers