# SPDX-License-Identifier: Apache-2.0 import os from typing import Iterable, List, Optional, Set, Tuple, Any, Dict import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm import _custom_ops as ops from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient) class ResidualBlock(nn.Module): def __init__(self, config: VllmConfig, hidden_size: int, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList([ nn.Linear(hidden_size, hidden_size, bias=getattr(config, "medusa_fc_bias", False)) for _ in range(num_layers) ]) self.act = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = x + self.act(layer(x)) return x class Medusa(nn.Module): """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 Reference implementation: https://github.com/FasterDecoding/Medusa Differences from reference implementation: 1. Currently this only supports generating proposals from top-1 tokens. 2. We have an optional token_map which reduces draft vocab to most frequently used tokens to give some additional speed-up by reducing sampling overhead. This is disabled unless the checkpoint file has explicit token_map tensor and config has an optional attribute truncated_vocab_size < vocab_size. To use this technique, one has to find the top-k most frequent tokens in target dataset and add that as a tensor in the draft checkpoint (using key token_map). Also, the draft config needs to have truncated_vocab_size (=k) as an attribute.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config super().__init__() self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.config = config self.blocks = nn.ModuleList([ ResidualBlock(config=config, hidden_size=self.config.hidden_size, num_layers=self.config.num_hidden_layers) for _ in range(self.config.num_heads) ]) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size self.medusa_choices = config.medusa_choices if getattr(config, "original_lm_head", False): self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.lm_heads = [ self.lm_head for _ in range(self.config.num_heads) ] else: self.lm_heads = nn.ModuleList([ ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) for _ in range(self.config.num_heads) ]) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale) # Token map is a idx to token mapping to reduce the vocab size for # the draft model. Using smaller vocab size for draft, containing # only most frequent tokens reduces the speculation overhead. This # doesn't affect the acceptance rate much and thus gives more speed # -up. By default, this is disabled and is only used if the EAGLE # checkpoint file has token_map tensor. self.token_map = None def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: return [block(hidden_states) for block in self.blocks] def compute_logits( self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]: logits_lst: List[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): #_logits = self.logits_processor(lm_head, hs, sampling_metadata) _logits = lm_head.quant_method.apply(lm_head, hs, bias=None) _logits = tensor_model_parallel_all_gather(_logits) if _logits is None: # _logits should only be None on rank > 0, in which case # it should remain true for every lm_head assert len(logits_lst) == 0 continue if self.token_map is None: logits_lst.append(_logits) else: logits_lst.append(-torch.inf * torch.ones( size=(*_logits.shape[:-1], self.orig_vocab_size), device=_logits.device, dtype=_logits.dtype)) logits_lst[-1][..., self.token_map] = _logits return logits_lst def sample( self, logits: List[torch.Tensor], sample_indices_list: List[List[int]], ) -> List[SamplerOutput]: logits = torch.stack(logits, dim=0).float() logprobs = torch.log_softmax(logits, dim=-1) token_ids = logits.argmax(-1) # support only top-1 for now probs = torch.softmax(logits, dim=-1) token_id_list = [] token_prob_list = [] token_logprob_list = [] for idx, sample_indices in enumerate(sample_indices_list): token_id_list.append(token_ids[:, sample_indices]) token_prob_list.append(probs[:, sample_indices]) token_logprob_list.append(logprobs[:, sample_indices]) outputs: List[Optional[SamplerOutput]] = [] for idx in range(len(sample_indices_list)): outputs.append( SamplerOutput( outputs=None, sampled_token_probs=token_prob_list[idx].squeeze(1), logprobs=token_logprob_list[idx].squeeze(1), sampled_token_ids=token_id_list[idx].squeeze(1), )) return outputs def medusa_sample( self, medusa_logits: List[torch.Tensor], sample_indices_list: List[List[int]], logits: torch.Tensor, medusa_buffers: Dict[str, Any] ) -> List[SamplerOutput]: batch_size = logits.shape[0] candidates_logit = torch.argmax(logits, dim=-1).view(batch_size, -1) # [batch_size, 1] medusa_logits = torch.stack(medusa_logits, dim=0) # [medusa_heads, batch_size, vocab_size] # Extract the TOPK candidates from the medusa logits. candidates_medusa_logits = torch.topk(medusa_logits, TOPK, dim=-1).indices # [medusa_heads, batch_size, TOPK] candidates_medusa_logits = candidates_medusa_logits.permute(1, 0 ,2) # [batch_size, medusa_heads, TOPK] candidates_medusa_logits = candidates_medusa_logits.reshape(batch_size, -1) # Combine the selected candidate from the original logits with the topk medusa logits. candidates = torch.cat([candidates_logit, candidates_medusa_logits], dim=-1) #[batch_size, 1+medusa_heads*TOPK] # Map the combined candidates to the tree indices to get tree candidates. tree_candidates = torch.index_select(candidates, dim=-1, index=medusa_buffers['tree_indices']) # [batch_size, choices] # Extend the tree candidates by appending a zero. tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((batch_size, 1), dtype=torch.long, device=tree_candidates.device)], dim=-1) # [batch_size, choices] # Retrieve the cartesian candidates using the retrieve indices. cart_candidates = tree_candidates_ext[:, medusa_buffers['retrieve_indices']] # [batch_size, retrieve_size, max_depth] token_id_list = [] cart_candidate_list = [] for sample_indices in sample_indices_list: token_id_list.append(tree_candidates[sample_indices, :]) cart_candidate_list.append(cart_candidates[sample_indices, :]) outputs: List[Optional[SamplerOutput]] = [] for idx in range(len(sample_indices_list)): outputs.append( SamplerOutput( outputs=None, sampled_token_ids=token_id_list[idx].squeeze(1), cart_candidates=cart_candidate_list[idx] )) return outputs def generate_proposals( self, previous_hidden_states: torch.Tensor, sample_indices_list: List[List[int]], previous_logits: torch.Tensor=None, medusa_buffers: Dict[str, Any]=None ) -> List[SamplerOutput]: if medusa_buffers is None: return self.sample( logits=self.compute_logits( hidden_states=self.forward(previous_hidden_states) ), sample_indices_list=sample_indices_list, ) else: return self.medusa_sample(medusa_logits=self.compute_logits( hidden_states=self.forward(previous_hidden_states) ), sample_indices_list=sample_indices_list, logits=previous_logits, medusa_buffers=medusa_buffers) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() weights_map = {} for name, loaded_weight in weights: name = name.replace("medusa_heads.", "") if name == "token_map": if self.truncated_vocab_size < self.orig_vocab_size: self.token_map = nn.Parameter(loaded_weight, requires_grad=False) elif name in params_dict: weights_map[name] = loaded_weight elif (getattr(self.config, "original_lm_head", False) and name == "lm_heads.0.weight"): weights_map["lm_head.weight"] = loaded_weight for name, loaded_weight in weights_map.items(): if "lm_head" in name and self.token_map is not None and\ loaded_weight.shape[0] > self.token_map.shape[0]: loaded_weight = loaded_weight[self.token_map] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) if self.use_llama_nn and os.environ['LM_NN'] == '1' and "lm_head" in name: _weight = torch.zeros_like(param.data) ori_shape =_weight.shape ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1]) param.data.copy_(_weight) param.data=param.data.reshape(ori_shape[1],-1) if self.token_map is not None: self.token_map.to(device=self.lm_heads[0].weight.device) assert (self.truncated_vocab_size == self.orig_vocab_size) or (self.token_map is not None) return loaded_params