Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e96c334
"vllm/worker/cache_engine.py" did not exist on "bb59a3e7302ad6892e097eee4040e3f516e9f4ea"
Unverified
Commit
1e96c334
authored
Apr 11, 2024
by
Antoni Baum
Committed by
GitHub
Apr 11, 2024
Browse files
Add extra punica sizes to support bigger vocabs (#4015)
parent
95e7d4a9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
109 additions
and
48 deletions
+109
-48
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+11
-1
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+7
-7
tests/lora/test_layers.py
tests/lora/test_layers.py
+44
-34
tests/lora/test_punica.py
tests/lora/test_punica.py
+45
-4
vllm/lora/layers.py
vllm/lora/layers.py
+2
-2
No files found.
csrc/punica/bgmv/bgmv_config.h
View file @
1e96c334
...
...
@@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
...
...
csrc/punica/punica_ops.cc
View file @
1e96c334
...
...
@@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
}
}
inline
constexpr
uint
32
_t
pack_u
16
(
uint
16
_t
a
,
uint
16
_t
b
)
{
return
(
uint
32
_t
(
a
)
<<
16
)
|
uint
32
_t
(
b
);
inline
constexpr
uint
64
_t
pack_u
32
(
uint
32
_t
a
,
uint
32
_t
b
)
{
return
(
uint
64
_t
(
a
)
<<
32
)
|
uint
64
_t
(
b
);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
...
...
@@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
template
<
typename
in_T
,
typename
out_T
,
typename
W_T
>
inline
bool
launch_bgmv_kernel
(
out_T
*
Y
,
const
in_T
*
X
,
const
W_T
*
W
,
const
int64_t
*
lora_indices
,
uint
16
_t
in_features
,
uint
16
_t
out_features
,
uint
32
_t
in_features
,
uint
32
_t
out_features
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
switch
(
pack_u
16
(
in_features
,
out_features
))
{
switch
(
pack_u
32
(
in_features
,
out_features
))
{
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u
16
(feat_in, feat_out): \
case pack_u
32
(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
...
...
@@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ
(
y
.
size
(
0
),
x
.
size
(
0
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
x
));
bool
ok
=
false
;
if
(
h_in
<
65536
&&
h_out
<
65536
)
{
if
(
h_in
<
=
128512
&&
h_out
<
=
128512
)
{
// TODO: See if we can get rid of this massive nested switch
switch
(
x
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
...
...
@@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ
(
y
.
size
(
0
),
x
.
size
(
0
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
x
));
bool
ok
=
false
;
if
(
h_in
<
65536
&&
h_out
<
65536
)
{
if
(
h_in
<
=
128512
&&
h_out
<
=
128512
)
{
// TODO: See if we can get rid of this massive nested switch
switch
(
x
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
...
...
tests/lora/test_layers.py
View file @
1e96c334
...
...
@@ -170,7 +170,8 @@ def create_random_inputs(
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_embeddings
(
dist_init
,
num_loras
,
device
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
def
test_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
)
->
None
:
torch
.
set_default_device
(
device
)
max_loras
=
8
...
...
@@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
lora_dtype
=
torch
.
float16
)
def
create_random_embedding_layer
():
embedding
=
VocabParallelEmbedding
(
512
,
256
)
embedding
=
VocabParallelEmbedding
(
vocab_size
,
256
)
embedding
.
weight
.
data
=
torch
.
rand_like
(
embedding
.
weight
.
data
)
embedding
.
weight
.
data
[
512
:,
:]
=
0
embedding
.
weight
.
data
[
vocab_size
:,
:]
=
0
lora_embedding
=
VocabParallelEmbeddingWithLoRA
(
embedding
)
lora_embedding
.
create_lora_weights
(
max_loras
,
lora_config
)
...
...
@@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids
=
list
(
lora_dict
.
keys
()),
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_range
=
(
1
,
512
),
input_range
=
(
1
,
vocab_size
),
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_embedding
.
set_mapping
(
*
mapping_info
)
lora_result
=
lora_embedding
(
torch
.
cat
(
inputs
))
...
...
@@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids
=
[
0
],
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_range
=
(
1
,
512
),
input_range
=
(
1
,
vocab_size
),
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_embedding
.
set_mapping
(
*
mapping_info
,
)
lora_result
=
lora_embedding
(
torch
.
cat
(
inputs
))
...
...
@@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
# reason="Fails when loras are in any slot other than the first.")
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
)
->
None
:
torch
.
set_default_device
(
device
)
max_loras
=
8
...
...
@@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
lora_dtype
=
torch
.
float16
)
def
create_random_embedding_layer
():
embedding
=
VocabParallelEmbedding
(
512
,
256
)
embedding
=
VocabParallelEmbedding
(
vocab_size
,
256
)
embedding_data
=
torch
.
rand_like
(
embedding
.
weight
.
data
)
embedding
.
weight
.
data
=
embedding_data
embedding
.
weight
.
data
[
512
:,
:]
=
0
embedding
.
weight
.
data
[
vocab_size
:,
:]
=
0
expanded_embedding
=
VocabParallelEmbedding
(
512
+
lora_config
.
lora_extra_vocab_size
*
max_loras
,
vocab_size
+
lora_config
.
lora_extra_vocab_size
*
max_loras
,
256
,
org_num_embeddings
=
512
)
expanded_embedding
.
weight
.
data
[:
512
,
:]
=
embedding_data
org_num_embeddings
=
vocab_size
)
expanded_embedding
.
weight
.
data
[:
vocab_size
,
:]
=
embedding_data
# We need to deepcopy the embedding as it will be modified
# in place
lora_embedding
=
VocabParallelEmbeddingWithLoRA
(
...
...
@@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
id_to_index
,
layer
=
lora_embedding
,
layer_weights
=
torch
.
zeros
(
(
256
,
512
+
lora_config
.
lora_extra_vocab_size
)),
(
256
,
vocab_size
+
lora_config
.
lora_extra_vocab_size
)),
generate_embeddings_tensor
=
256
,
)
...
...
@@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids
=
list
(
lora_dict
.
keys
()),
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_range
=
(
1
,
512
),
input_range
=
(
1
,
vocab_size
),
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
...
...
@@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
for
input_
,
original_input_
,
lora_id
in
zip
(
inputs
,
original_inputs
,
prompt_mapping
):
embedding_id
=
lora_id
-
1
input_
[
-
1
]
=
512
+
(
embedding_id
*
embeddings_tensor_len
)
original_input_
[
-
1
]
=
512
input_
[
-
2
]
=
512
+
((
embedding_id
+
1
)
*
embeddings_tensor_len
-
1
)
original_input_
[
-
2
]
=
512
+
embeddings_tensor_len
-
1
input_
[
-
1
]
=
vocab_size
+
(
embedding_id
*
embeddings_tensor_len
)
original_input_
[
-
1
]
=
vocab_size
input_
[
-
2
]
=
vocab_size
+
(
(
embedding_id
+
1
)
*
embeddings_tensor_len
-
1
)
original_input_
[
-
2
]
=
vocab_size
+
embeddings_tensor_len
-
1
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_embedding
.
set_mapping
(
*
mapping_info
,
)
expanded_embedding
.
weight
[
512
:
512
+
expanded_embedding
.
weight
[
vocab_size
:
vocab_size
+
(
embeddings_tensor_len
*
max_loras
)]
=
torch
.
cat
(
embeddings_tensors
)
...
...
@@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids
=
[
0
],
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_range
=
(
1
,
512
),
input_range
=
(
1
,
vocab_size
),
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
original_inputs
=
deepcopy
(
inputs
)
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_embedding
.
set_mapping
(
*
mapping_info
,
)
lora_result
=
lora_embedding
(
torch
.
cat
(
original_inputs
))
...
...
@@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
,
vocab_size
)
->
None
:
torch
.
set_default_device
(
device
)
max_loras
=
8
...
...
@@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_dtype
=
torch
.
float16
)
def
_pretest
():
linear
=
ParallelLMHead
(
32000
+
lora_config
.
lora_extra_vocab_size
,
1024
,
32000
)
linear
=
ParallelLMHead
(
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
1024
,
vocab_size
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
[:,
32000
:]
=
0
linear
.
weight
.
data
[:,
vocab_size
:]
=
0
logits_processor
=
LogitsProcessor
(
32000
+
lora_config
.
lora_extra_vocab_size
,
32000
)
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
vocab_size
)
lora_logits_processor
=
LogitsProcessorWithLoRA
(
logits_processor
,
1024
,
linear
.
weight
.
dtype
,
linear
.
weight
.
device
)
lora_logits_processor
.
create_lora_weights
(
max_loras
,
lora_config
)
...
...
@@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_mapping
,
id_to_index
,
max_loras
,
32000
,
vocab_size
,
lora_config
.
lora_extra_vocab_size
,
)
lora_logits_processor
.
set_mapping
(
*
mapping_info
,
)
...
...
@@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
org_vocab_size
:
logits_processor
.
org_vocab_size
+
embeddings_tensor_len
]
=
embeddings_tensor
logits_processor
.
org_vocab_size
=
(
32000
+
logits_processor
.
org_vocab_size
=
(
vocab_size
+
lora_config
.
lora_extra_vocab_size
)
expected_results
=
[]
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
...
...
@@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
result
=
logits_processor
.
_get_logits
(
hidden_states
=
input_
,
embedding
=
linear
.
weight
,
embedding_bias
=
None
)
result
[:,
32000
+
embeddings_tensor_len
:]
=
float
(
"-inf"
)
result
[:,
vocab_size
+
embeddings_tensor_len
:]
=
float
(
"-inf"
)
result
+=
input_
@
lora
.
lora_a
@
lora
.
lora_b
*
lora
.
scaling
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
logits_processor
.
org_vocab_size
=
32000
logits_processor
.
org_vocab_size
=
vocab_size
# Check that resetting the lora weights succeeds
...
...
@@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
32000
,
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_logits_processor
.
set_mapping
(
*
mapping_info
,
)
lora_result
=
lora_logits_processor
.
_get_logits
(
hidden_states
=
torch
.
cat
(
inputs
),
embedding
=
original_weight
,
embedding_bias
=
None
)[:,
:
32000
]
embedding_bias
=
None
)[:,
:
vocab_size
]
expected_result
=
logits_processor
.
_get_logits
(
hidden_states
=
torch
.
cat
(
inputs
),
embedding
=
original_weight
,
...
...
tests/lora/test_punica.py
View file @
1e96c334
...
...
@@ -43,10 +43,51 @@ def _lora_ref_impl(
H1
=
H2
=
[
128
,
256
,
512
,
1024
,
1152
,
1280
,
1536
,
2048
,
2304
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
4608
,
5120
,
5504
,
5632
,
6144
,
6848
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
22016
,
24576
,
27392
,
32000
,
32256
,
32512
,
32768
,
33024
128
,
256
,
512
,
1024
,
1152
,
1280
,
1536
,
2048
,
2304
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
4608
,
5120
,
5504
,
5632
,
6144
,
6848
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
22016
,
24576
,
27392
,
32000
,
32256
,
32512
,
32768
,
33024
,
36864
,
49152
,
64000
,
64256
,
102400
,
102656
,
128000
,
128256
,
]
SEED
=
[
0xabcdabcd987
]
CUDA_DEVICES
=
[
...
...
vllm/lora/layers.py
View file @
1e96c334
...
...
@@ -935,9 +935,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if
32000
<
self
.
base_layer
.
vocab_size
>
33024
:
if
32000
<
self
.
base_layer
.
vocab_size
>
128512
:
raise
ValueError
(
"When using LoRA, vocab size must be "
"32000 >= vocab_size <=
33024
"
)
"32000 >= vocab_size <=
128512
"
)
self
.
lora_a_stacked
=
torch
.
zeros
(
(
max_loras
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment