Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
929b4f29
Unverified
Commit
929b4f29
authored
Feb 28, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 28, 2024
Browse files
Add LoRA support for Gemma (#3050)
parent
3b7178cf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
82 additions
and
7 deletions
+82
-7
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-1
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+2
-0
tests/lora/conftest.py
tests/lora/conftest.py
+5
-0
tests/lora/test_gemma.py
tests/lora/test_gemma.py
+46
-0
tests/lora/test_punica.py
tests/lora/test_punica.py
+2
-2
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+25
-3
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+1
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
929b4f29
...
...
@@ -50,7 +50,7 @@ steps:
command
:
pytest -v -s worker
-
label
:
LoRA Test
command
:
pytest -v -s lora
command
:
pytest -v -s lora
--forked
-
label
:
Metrics Test
command
:
pytest -v -s metrics
...
...
csrc/punica/bgmv/bgmv_config.h
View file @
929b4f29
...
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
...
...
@@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
...
...
tests/lora/conftest.py
View file @
929b4f29
...
...
@@ -126,6 +126,11 @@ def mixtral_lora_files():
return
snapshot_download
(
repo_id
=
"terrysun/mixtral-lora-adapter"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
gemma_lora_files
():
return
snapshot_download
(
repo_id
=
"wskwon/gemma-7b-test-lora"
)
@
pytest
.
fixture
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
cleanup
()
...
...
tests/lora/test_gemma.py
0 → 100644
View file @
929b4f29
import
vllm
from
vllm.lora.request
import
LoRARequest
MODEL_PATH
=
"google/gemma-7b"
def
do_sample
(
llm
,
lora_path
:
str
,
lora_id
:
int
)
->
str
:
prompts
=
[
"Quote: Imagination is"
,
"Quote: Be yourself;"
,
"Quote: So many books,"
,
]
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
lora_request
=
LoRARequest
(
str
(
lora_id
),
lora_id
,
lora_path
)
if
lora_id
else
None
)
# Print the outputs.
generated_texts
=
[]
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
.
strip
()
generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
def
test_gemma_lora
(
gemma_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
1024
,
enable_lora
=
True
,
max_loras
=
4
)
expected_lora_output
=
[
"more important than knowledge.
\n
Author: Albert Einstein
\n
"
,
"everyone else is already taken.
\n
Author: Oscar Wilde
\n
"
,
"so little time
\n
Author: Frank Zappa
\n
"
,
]
output1
=
do_sample
(
llm
,
gemma_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
expected_lora_output
)):
assert
output1
[
i
].
startswith
(
expected_lora_output
[
i
])
output2
=
do_sample
(
llm
,
gemma_lora_files
,
lora_id
=
2
)
for
i
in
range
(
len
(
expected_lora_output
)):
assert
output2
[
i
].
startswith
(
expected_lora_output
[
i
])
tests/lora/test_punica.py
View file @
929b4f29
...
...
@@ -44,8 +44,8 @@ def _lora_ref_impl(
H1
=
H2
=
[
128
,
256
,
512
,
1024
,
1280
,
2048
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
5120
,
5504
,
5632
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
32000
,
32256
,
32512
,
32768
,
33024
5504
,
5632
,
6144
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
24576
,
32000
,
32256
,
32512
,
32768
,
33024
]
SEED
=
[
0xabcdabcd987
]
...
...
vllm/model_executor/models/gemma.py
View file @
929b4f29
...
...
@@ -20,6 +20,7 @@ import torch
from
torch
import
nn
from
transformers
import
GemmaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
...
...
@@ -246,12 +247,36 @@ class GemmaModel(nn.Module):
class
GemmaForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
...
...
@@ -305,9 +330,6 @@ class GemmaForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra layer for lora models.
if
"lm_head"
in
name
:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if
"norm.weight"
in
name
:
...
...
vllm/model_executor/models/llama.py
View file @
929b4f29
...
...
@@ -27,6 +27,7 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
...
...
@@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.config
import
LoRAConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment