Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
929b4f29
Unverified
Commit
929b4f29
authored
Feb 28, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 28, 2024
Browse files
Add LoRA support for Gemma (#3050)
parent
3b7178cf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
82 additions
and
7 deletions
+82
-7
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-1
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+2
-0
tests/lora/conftest.py
tests/lora/conftest.py
+5
-0
tests/lora/test_gemma.py
tests/lora/test_gemma.py
+46
-0
tests/lora/test_punica.py
tests/lora/test_punica.py
+2
-2
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+25
-3
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+1
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
929b4f29
...
@@ -50,7 +50,7 @@ steps:
...
@@ -50,7 +50,7 @@ steps:
command
:
pytest -v -s worker
command
:
pytest -v -s worker
-
label
:
LoRA Test
-
label
:
LoRA Test
command
:
pytest -v -s lora
command
:
pytest -v -s lora
--forked
-
label
:
Metrics Test
-
label
:
Metrics Test
command
:
pytest -v -s metrics
command
:
pytest -v -s metrics
...
...
csrc/punica/bgmv/bgmv_config.h
View file @
929b4f29
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 8192) \
...
@@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32256) \
...
...
tests/lora/conftest.py
View file @
929b4f29
...
@@ -126,6 +126,11 @@ def mixtral_lora_files():
...
@@ -126,6 +126,11 @@ def mixtral_lora_files():
return
snapshot_download
(
repo_id
=
"terrysun/mixtral-lora-adapter"
)
return
snapshot_download
(
repo_id
=
"terrysun/mixtral-lora-adapter"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
gemma_lora_files
():
return
snapshot_download
(
repo_id
=
"wskwon/gemma-7b-test-lora"
)
@
pytest
.
fixture
@
pytest
.
fixture
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
cleanup
()
cleanup
()
...
...
tests/lora/test_gemma.py
0 → 100644
View file @
929b4f29
import
vllm
from
vllm.lora.request
import
LoRARequest
MODEL_PATH
=
"google/gemma-7b"
def
do_sample
(
llm
,
lora_path
:
str
,
lora_id
:
int
)
->
str
:
prompts
=
[
"Quote: Imagination is"
,
"Quote: Be yourself;"
,
"Quote: So many books,"
,
]
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
lora_request
=
LoRARequest
(
str
(
lora_id
),
lora_id
,
lora_path
)
if
lora_id
else
None
)
# Print the outputs.
generated_texts
=
[]
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
.
strip
()
generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
def
test_gemma_lora
(
gemma_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
1024
,
enable_lora
=
True
,
max_loras
=
4
)
expected_lora_output
=
[
"more important than knowledge.
\n
Author: Albert Einstein
\n
"
,
"everyone else is already taken.
\n
Author: Oscar Wilde
\n
"
,
"so little time
\n
Author: Frank Zappa
\n
"
,
]
output1
=
do_sample
(
llm
,
gemma_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
expected_lora_output
)):
assert
output1
[
i
].
startswith
(
expected_lora_output
[
i
])
output2
=
do_sample
(
llm
,
gemma_lora_files
,
lora_id
=
2
)
for
i
in
range
(
len
(
expected_lora_output
)):
assert
output2
[
i
].
startswith
(
expected_lora_output
[
i
])
tests/lora/test_punica.py
View file @
929b4f29
...
@@ -44,8 +44,8 @@ def _lora_ref_impl(
...
@@ -44,8 +44,8 @@ def _lora_ref_impl(
H1
=
H2
=
[
H1
=
H2
=
[
128
,
256
,
512
,
1024
,
1280
,
2048
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
5120
,
128
,
256
,
512
,
1024
,
1280
,
2048
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
5120
,
5504
,
5632
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
32000
,
5504
,
5632
,
6144
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
32256
,
32512
,
32768
,
33024
24576
,
32000
,
32256
,
32512
,
32768
,
33024
]
]
SEED
=
[
0xabcdabcd987
]
SEED
=
[
0xabcdabcd987
]
...
...
vllm/model_executor/models/gemma.py
View file @
929b4f29
...
@@ -20,6 +20,7 @@ import torch
...
@@ -20,6 +20,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
GemmaConfig
from
transformers
import
GemmaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
...
@@ -246,12 +247,36 @@ class GemmaModel(nn.Module):
...
@@ -246,12 +247,36 @@ class GemmaModel(nn.Module):
class
GemmaForCausalLM
(
nn
.
Module
):
class
GemmaForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GemmaConfig
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
...
@@ -305,9 +330,6 @@ class GemmaForCausalLM(nn.Module):
...
@@ -305,9 +330,6 @@ class GemmaForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra layer for lora models.
if
"lm_head"
in
name
:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
# (1 + weight) to the output, instead of just weight.
if
"norm.weight"
in
name
:
if
"norm.weight"
in
name
:
...
...
vllm/model_executor/models/llama.py
View file @
929b4f29
...
@@ -27,6 +27,7 @@ import torch
...
@@ -27,6 +27,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
...
@@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.config
import
LoRAConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment