medusa_worker.py 14 KB
Newer Older
1
import os
2
import weakref
3
from typing import List, Optional, Set, Tuple, Dict
4
5

import torch
6
import torch.nn.functional as F
7
8

from vllm.model_executor import SamplingMetadata
9
10
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
11
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
12
13
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
14
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
15
from vllm.worker.worker import Worker
16
from vllm.distributed import broadcast_tensor_dict
17

18
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
19
20
21
22
23
24

class MedusaWorker(NonLLMProposerWorkerBase, Worker):
    """Worker for Medusa.
    """

    def __init__(self, *args, **kwargs):
25
26
27
28
29
        # skip lora config in medusa
        kwargs_copy = kwargs.copy()
        kwargs_copy['lora_config'] = None

        super().__init__(*args, **kwargs_copy)
30
31

        # Lazy initialization list.
32
        self._proposer: SpeculativeProposer
33
34
        self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')

35
36
37

    def init_device(self):
        super().init_device()
38
39
40
41
42
43
    
    def load_model(self):
        super().load_model()

        # get medusa choices and generate medusa_buffers
        self.medusa_buffers = None
44
        if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
45
            self.medusa_choices = self.model_runner.model.medusa_choices
46
47
48
49
            if self.medusa_choices is not None:
                self.medusa_buffers = self.generate_medusa_buffers(
                    self.medusa_choices, device=self.device
                )
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        if self.medusa_buffers is None:
            self._proposer = Top1Proposer(
                weakref.proxy(self),  # type: ignore[arg-type]
                self.device,
                self.vocab_size,
                max_proposal_len=self.max_model_len,
            )
        else:
            self._proposer = TreeStyleProposer(
                weakref.proxy(self),  # type: ignore[arg-type]
                self.device,
                self.vocab_size,
                self.medusa_buffers,
                max_proposal_len=self.max_model_len,                
            )
66
67
68
69
70


    def set_include_gpu_probs_tensor(self):
        pass

71
72
73
    def set_should_modify_greedy_probs_inplace(self):
        pass

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
    ) -> Dict[str, torch.Tensor]:
        
        seq_group_metadata_list = execute_model_req.seq_group_metadata_list

        seq_lens, query_lens = self._prepare_input_tensors(
            seq_group_metadata_list)

        generators = self.model_runner.get_generators(
            execute_model_req.finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list, seq_lens, query_lens, self.device,
            self.model_runner.pin_memory, generators)
        
        sample_indices_list = []
        for seq_group in sampling_metadata.seq_groups:
            sample_indices_list.append(seq_group.sample_indices)
        
        previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
        previous_logits = execute_model_req.previous_logits.logits if \
            execute_model_req.previous_logits is not None else None

        tensor_dict = {
            "previous_hidden_states": previous_hidden_states,
            "previous_logits": previous_logits,
            "sample_indices_list": sample_indices_list,
            "seq_lens": seq_lens
        }

        if self.do_metadata_broadcast:
            broadcast_tensor_dict(tensor_dict, src=0)

        return tensor_dict
    
    def _get_worker_input_from_broadcast(
        self
    ) -> Optional[Dict[str, torch.Tensor]]:
        """ Get the worker input from the broadcasted tensor dict. """
        assert self.do_metadata_broadcast
        assert not self.is_driver_worker
        broadcast_data = broadcast_tensor_dict(src=0)

        return broadcast_data

119
120
121
122
123
    @torch.inference_mode()
    def sampler_output(
        self,
        execute_model_req: ExecuteModelRequest,
        sample_len: int,
124
125
        # Unused parameter.
        seq_ids_with_bonus_token_in_last_step: Set[int],
126
127
128
129
130
131
132
133
134
    ) -> Tuple[List[SamplerOutput], bool]:
        """Run the model forward pass to generate sample_len future tokens.
        Returns the list of sampler output, one per layer, along with indicator
        of whether torch tensor in sampler output need to be transposed in
        latter sampler_output_to_torch logic.

        For medusa worker, this indicator shall be False.
        """
        self._raise_if_unsupported(execute_model_req)
135
136
137
138
139
140
141
        
        if self.is_driver_worker:
            tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
        else:
            tensor_dict = self._get_worker_input_from_broadcast()
            if tensor_dict is None:
                raise ValueError("Can not get inputs of medusa worker!!!")
142
143

        model_outputs = self.model_runner.model.generate_proposals(
144
145
146
            previous_hidden_states=tensor_dict["previous_hidden_states"],
            sample_indices_list=tensor_dict["sample_indices_list"],
            previous_logits=tensor_dict["previous_logits"],
147
148
149
            medusa_buffers=self.medusa_buffers)
        
        # create tree attn masks
王敏's avatar
王敏 committed
150
        if self.is_driver_worker and self.medusa_buffers is not None:
151
            seq_lens = tensor_dict["seq_lens"]
152
153
154
155
156
157
158
159
160
161
162
163
            max_context_len = max(seq_lens)
            for sampler_output, seq_len in zip(model_outputs, seq_lens):
                context_len = seq_len
                attn_masks = self.medusa_buffers['tree_attn_masks']
                left_mask = torch.ones(attn_masks.shape[0], context_len,
                                            dtype=attn_masks.dtype,
                                            device=attn_masks.device)
                attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
                right_pad = max_context_len - context_len
                if right_pad > 0:
                    attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
                sampler_output.tree_attn_masks = attn_masks
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

        return model_outputs, False

    def _prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
    ) -> Tuple[List[int], List[int]]:
        if not seq_group_metadata_list:
            return [], []

        seq_lens: List[int] = []
        query_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            is_prompt = seq_group_metadata.is_prompt

            for seq_data in seq_group_metadata.seq_data.values():
                seq_data_len = seq_data.get_len()
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                    seq_len = min(
                        seq_data_len,
                        context_len + seq_group_metadata.token_chunk_size)
                    seq_lens.append(seq_len)
                    query_lens.append(seq_len - context_len)
                else:
