Unverified Commit ee93f4f9 authored by Qubitium-ModelCloud's avatar Qubitium-ModelCloud Committed by GitHub
Browse files

[CORE] Quantized lm-head Framework (#4442)


Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Co-authored-by: default avatarZX <zx@lbx.dev>
parent 7c008c51
......@@ -347,8 +347,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
def sample(
......
......@@ -346,8 +346,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
def sample(
......
......@@ -238,7 +238,7 @@ class GPT2LMHeadModel(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -256,7 +256,7 @@ class GPT2LMHeadModel(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -259,7 +259,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
self.lm_head_weight = self.transformer.wte.weight
self.lm_head = self.transformer.wte
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......@@ -281,7 +281,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -229,6 +229,7 @@ class GPTJForCausalLM(nn.Module):
config.vocab_size,
config.n_embd,
bias=True,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -247,7 +248,7 @@ class GPTJForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
......
......@@ -241,6 +241,7 @@ class GPTNeoXForCausalLM(nn.Module):
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -259,7 +260,7 @@ class GPTNeoXForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.embed_out.weight, hidden_states,
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits
......
......@@ -253,7 +253,9 @@ class InternLM2ForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(config, cache_config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -271,7 +273,7 @@ class InternLM2ForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.output.weight, hidden_states,
logits = self.logits_processor(self.output, hidden_states,
sampling_metadata)
return logits
......
......@@ -273,7 +273,7 @@ class JAISLMHeadModel(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.lm_head = self.transformer.wte
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
else:
......@@ -297,7 +297,7 @@ class JAISLMHeadModel(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -380,6 +380,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
......@@ -403,7 +404,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -125,7 +125,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size)
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
......@@ -255,7 +256,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -186,7 +186,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size)
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
......@@ -438,7 +439,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -449,6 +449,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
self.scale_width = self.config.hidden_size / self.config.dim_model_base
......@@ -472,10 +473,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata) -> torch.Tensor:
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
lm_head = self.model.embed_tokens
else:
lm_head_weight = self.lm_head.weight
logits = self.logits_processor(lm_head_weight, hidden_states,
lm_head = self.lm_head
logits = self.logits_processor(lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -331,6 +331,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -350,7 +351,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -344,7 +344,9 @@ class MixtralForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -362,7 +364,7 @@ class MixtralForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -8,7 +8,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig
......@@ -87,7 +87,7 @@ class MLPSpeculator(nn.Module):
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1))
head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
ln = MLPSpeculatorLayerNorm(self.inner_dim,
......@@ -169,8 +169,8 @@ class MLPSpeculator(nn.Module):
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states
logits = self.logits_processor(self.head[head_index].weight,
states, sampling_metadata)
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
last_tokens = output.sampled_token_ids
......
......@@ -263,7 +263,7 @@ class MPTForCausalLM(nn.Module):
self.quant_config = quant_config
self.transformer = MPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -281,7 +281,7 @@ class MPTForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -283,15 +283,15 @@ class OlmoForCausalLM(nn.Module):
self.config = config
self.model = OlmoModel(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -313,7 +313,7 @@ class OlmoForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -294,7 +294,7 @@ class OPTForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.lm_head = self.model.decoder.embed_tokens
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -312,7 +312,7 @@ class OPTForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -259,7 +259,9 @@ class OrionForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = OrionModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -277,7 +279,7 @@ class OrionForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -268,7 +268,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
bias=True,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -287,7 +288,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment