medusa_worker.py 14.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import weakref
5
from typing import List, Optional, Set, Tuple, Dict
6
7

import torch
8
import torch.nn.functional as F
9
10

from vllm.model_executor import SamplingMetadata
11
12
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
13
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
14
15
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
16
from vllm.worker.worker_base import DelegateWorkerBase
17

18
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
19
from vllm.distributed import broadcast_tensor_dict
20
from vllm.worker.worker_base import WorkerWrapperBase
21
22


23
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
24
25


26
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
27
28
29
30
    """Worker for Medusa.
    """

    def __init__(self, *args, **kwargs):
31
32
33
        # skip lora config in medusa
        kwargs_copy = kwargs.copy()
        kwargs_copy['lora_config'] = None
zhuwenwen's avatar
zhuwenwen committed
34
        DelegateWorkerBase.__init__(self, *args, **kwargs_copy)
35
        # Lazy initialization list.
36
        self._proposer: SpeculativeProposer
37
38
        self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')

39
40

    def init_device(self):
41
        self.worker.init_device()
42
43
44
45
46
47
    
    def load_model(self):
        super().load_model()

        # get medusa choices and generate medusa_buffers
        self.medusa_buffers = None
48
        if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
49
            self.medusa_choices = self.model_runner.model.medusa_choices
50
51
52
53
            if self.medusa_choices is not None:
                self.medusa_buffers = self.generate_medusa_buffers(
                    self.medusa_choices, device=self.device
                )
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

        if self.medusa_buffers is None:
            self._proposer = Top1Proposer(
                weakref.proxy(self),  # type: ignore[arg-type]
                self.device,
                self.vocab_size,
                max_proposal_len=self.max_model_len,
            )
        else:
            self._proposer = TreeStyleProposer(
                weakref.proxy(self),  # type: ignore[arg-type]
                self.device,
                self.vocab_size,
                self.medusa_buffers,
                max_proposal_len=self.max_model_len,                
            )
70
71
72
73
74


    def set_include_gpu_probs_tensor(self):
        pass

75
76
77
    def set_should_modify_greedy_probs_inplace(self):
        pass

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
    ) -> Dict[str, torch.Tensor]:
        
        seq_group_metadata_list = execute_model_req.seq_group_metadata_list

        seq_lens, query_lens = self._prepare_input_tensors(
            seq_group_metadata_list)

        generators = self.model_runner.get_generators(
            execute_model_req.finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list, seq_lens, query_lens, self.device,
            self.model_runner.pin_memory, generators)
        
        sample_indices_list = []
        for seq_group in sampling_metadata.seq_groups:
            sample_indices_list.append(seq_group.sample_indices)
        
        previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
        previous_logits = execute_model_req.previous_logits.logits if \
            execute_model_req.previous_logits is not None else None

        tensor_dict = {
            "previous_hidden_states": previous_hidden_states,
            "previous_logits": previous_logits,
            "sample_indices_list": sample_indices_list,
            "seq_lens": seq_lens
        }

        if self.do_metadata_broadcast:
            broadcast_tensor_dict(tensor_dict, src=0)

        return tensor_dict
    
    def _get_worker_input_from_broadcast(
        self
    ) -> Optional[Dict[str, torch.Tensor]]:
        """ Get the worker input from the broadcasted tensor dict. """
        assert self.do_metadata_broadcast
        assert not self.is_driver_worker
        broadcast_data = broadcast_tensor_dict(src=0)

        return broadcast_data

123
124
125
126
127
    @torch.inference_mode()
    def sampler_output(
        self,
        execute_model_req: ExecuteModelRequest,
        sample_len: int,
128
129
        # Unused parameter.
        seq_ids_with_bonus_token_in_last_step: Set[int],
130
131
132
133
134
135
136
137
138
    ) -> Tuple[List[SamplerOutput], bool]:
        """Run the model forward pass to generate sample_len future tokens.
        Returns the list of sampler output, one per layer, along with indicator
        of whether torch tensor in sampler output need to be transposed in
        latter sampler_output_to_torch logic.

        For medusa worker, this indicator shall be False.
        """
        self._raise_if_unsupported(execute_model_req)