190
                    # first step of tree decoding need to ignore first token
191
                    if self.medusa_buffers is not None and seq_data.get_first_step_flag():
192
                        seq_data_len -= 1
193
194
195
196
197
198
199
200
                    seq_lens.append(seq_data_len)
                    query_lens.append(1)

        return seq_lens, query_lens

    def get_spec_proposals(
        self,
        execute_model_req: ExecuteModelRequest,
201
        seq_ids_with_bonus_token_in_last_step: Set[int],
202
203
204
205
206
    ) -> SpeculativeProposals:
        """Produce speculations given an input batch of sequences. The number of
        speculative tokens per sequence is determined by max_proposal_len.
        """

207
208
        return self._proposer.get_spec_proposals(
            execute_model_req, seq_ids_with_bonus_token_in_last_step)
209
210
211
212
213
214
215
216

    def _raise_if_unsupported(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> None:
        """MedusaWorker does not yet implement support for cache swap
        operations or beam search.
        """
217
218
219
        if execute_model_req is None:
            return None

220
221
222
223
224
225
226
227
228
229
230
231
232
233
        if any([
                execute_model_req.blocks_to_swap_in,
                execute_model_req.blocks_to_swap_out,
                execute_model_req.blocks_to_copy
        ]):
            raise NotImplementedError(
                "MedusaWorker does not support cache operations")

        if any(
                len(seq_group_metadata.seq_data.keys()) != 1
                for seq_group_metadata in
                execute_model_req.seq_group_metadata_list):
            raise NotImplementedError(
                "MedusaWorker does not support beam search.")
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        
    def pad_path(self, path, length, pad_value=-2):
        """
        Pad the given path list with a specific value up to a specified length.
        
        Parameters:
        - path (list): The original list that needs padding.
        - length (int): The desired length of the padded list.
        - pad_value (optional, default=-2): The value to use for padding.
        
        Returns:
        - list: A new list based on the original path but padded to the desired length.
        
        Example:
        >>> pad_path([1,2,3], 5)
        [1, 2, 3, -2, -2]
        
        Note:
        If the given path is already longer than the specified length, 
        then no padding occurs, and the original path is returned.
        """
        
        # Calculate the number of padding values needed by subtracting the length
        # of the path from the desired length.
        # Append the padding values to the original path and return the new list.
        return path + [pad_value] * (length - len(path))

    def generate_medusa_buffers(self, medusa_choices, device="cuda"):
        """
        Generate buffers for the Medusa structure based on the provided choices.
        
        Parameters:
        - medusa_choices (list): A nested list representing tree in the Medusa structure.
        - device (str): Device to which the tensors should be moved. Default is "cuda".
        
        Returns:
        - dict: A dictionary containing buffers related to the Medusa structure.
        """

        # Sort the medusa_choices based on their lengths and then their values
        sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
        medusa_len = len(sorted_medusa_choices) + 1

        # Initialize depth_counts to keep track of how many choices have a particular depth
        depth_counts = []
        prev_depth = 0
        for path in sorted_medusa_choices:
            depth = len(path)
            if depth != prev_depth:
                depth_counts.append(0)
            depth_counts[depth - 1] += 1
            prev_depth = depth
        
        # Create the attention mask for Medusa
        medusa_attn_mask = torch.eye(medusa_len, medusa_len)
        medusa_attn_mask[:, 0] = 1
        start = 0
        for i in range(len(depth_counts)):
            for j in range(depth_counts[i]):
                cur_medusa_choice = sorted_medusa_choices[start + j]
                # retrieve ancestor position
                if len(cur_medusa_choice) == 1:
                    continue
                ancestor_idx = []
                for c in range(len(cur_medusa_choice) - 1):
                    ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
                medusa_attn_mask[j + start + 1, ancestor_idx] = 1
            start += depth_counts[i]

        # Generate tree indices for the Medusa structure
        medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
        medusa_tree_indices[0] = 0
        start = 0
        for i in range(len(depth_counts)):
            for j in range(depth_counts[i]):
                cur_medusa_choice = sorted_medusa_choices[start + j]
                medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
            start += depth_counts[i]

        # Generate position IDs for the Medusa structure
        medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
        start = 0
        for i in range(len(depth_counts)):
            medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
            start += depth_counts[i]

        # Generate retrieval indices for Medusa structure verification
        retrieve_indices_nest = []
        retrieve_paths = []
        for i in range(len(sorted_medusa_choices)):
            cur_medusa_choice = sorted_medusa_choices[-i-1]
            retrieve_indice = []
            if cur_medusa_choice in retrieve_paths:
                continue
            else:
                for c in range(len(cur_medusa_choice)):
                    retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
                    retrieve_paths.append(cur_medusa_choice[:c+1])
            retrieve_indices_nest.append(retrieve_indice)
        max_length = max([len(x) for x in retrieve_indices_nest])
        retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
        retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
        retrieve_indices = retrieve_indices + 1
        retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)

        # Aggregate the generated buffers into a dictionary
        medusa_buffers = {
            "tree_attn_masks": medusa_attn_mask.int(),
            "tree_indices": medusa_tree_indices,
            "tree_position_ids": medusa_position_ids,
            "retrieve_indices": retrieve_indices,
            }
        
        # Move the tensors in the dictionary to the specified device
        medusa_buffers = {
349
            k: v.clone().to(device)
350
351
352
353
354
            if isinstance(v, torch.Tensor)
            else torch.tensor(v,  device=device)
            for k, v in medusa_buffers.items()
        }
        return medusa_buffers