medusa.py 2.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6

import torch
import torch.nn as nn

7
from vllm.config import VllmConfig
8
9
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
10
from vllm.model_executor.model_loader import get_model
11
from vllm.model_executor.models.interfaces import is_mixture_of_experts
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.v1.sample.metadata import SamplingMetadata

# Initialize logger
logger = init_logger(__name__)


class MedusaProposer:
    """
    Medusa proposer class for generating token sequences
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        # Save config parameters
        self.vllm_config = vllm_config
        self.device = device
31
32
33
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
        self.hidden_size = (
            vllm_config.speculative_config.draft_model_config.get_hidden_size()
34
35
36
37
38
39
40
        )
        self.dtype = vllm_config.model_config.dtype

    def propose(
        self,
        target_hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
41
    ) -> torch.Tensor:
42
43
        # Generate blocks and compute logits
        blocks = self.model(target_hidden_states)
44
        logits = self.model.compute_logits(blocks)
45

46
47
48
49
50
        # Compute argmax for each Medusa head and stack into a single tensor
        # Shape: [batch_size, num_heads]
        draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)

        return draft_tokens
51
52

    def load_model(self, target_model: nn.Module) -> None:
53
        from vllm.compilation.backends import set_model_tag
54

55
        with set_model_tag("medusa_head"):
56
57
58
59
            self.model = get_model(
                vllm_config=self.vllm_config,
                model_config=self.vllm_config.speculative_config.draft_model_config,
            )
60
61
62
63
        assert not (
            is_mixture_of_experts(self.model)
            and self.vllm_config.parallel_config.enable_eplb
        ), "EPLB for Medusa is not supported"
64
65
66

    @torch.inference_mode()
    def dummy_run(self, num_tokens: int) -> None:
67
68
69
70
71
72
        hidden_states = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device,
        )
        with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
73
            self.model(hidden_states)