139
140
141
142
143
144
145
        
        if self.is_driver_worker:
            tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
        else:
            tensor_dict = self._get_worker_input_from_broadcast()
            if tensor_dict is None:
                raise ValueError("Can not get inputs of medusa worker!!!")
146
147

        model_outputs = self.model_runner.model.generate_proposals(
148
149
150
            previous_hidden_states=tensor_dict["previous_hidden_states"],
            sample_indices_list=tensor_dict["sample_indices_list"],
            previous_logits=tensor_dict["previous_logits"],
151
152
153
            medusa_buffers=self.medusa_buffers)
        
        # create tree attn masks
王敏's avatar
王敏 committed
154
        if self.is_driver_worker and self.medusa_buffers is not None:
155
            seq_lens = tensor_dict["seq_lens"]
156
157
158
159
160
161
162
163
164
165
166
167
            max_context_len = max(seq_lens)
            for sampler_output, seq_len in zip(model_outputs, seq_lens):
                context_len = seq_len
                attn_masks = self.medusa_buffers['tree_attn_masks']
                left_mask = torch.ones(attn_masks.shape[0], context_len,
                                            dtype=attn_masks.dtype,
                                            device=attn_masks.device)
                attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
                right_pad = max_context_len - context_len
                if right_pad > 0:
                    attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
                sampler_output.tree_attn_masks = attn_masks
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

        return model_outputs, False

    def _prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
    ) -> Tuple[List[int], List[int]]:
        if not seq_group_metadata_list:
            return [], []

        seq_lens: List[int] = []
        query_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            is_prompt = seq_group_metadata.is_prompt

            for seq_data in seq_group_metadata.seq_data.values():
                seq_data_len = seq_data.get_len()
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                    seq_len = min(
                        seq_data_len,
                        context_len + seq_group_metadata.token_chunk_size)
                    seq_lens.append(seq_len)
                    query_lens.append(seq_len - context_len)
                else:
194
                    # first step of tree decoding need to ignore first token
195
                    if self.medusa_buffers is not None and seq_data.get_first_step_flag():
196
                        seq_data_len -= 1
197
198
199
200
201
202
203
204
                    seq_lens.append(seq_data_len)
                    query_lens.append(1)

        return seq_lens, query_lens

    def get_spec_proposals(
        self,
        execute_model_req: ExecuteModelRequest,
205
        seq_ids_with_bonus_token_in_last_step: Set[int],
206
207
208
209
210
    ) -> SpeculativeProposals:
        """Produce speculations given an input batch of sequences. The number of
        speculative tokens per sequence is determined by max_proposal_len.
        """

211
212
        return self._proposer.get_spec_proposals(
            execute_model_req, seq_ids_with_bonus_token_in_last_step)
