Unverified Commit ee93f4f9 authored by Qubitium-ModelCloud's avatar Qubitium-ModelCloud Committed by GitHub
Browse files

[CORE] Quantized lm-head Framework (#4442)


Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Co-authored-by: default avatarZX <zx@lbx.dev>
parent 7c008c51
...@@ -366,6 +366,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -366,6 +366,7 @@ class Phi3SmallForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -400,7 +401,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -400,7 +401,7 @@ class Phi3SmallForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
if self.dummy_token_indices is not None and logits is not None: if self.dummy_token_indices is not None and logits is not None:
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
......
...@@ -365,7 +365,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -365,7 +365,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self.model = LlamaModel(config, cache_config, quant_config) self.model = LlamaModel(config, cache_config, quant_config)
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(
vlm_config, config, self.model.embed_tokens) vlm_config, config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -409,7 +411,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -409,7 +411,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -235,7 +235,9 @@ class QWenLMHeadModel(nn.Module): ...@@ -235,7 +235,9 @@ class QWenLMHeadModel(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = QWenModel(config, cache_config, quant_config) self.transformer = QWenModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -253,7 +255,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -253,7 +255,7 @@ class QWenLMHeadModel(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -316,11 +316,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -316,11 +316,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.model = Qwen2Model(config, cache_config, quant_config) self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size) config.hidden_size,
self.lm_head_weight = self.lm_head.weight quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -339,7 +339,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -339,7 +339,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -362,7 +362,9 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -362,7 +362,9 @@ class Qwen2MoeForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2MoeModel(config, cache_config, quant_config) self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -380,7 +382,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -380,7 +382,7 @@ class Qwen2MoeForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -240,7 +240,9 @@ class StablelmForCausalLM(nn.Module): ...@@ -240,7 +240,9 @@ class StablelmForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = StableLMEpochModel(config, cache_config, quant_config) self.model = StableLMEpochModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -258,7 +260,7 @@ class StablelmForCausalLM(nn.Module): ...@@ -258,7 +260,7 @@ class StablelmForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -242,7 +242,7 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -242,7 +242,7 @@ class Starcoder2ForCausalLM(nn.Module):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight self.lm_head = self.model.embed_tokens
else: else:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
...@@ -250,8 +250,8 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -250,8 +250,8 @@ class Starcoder2ForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
) )
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -270,7 +270,7 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -270,7 +270,7 @@ class Starcoder2ForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -310,7 +310,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -310,7 +310,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config self.quant_config = quant_config
self.model = XverseModel(config, cache_config, quant_config) self.model = XverseModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -328,7 +330,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -328,7 +330,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment