Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
07feecde
Unverified
Commit
07feecde
authored
Jun 18, 2024
by
sergey-tinkoff
Committed by
GitHub
Jun 18, 2024
Browse files
[Model] LoRA support added for command-r (#5178)
parent
19091efc
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
6 deletions
+50
-6
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+6
-0
tests/lora/test_punica.py
tests/lora/test_punica.py
+2
-0
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+42
-6
No files found.
csrc/punica/bgmv/bgmv_config.h
100644 → 100755
View file @
07feecde
...
@@ -69,6 +69,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -69,6 +69,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 60544) \
f(in_T, out_T, W_T, narrow, 60672) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 64512) \
...
@@ -78,6 +80,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -78,6 +80,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// and vllm/tests/lora/test_punica.py
...
@@ -144,6 +148,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -144,6 +148,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 60544, narrow) \
f(in_T, out_T, W_T, 60672, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
...
...
tests/lora/test_punica.py
View file @
07feecde
...
@@ -94,6 +94,8 @@ H1 = H2 = [
...
@@ -94,6 +94,8 @@ H1 = H2 = [
36864
,
36864
,
43264
,
43264
,
49152
,
49152
,
60544
,
60672
,
64000
,
64000
,
64256
,
64256
,
102400
,
102400
,
...
...
vllm/model_executor/models/commandr.py
View file @
07feecde
...
@@ -29,7 +29,7 @@ from torch.nn.parameter import Parameter
...
@@ -29,7 +29,7 @@ from torch.nn.parameter import Parameter
from
transformers
import
CohereConfig
from
transformers
import
CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -265,10 +265,14 @@ class CohereModel(nn.Module):
...
@@ -265,10 +265,14 @@ class CohereModel(nn.Module):
config
:
CohereConfig
,
config
:
CohereConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
...
@@ -302,18 +306,44 @@ class CohereModel(nn.Module):
...
@@ -302,18 +306,44 @@ class CohereModel(nn.Module):
class
CohereForCausalLM
(
nn
.
Module
):
class
CohereForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
}
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
CohereConfig
,
config
:
CohereConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
scale
=
config
.
logit_scale
)
scale
=
config
.
logit_scale
)
self
.
model
=
CohereModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
CohereModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -330,8 +360,14 @@ class CohereForCausalLM(nn.Module):
...
@@ -330,8 +360,14 @@ class CohereForCausalLM(nn.Module):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
.
weight
,
is_not_lora
=
hasattr
(
self
.
model
.
embed_tokens
,
'weight'
)
hidden_states
,
sampling_metadata
)
if
is_not_lora
:
embedding_weights
=
self
.
model
.
embed_tokens
.
weight
else
:
embedding_weights
=
self
.
model
.
embed_tokens
.
base_layer
.
weight
logits
=
self
.
logits_processor
(
embedding_weights
,
hidden_states
,
sampling_metadata
)
return
logits
return
logits
def
sample
(
def
sample
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment