Unverified Commit 31348dff authored by Philipp Moritz's avatar Philipp Moritz Committed by GitHub
Browse files

Align LoRA code between Mistral and Mixtral (fixes #2875) (#2880)



* Fix AttributeError: MixtralModel object has no attribute org_vocab_size.

* Make LoRA logic for Mistral and Mixtral the same

---------
Co-authored-by: default avatarPernekhan Utemuratov <pernekhan@deepinfra.com>
parent 25e86b6a
...@@ -285,15 +285,19 @@ class MixtralModel(nn.Module): ...@@ -285,15 +285,19 @@ class MixtralModel(nn.Module):
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=self.org_vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method) MixtralDecoderLayer(config, linear_method=linear_method)
...@@ -350,7 +354,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -350,7 +354,9 @@ class MixtralForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = MixtralModel(config, linear_method) self.model = MixtralModel(config,
linear_method,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment