Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
929b4f29
Unverified
Commit
929b4f29
authored
Feb 28, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 28, 2024
Browse files
Add LoRA support for Gemma (#3050)
parent
3b7178cf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
82 additions
and
7 deletions
+82
-7
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-1
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+2
-0
tests/lora/conftest.py
tests/lora/conftest.py
+5
-0
tests/lora/test_gemma.py
tests/lora/test_gemma.py
+46
-0
tests/lora/test_punica.py
tests/lora/test_punica.py
+2
-2
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+25
-3
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+1
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
929b4f29
...
@@ -50,7 +50,7 @@ steps:
...
@@ -50,7 +50,7 @@ steps:
command
:
pytest -v -s worker
command
:
pytest -v -s worker
-
label
:
LoRA Test
-
label
:
LoRA Test
command
:
pytest -v -s lora
command
:
pytest -v -s lora
--forked
-
label
:
Metrics Test
-
label
:
Metrics Test
command
:
pytest -v -s metrics
command
:
pytest -v -s metrics
...
...
csrc/punica/bgmv/bgmv_config.h
View file @
929b4f29
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 8192) \
...
@@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32256) \
...
...
tests/lora/conftest.py
View file @
929b4f29
...
@@ -126,6 +126,11 @@ def mixtral_lora_files():
...
@@ -126,6 +126,11 @@ def mixtral_lora_files():
return
snapshot_download
(
repo_id
=
"terrysun/mixtral-lora-adapter"
)
return
snapshot_download
(
repo_id
=
"terrysun/mixtral-lora-adapter"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
gemma_lora_files
():
return
snapshot_download
(
repo_id
=
"wskwon/gemma-7b-test-lora"
)
@
pytest
.
fixture
@
pytest
.
fixture
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
cleanup
()
cleanup
()
...
...
tests/lora/test_gemma.py
0 → 100644
View file @
929b4f29
import
vllm
from
vllm.lora.request
import
LoRARequest
MODEL_PATH
=
"google/gemma-7b"
def
do_sample
(
llm
,
lora_path
:
str
,
lora_id
:
int
)
->
str
:
prompts
=
[
"Quote: Imagination is"
,
"Quote: Be yourself;"
,
"Quote: So many books,"
,
]
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
lora_request
=
LoRARequest
(
str
(
lora_id
),
lora_id
,
lora_path
)
if
lora_id
else
None
)
# Print the outputs.
generated_texts
=
[]
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
.
strip
()
generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
def
test_gemma_lora
(
gemma_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
1024
,
enable_lora
=
True
,
max_loras
=
4
)
expected_lora_output
=
[
"more important than knowledge.
\n
Author: Albert Einstein
\n
"
,
"everyone else is already taken.
\n
Author: Oscar Wilde
\n
"
,
"so little time
\n
Author: Frank Zappa
\n
"
,
]
output1
=
do_sample
(
llm
,
gemma_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
expected_lora_output
)):
assert
output1
[
i
].
startswith
(
expected_lora_output
[
i
])
output2
=
do_sample
(
llm
,
gemma_lora_files
,
lora_id
=
2
)
for
i
in
range
(
len
(
expected_lora_output
)):
assert
output2
[
i
].
startswith
(
expected_lora_output
[
i
])
tests/lora/test_punica.py
View file @
929b4f29
...
@@ -44,8 +44,8 @@ def _lora_ref_impl(
...
@@ -44,8 +44,8 @@ def _lora_ref_impl(
H1
=
H2
=
[
H1
=
H2
=
[
128
,
256
,
512
,
1024
,
1280
,
2048
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
5120
,
128
,
256
,
512
,
1024
,
1280
,
2048
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
5120
,
5504
,
5632
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
32000
,
5504
,
5632
,
6144
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
32256
,
32512
,
32768
,
33024
24576
,
32000
,
32256
,
32512
,
32768
,
33024
]
]
SEED
=
[
0xabcdabcd987
]
SEED
=
[
0xabcdabcd987
]
...
...
vllm/model_executor/models/gemma.py
View file @
929b4f29
...
@@ -20,6 +20,7 @@ import torch
...
@@ -20,6 +20,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
GemmaConfig
from
transformers
import
GemmaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
...
@@ -246,12 +247,36 @@ class GemmaModel(nn.Module):
...
@@ -246,12 +247,36 @@ class GemmaModel(nn.Module):
class
GemmaForCausalLM
(
nn
.
Module
):
class
GemmaForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GemmaConfig
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
...
@@ -305,9 +330,6 @@ class GemmaForCausalLM(nn.Module):
...
@@ -305,9 +330,6 @@ class GemmaForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra layer for lora models.
if
"lm_head"
in
name
:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
# (1 + weight) to the output, instead of just weight.
if
"norm.weight"
in
name
:
if
"norm.weight"
in
name
:
...
...
vllm/model_executor/models/llama.py
View file @
929b4f29
...
@@ -27,6 +27,7 @@ import torch
...
@@ -27,6 +27,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
...
@@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.config
import
LoRAConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment