import os import ast from pathlib import Path from typing import Iterable, List, Optional, Tuple, Union from addict import Dict import yaml import argparse import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter from transformers import PretrainedConfig from safetensors.torch import save_model, safe_open from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.weight' TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'medusa_head.{}.1.weight' TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.bias' VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight' VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias' VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE = 'lm_heads.{}.weight' def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to class MedusaConfig(PretrainedConfig): model_type = "medusa" def __init__(self, hidden_size: int = 4096, vocab_size: int = 32001, num_heads: int = 5, num_hidden_layers: int = 1, max_paths: int = 64, topk: int = 10, truncated_vocab_size: Optional[int] = None, **kwargs): self.hidden_size = hidden_size self.vocab_size = vocab_size self.num_heads = num_heads self.num_hidden_layers = num_hidden_layers self.max_paths = max_paths self.topk = topk self.max_seq_len = int(2**20) self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\ else truncated_vocab_size if "architectures" not in kwargs: kwargs["architectures"] = ["MedusaModel"] super().__init__(**kwargs) @property def num_attention_heads(self): return 0 @property def num_lookahead_tokens(self): return self.num_heads @num_lookahead_tokens.setter def num_lookahead_tokens(self, num_lookahead_tokens: int): self.num_heads = num_lookahead_tokens class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to make sure it is divisible by the number of model parallel GPUs. In order to support various loading methods, we ensure that LoRA-added embeddings are always at the end of TP-sharded tensors. In other words, we shard base embeddings and LoRA embeddings separately (both padded), and place them in the same tensor. In this example, we will have the original vocab size = 1010, added vocab size = 16 and padding to 64. Therefore, the total vocab size with padding will be 1088 (because we first pad 1010 to 1024, add 16, and then pad to 1088). Therefore, the tensor format looks like the following: TP1, rank 0 (no sharding): |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | TP2, rank 0: |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | TP2, rank 1: |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | Args: num_embeddings: vocabulary size. embedding_dim: size of hidden state. params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). padding_size: padding size for the vocabulary. quant_config: quant config for the layer prefix: full name of the layer in the state dict """ # noqa: E501 def __init__(self, num_embeddings: int, embedding_dim: int, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.num_embeddings = num_embeddings self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size) self.num_embeddings_padded = pad_vocab_size( self.org_vocab_size_padded + num_added_embeddings, self.padding_size) assert self.org_vocab_size_padded <= self.num_embeddings_padded self.embedding_dim = embedding_dim linear_method = None if quant_config is not None: linear_method = quant_config.get_quant_method(self, prefix=prefix) if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method: QuantizeMethodBase = linear_method if params_dtype is None: params_dtype = torch.get_default_dtype() self.linear_method.create_weights(self, self.embedding_dim, [self.num_embeddings_padded], self.embedding_dim, self.num_embeddings_padded, params_dtype=params_dtype, weight_loader=self.weight_loader) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param.data.shape == loaded_weight.shape param.data.copy_(loaded_weight) def forward(self, input_): masked_input = input_ # Get the embeddings. output = F.embedding(masked_input.long(), self.weight) return output class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head. Output logits weight matrices used in the Sampler. The weight and bias tensors are padded to make sure they are divisible by the number of model parallel GPUs. Args: num_embeddings: vocabulary size. embedding_dim: size of hidden state. bias: whether to use bias. params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). padding_size: padding size for the vocabulary. """ def __init__(self, num_embeddings: int, embedding_dim: int, bias: bool = False, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__(num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, "weight_loader": self.weight_loader, }) else: self.register_parameter("bias", None) def forward(self, input_): del input_ raise RuntimeError("LMHead's weights should be used in the sampler.") class ResidualBlock(nn.Module): def __init__(self, hidden_size: int, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList([ nn.Linear(hidden_size, hidden_size) for _ in range(num_layers) ]) self.act = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = x + self.act(layer(x)) return x class Medusa(nn.Module): def __init__(self, config: MedusaConfig, **_) -> None: super().__init__() self.config = config self.blocks = nn.ModuleList([ ResidualBlock(hidden_size=self.config.hidden_size, num_layers=self.config.num_hidden_layers) for _ in range(self.config.num_heads) ]) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size self.lm_heads = nn.ModuleList([ ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) for _ in range(self.config.num_heads) ]) logit_scale = getattr(config, "logit_scale", 1.0) self.token_map = None def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: return [block(hidden_states) for block in self.blocks] def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) weights_map = {} for name, loaded_weight in weights: name = name.replace("medusa_heads.", "") if name == "token_map": if self.truncated_vocab_size < self.orig_vocab_size: self.token_map = nn.Parameter(loaded_weight, requires_grad=False) elif name in params_dict: weights_map[name] = loaded_weight for name, loaded_weight in weights_map.items(): if "lm_head" in name and self.token_map is not None and\ loaded_weight.shape[0] > self.token_map.shape[0]: loaded_weight = loaded_weight[self.token_map] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if self.token_map is not None: self.token_map.to(device=self.lm_heads[0].weight.device) assert (self.truncated_vocab_size == self.orig_vocab_size) or (self.token_map is not None) class CustomMedusaConfig(PretrainedConfig): model_type = "medusa" def __init__(self, name_or_path: str = "S-3000/vllm-medusa-qwen1.5-7b-chat", architectures: list[str] = ["MedusaModel"], hidden_size: int = 4096, model_type: str = "medusa", num_heads: int = 5, num_hidden_layers: int = 1, transformers_version: str = "4.41.2", truncated_vocab_size: Optional[int] = None, vocab_size: int = 151936, medusa_choices:List[List[int]] = None, **kwargs): super().__init__(**kwargs) self._name_or_path = name_or_path self.architectures = architectures self.hidden_size = hidden_size self.model_type = model_type self.num_heads = num_heads self.num_hidden_layers = num_hidden_layers self.transformers_version = transformers_version self.truncated_vocab_size = truncated_vocab_size self.vocab_size = vocab_size self.medusa_choices = medusa_choices def main(args): medusa_head_num = args.medusa_num_heads medusa_num_layers = args.medusa_num_layers config = MedusaConfig(hidden_size=args.hidden_size, vocab_size=args.vocab_size, num_heads=medusa_head_num) medusa_model = Medusa(config) params_dict = dict(medusa_model.named_parameters()) trained_medusa_model = torch.load(args.medusa_model_path) for i in range(medusa_head_num): vllm_medusa_head_weight_name = VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE.format(i) trained_medusa_head_weight_name = TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE.format(i) vllm_medusa_head_param = params_dict[vllm_medusa_head_weight_name] trained_medusa_head_param = trained_medusa_model[trained_medusa_head_weight_name] weight_loader = getattr(vllm_medusa_head_param, "weight_loader", default_weight_loader) weight_loader(vllm_medusa_head_param, trained_medusa_head_param) for i in range(medusa_head_num): for j in range(medusa_num_layers): # load linear weight vllm_medusa_block_weight_name = VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j) trained_medusa_block_weight_name = TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j) vllm_medusa_block_param = params_dict[vllm_medusa_block_weight_name] trained_medusa_block_param = trained_medusa_model[trained_medusa_block_weight_name] weight_loader = getattr(vllm_medusa_block_param, "weight_loader", default_weight_loader) weight_loader(vllm_medusa_block_param, trained_medusa_block_param) # load linear bias vllm_medusa_block_bias_name = VLLM_BLOCK_BIAS_NAME_TEMPLATE.format(i, j) trained_medusa_block_bias_name = TRAINED_BLOCK_BIAS_NAME_TEMPLATE.format(i, j) vllm_medusa_block_bias_param = params_dict[vllm_medusa_block_bias_name] trained_medusa_block_bias_param = trained_medusa_model[trained_medusa_block_bias_name] weight_loader = getattr(vllm_medusa_block_bias_param, "weight_loader", default_weight_loader) weight_loader(vllm_medusa_block_bias_param, trained_medusa_block_bias_param) if not Path(args.output_dir).is_dir(): os.makedirs(args.output_dir, exist_ok=True) save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors")) medusa_choices = ast.literal_eval(args.medusa_choices) if args.medusa_choices is not None else None to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"), hidden_size=args.hidden_size, num_heads=medusa_head_num, num_hidden_layers=medusa_num_layers, vocab_size=args.vocab_size, medusa_choices=medusa_choices) to_save_config.save_pretrained(args.output_dir) # validate weight # with safe_open(os.path.join(args.output_dir, "model.safetensors"), framework="pt") as f: # param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0)) # trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0)] # mse_value = torch.nn.functional.mse_loss(param.cpu(), trained_param.cpu()) # print("weight mes:", mse_value) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Medusa Model Evaluator") parser.add_argument("--medusa_model_path", type=str, required=True, help="Path to the medusa model file.") parser.add_argument("--vocab_size", type=int, required=True, help="Vocab size") parser.add_argument("--medusa_num_heads", type=int, required=True, help="Number of Medusa heads") parser.add_argument("--medusa_num_layers", type=int, required=True, help="Number of Medusa layers") parser.add_argument("--hidden_size", type=int, required=True, help="Hidden size") parser.add_argument("--output_dir", type=str, required=True, help="Output dir") parser.add_argument( '--medusa_choices', type=str, default=None, help="Medusa choice to use, if not none, will use Medusa decoding." " E.g.: [[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]] for 9 medusa tokens." ) args = parser.parse_args() main(args)