213
214
215
216
217
218
219
220

    def _raise_if_unsupported(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> None:
        """MedusaWorker does not yet implement support for cache swap
        operations or beam search.
        """
221
222
223
        if execute_model_req is None:
            return None

224
225
226
227
228
229
230
231
232
233
234
235
236
237
        if any([
                execute_model_req.blocks_to_swap_in,
                execute_model_req.blocks_to_swap_out,
                execute_model_req.blocks_to_copy
        ]):
            raise NotImplementedError(
                "MedusaWorker does not support cache operations")

        if any(
                len(seq_group_metadata.seq_data.keys()) != 1
                for seq_group_metadata in
                execute_model_req.seq_group_metadata_list):
            raise NotImplementedError(
                "MedusaWorker does not support beam search.")
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        
    def pad_path(self, path, length, pad_value=-2):
        """
        Pad the given path list with a specific value up to a specified length.
        
        Parameters:
        - path (list): The original list that needs padding.
        - length (int): The desired length of the padded list.
        - pad_value (optional, default=-2): The value to use for padding.
        
        Returns:
        - list: A new list based on the original path but padded to the desired length.
        
        Example:
        >>> pad_path([1,2,3], 5)
        [1, 2, 3, -2, -2]
        
        Note:
        If the given path is already longer than the specified length, 
        then no padding occurs, and the original path is returned.
        """
        
        # Calculate the number of padding values needed by subtracting the length
        # of the path from the desired length.
        # Append the padding values to the original path and return the new list.
        return path + [pad_value] * (length - len(path))

    def generate_medusa_buffers(self, medusa_choices, device="cuda"):
        """
        Generate buffers for the Medusa structure based on the provided choices.
        
        Parameters:
        - medusa_choices (list): A nested list representing tree in the Medusa structure.
        - device (str): Device to which the tensors should be moved. Default is "cuda".
        
        Returns:
        - dict: A dictionary containing buffers related to the Medusa structure.
        """

        # Sort the medusa_choices based on their lengths and then their values
        sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
        medusa_len = len(sorted_medusa_choices) + 1

        # Initialize depth_counts to keep track of how many choices have a particular depth
        depth_counts = []
        prev_depth = 0
        for path in sorted_medusa_choices:
            depth = len(path)
            if depth != prev_depth:
                depth_counts.append(0)
            depth_counts[depth - 1] += 1
            prev_depth = depth
        
        # Create the attention mask for Medusa
        medusa_attn_mask = torch.eye(medusa_len, medusa_len)
        medusa_attn_mask[:, 0] = 1
        start = 0
        for i in range(len(depth_counts)):
            for j in range(depth_counts[i]):
                cur_medusa_choice = sorted_medusa_choices[start + j]
                # retrieve ancestor position
                if len(cur_medusa_choice) == 1:
                    continue
                ancestor_idx = []
                for c in range(len(cur_medusa_choice) - 1):
                    ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
                medusa_attn_mask[j + start + 1, ancestor_idx] = 1
            start += depth_counts[i]

        # Generate tree indices for the Medusa structure
        medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
        medusa_tree_indices[0] = 0
        start = 0
        for i in range(len(depth_counts)):
            for j in range(depth_counts[i]):
                cur_medusa_choice = sorted_medusa_choices[start + j]
                medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
            start += depth_counts[i]

        # Generate position IDs for the Medusa structure
        medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
        start = 0
        for i in range(len(depth_counts)):
            medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
            start += depth_counts[i]

        # Generate retrieval indices for Medusa structure verification
        retrieve_indices_nest = []
        retrieve_paths = []
        for i in range(len(sorted_medusa_choices)):
            cur_medusa_choice = sorted_medusa_choices[-i-1]
            retrieve_indice = []
            if cur_medusa_choice in retrieve_paths:
                continue
            else:
                for c in range(len(cur_medusa_choice)):
                    retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
                    retrieve_paths.append(cur_medusa_choice[:c+1])
            retrieve_indices_nest.append(retrieve_indice)
        max_length = max([len(x) for x in retrieve_indices_nest])
        retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
        retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
        retrieve_indices = retrieve_indices + 1
        retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)

        # Aggregate the generated buffers into a dictionary
        medusa_buffers = {
            "tree_attn_masks": medusa_attn_mask.int(),
            "tree_indices": medusa_tree_indices,
            "tree_position_ids": medusa_position_ids,
            "retrieve_indices": retrieve_indices,
            }
        
        # Move the tensors in the dictionary to the specified device
        medusa_buffers = {
353
            k: v.clone().to(device)
354
355
356
357
358
            if isinstance(v, torch.Tensor)
            else torch.tensor(v,  device=device)
            for k, v in medusa_buffers.items()
        }
        return medusa_buffers