Unverified Commit 07feecde authored by sergey-tinkoff's avatar sergey-tinkoff Committed by GitHub
Browse files

[Model] LoRA support added for command-r (#5178)

parent 19091efc
...@@ -69,6 +69,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -69,6 +69,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \ f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 60544) \
f(in_T, out_T, W_T, narrow, 60672) \
f(in_T, out_T, W_T, narrow, 64000) \ f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \ f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \ f(in_T, out_T, W_T, narrow, 64512) \
...@@ -78,6 +80,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -78,6 +80,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 128000) \ f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \ f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \ f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py // and vllm/tests/lora/test_punica.py
...@@ -144,6 +148,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -144,6 +148,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 36864, narrow) \ f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \ f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \ f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 60544, narrow) \
f(in_T, out_T, W_T, 60672, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \ f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \ f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \ f(in_T, out_T, W_T, 64512, narrow) \
......
...@@ -94,6 +94,8 @@ H1 = H2 = [ ...@@ -94,6 +94,8 @@ H1 = H2 = [
36864, 36864,
43264, 43264,
49152, 49152,
60544,
60672,
64000, 64000,
64256, 64256,
102400, 102400,
......
...@@ -29,7 +29,7 @@ from torch.nn.parameter import Parameter ...@@ -29,7 +29,7 @@ from torch.nn.parameter import Parameter
from transformers import CohereConfig from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -265,10 +265,14 @@ class CohereModel(nn.Module): ...@@ -265,10 +265,14 @@ class CohereModel(nn.Module):
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.vocab_size lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
...@@ -302,18 +306,44 @@ class CohereModel(nn.Module): ...@@ -302,18 +306,44 @@ class CohereModel(nn.Module):
class CohereForCausalLM(nn.Module): class CohereForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {"embed_tokens": "input_embeddings"}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.quant_config = quant_config self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=config.logit_scale) scale=config.logit_scale)
self.model = CohereModel(config, cache_config, quant_config) self.model = CohereModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.sampler = Sampler() self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
...@@ -330,8 +360,14 @@ class CohereForCausalLM(nn.Module): ...@@ -330,8 +360,14 @@ class CohereForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight, is_not_lora = hasattr(self.model.embed_tokens, 'weight')
hidden_states, sampling_metadata) if is_not_lora:
embedding_weights = self.model.embed_tokens.weight
else:
embedding_weights = self.model.embed_tokens.base_layer.weight
logits = self.logits_processor(embedding_weights, hidden_states,
sampling_metadata)
return logits return logits
def sample( def sample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment