Unverified Commit ee93f4f9 authored by Qubitium-ModelCloud's avatar Qubitium-ModelCloud Committed by GitHub
Browse files

[CORE] Quantized lm-head Framework (#4442)


Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Co-authored-by: default avatarZX <zx@lbx.dev>
parent 7c008c51
...@@ -347,8 +347,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -347,8 +347,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight, logits = self.logits_processor(self.model.embed_tokens, hidden_states,
hidden_states, sampling_metadata) sampling_metadata)
return logits return logits
def sample( def sample(
......
...@@ -346,8 +346,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -346,8 +346,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight, logits = self.logits_processor(self.model.embed_tokens, hidden_states,
hidden_states, sampling_metadata) sampling_metadata)
return logits return logits
def sample( def sample(
......
...@@ -238,7 +238,7 @@ class GPT2LMHeadModel(nn.Module): ...@@ -238,7 +238,7 @@ class GPT2LMHeadModel(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, cache_config, quant_config) self.transformer = GPT2Model(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -256,7 +256,7 @@ class GPT2LMHeadModel(nn.Module): ...@@ -256,7 +256,7 @@ class GPT2LMHeadModel(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -259,7 +259,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -259,7 +259,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config, self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config) lora_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head = self.transformer.wte
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
...@@ -281,7 +281,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -281,7 +281,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -229,6 +229,7 @@ class GPTJForCausalLM(nn.Module): ...@@ -229,6 +229,7 @@ class GPTJForCausalLM(nn.Module):
config.vocab_size, config.vocab_size,
config.n_embd, config.n_embd,
bias=True, bias=True,
quant_config=quant_config,
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -247,7 +248,7 @@ class GPTJForCausalLM(nn.Module): ...@@ -247,7 +248,7 @@ class GPTJForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return logits return logits
......
...@@ -241,6 +241,7 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -241,6 +241,7 @@ class GPTNeoXForCausalLM(nn.Module):
self.embed_out = ParallelLMHead( self.embed_out = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -259,7 +260,7 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -259,7 +260,7 @@ class GPTNeoXForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.embed_out.weight, hidden_states, logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -253,7 +253,9 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -253,7 +253,9 @@ class InternLM2ForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = InternLM2Model(config, cache_config, quant_config) self.model = InternLM2Model(config, cache_config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -271,7 +273,7 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -271,7 +273,7 @@ class InternLM2ForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.output.weight, hidden_states, logits = self.logits_processor(self.output, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -273,7 +273,7 @@ class JAISLMHeadModel(nn.Module): ...@@ -273,7 +273,7 @@ class JAISLMHeadModel(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config) self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head = self.transformer.wte
if hasattr(config, "width_scale"): if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale self.output_logits_scale = config.width_scale
else: else:
...@@ -297,7 +297,7 @@ class JAISLMHeadModel(nn.Module): ...@@ -297,7 +297,7 @@ class JAISLMHeadModel(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -380,6 +380,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -380,6 +380,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel # We need bigger padding if using lora for kernel
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
...@@ -403,7 +404,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -403,7 +404,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -125,7 +125,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -125,7 +125,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, self.unpadded_vocab_size,
config.text_config.hidden_size, config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size) org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
...@@ -255,7 +256,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -255,7 +256,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -186,7 +186,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -186,7 +186,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, self.unpadded_vocab_size,
config.text_config.hidden_size, config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size) org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
...@@ -438,7 +439,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -438,7 +439,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -449,6 +449,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -449,6 +449,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel # We need bigger padding if using lora for kernel
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
) )
self.scale_width = self.config.hidden_size / self.config.dim_model_base self.scale_width = self.config.hidden_size / self.config.dim_model_base
...@@ -472,10 +473,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -472,10 +473,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight lm_head = self.model.embed_tokens
else: else:
lm_head_weight = self.lm_head.weight lm_head = self.lm_head
logits = self.logits_processor(lm_head_weight, hidden_states, logits = self.logits_processor(lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -331,6 +331,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -331,6 +331,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel # We need bigger padding if using lora for kernel
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
) )
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
...@@ -350,7 +351,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -350,7 +351,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -344,7 +344,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -344,7 +344,9 @@ class MixtralForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel(config, cache_config, quant_config) self.model = MixtralModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -362,7 +364,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -362,7 +364,7 @@ class MixtralForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -8,7 +8,7 @@ from vllm.model_executor import SamplingMetadata ...@@ -8,7 +8,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig from vllm.transformers_utils.configs import MLPSpeculatorConfig
...@@ -87,7 +87,7 @@ class MLPSpeculator(nn.Module): ...@@ -87,7 +87,7 @@ class MLPSpeculator(nn.Module):
self.proj = nn.ModuleList([proj_first] + [proj_tied] * self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1)) (self.max_speculative_tokens - 1))
head = nn.Linear(self.inner_dim, self.vocab_size, bias=False) head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
self.head = nn.ModuleList([head] * self.max_speculative_tokens) self.head = nn.ModuleList([head] * self.max_speculative_tokens)
ln = MLPSpeculatorLayerNorm(self.inner_dim, ln = MLPSpeculatorLayerNorm(self.inner_dim,
...@@ -169,8 +169,8 @@ class MLPSpeculator(nn.Module): ...@@ -169,8 +169,8 @@ class MLPSpeculator(nn.Module):
# TODO: not yet supporting top_k_tokens_per_head # TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states previous_hidden_states = states
logits = self.logits_processor(self.head[head_index].weight, logits = self.logits_processor(self.head[head_index], states,
states, sampling_metadata) sampling_metadata)
output = self.sampler(logits.flatten(0, 1), sampling_metadata) output = self.sampler(logits.flatten(0, 1), sampling_metadata)
last_tokens = output.sampled_token_ids last_tokens = output.sampled_token_ids
......
...@@ -263,7 +263,7 @@ class MPTForCausalLM(nn.Module): ...@@ -263,7 +263,7 @@ class MPTForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = MPTModel(config, cache_config, quant_config) self.transformer = MPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -281,7 +281,7 @@ class MPTForCausalLM(nn.Module): ...@@ -281,7 +281,7 @@ class MPTForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -283,15 +283,15 @@ class OlmoForCausalLM(nn.Module): ...@@ -283,15 +283,15 @@ class OlmoForCausalLM(nn.Module):
self.config = config self.config = config
self.model = OlmoModel(config, cache_config, quant_config) self.model = OlmoModel(config, cache_config, quant_config)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight self.lm_head = self.model.embed_tokens
else: else:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, self.unpadded_vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
quant_config=quant_config,
) )
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -313,7 +313,7 @@ class OlmoForCausalLM(nn.Module): ...@@ -313,7 +313,7 @@ class OlmoForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -294,7 +294,7 @@ class OPTForCausalLM(nn.Module): ...@@ -294,7 +294,7 @@ class OPTForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OPTModel(config, cache_config, quant_config) self.model = OPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight self.lm_head = self.model.decoder.embed_tokens
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -312,7 +312,7 @@ class OPTForCausalLM(nn.Module): ...@@ -312,7 +312,7 @@ class OPTForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -259,7 +259,9 @@ class OrionForCausalLM(nn.Module): ...@@ -259,7 +259,9 @@ class OrionForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OrionModel(config, cache_config, quant_config) self.model = OrionModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -277,7 +279,7 @@ class OrionForCausalLM(nn.Module): ...@@ -277,7 +279,7 @@ class OrionForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -268,7 +268,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -268,7 +268,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
bias=True) bias=True,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -287,7 +288,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -287,7 +288,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment