Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
97b03000
Unverified
Commit
97b03000
authored
May 23, 2024
by
raywanb
Committed by
GitHub
May 22, 2024
Browse files
[Model] LoRA gptbigcode implementation (#3949)
parent
a3a73ab0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
5 deletions
+34
-5
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+4
-0
tests/lora/test_punica.py
tests/lora/test_punica.py
+2
-0
vllm/lora/models.py
vllm/lora/models.py
+2
-0
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+26
-5
No files found.
csrc/punica/bgmv/bgmv_config.h
View file @
97b03000
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 4096) \
...
@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 7168) \
...
@@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
...
@@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
...
...
tests/lora/test_punica.py
View file @
97b03000
...
@@ -58,6 +58,7 @@ H1 = H2 = [
...
@@ -58,6 +58,7 @@ H1 = H2 = [
2560
,
2560
,
2752
,
2752
,
3072
,
3072
,
3328
,
3456
,
3456
,
3584
,
3584
,
4096
,
4096
,
...
@@ -66,6 +67,7 @@ H1 = H2 = [
...
@@ -66,6 +67,7 @@ H1 = H2 = [
5504
,
5504
,
5632
,
5632
,
6144
,
6144
,
6400
,
6848
,
6848
,
6912
,
6912
,
7168
,
7168
,
...
...
vllm/lora/models.py
View file @
97b03000
...
@@ -310,7 +310,9 @@ class LoRAModel:
...
@@ -310,7 +310,9 @@ class LoRAModel:
if
part_name
not
in
expected_lora_modules
:
if
part_name
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module
)
unexpected_modules
.
append
(
module
)
# loaded lora's target modules must be a subset of expected_lora_modules
# loaded lora's target modules must be a subset of expected_lora_modules
if
unexpected_modules
:
if
unexpected_modules
:
print
(
unexpected_modules
,
"modules"
)
raise
ValueError
(
raise
ValueError
(
f
"While loading
{
lora_dir
}
, expected"
f
"While loading
{
lora_dir
}
, expected"
f
" target modules in
{
expected_lora_modules
}
"
f
" target modules in
{
expected_lora_modules
}
"
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
97b03000
...
@@ -25,7 +25,7 @@ from torch import nn
...
@@ -25,7 +25,7 @@ from torch import nn
from
transformers
import
GPTBigCodeConfig
from
transformers
import
GPTBigCodeConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -191,14 +191,19 @@ class GPTBigCodeModel(nn.Module):
...
@@ -191,14 +191,19 @@ class GPTBigCodeModel(nn.Module):
config
:
GPTBigCodeConfig
,
config
:
GPTBigCodeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
assert
not
config
.
add_cross_attention
assert
not
config
.
add_cross_attention
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
wte
=
VocabParallelEmbedding
(
self
.
vocab_size
,
self
.
embed_dim
,
org_num_embeddings
=
config
.
vocab_size
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
([
self
.
h
=
nn
.
ModuleList
([
GPTBigCodeBlock
(
config
,
cache_config
,
quant_config
)
GPTBigCodeBlock
(
config
,
cache_config
,
quant_config
)
...
@@ -226,19 +231,35 @@ class GPTBigCodeModel(nn.Module):
...
@@ -226,19 +231,35 @@ class GPTBigCodeModel(nn.Module):
class
GPTBigCodeForCausalLM
(
nn
.
Module
):
class
GPTBigCodeForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"c_attn"
:
[
"c_attn"
]}
supported_lora_modules
=
[
"c_fc"
,
"c_proj"
,
"wte"
,
"lm_head"
,
"c_attn"
]
embedding_modules
=
{
"wte"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTBigCodeConfig
,
config
:
GPTBigCodeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
,
lora_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment