Unverified Commit 2003cc35 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Model][LoRA]LoRA support added for LlamaEmbeddingModel (#10071)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 6a585a23
......@@ -333,7 +333,7 @@ Text Embedding
* - :code:`MistralModel`
- Mistral-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
-
- ✅︎
- ✅︎
.. important::
......
......@@ -627,7 +627,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsPP):
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
......@@ -638,6 +638,19 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
}
embedding_padding_modules = []
def __init__(
self,
......@@ -679,3 +692,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# LRUCacheWorkerLoRAManager instantiation requires model config.
@property
def config(self):
return self.model.config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment