Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e96c334
Unverified
Commit
1e96c334
authored
Apr 11, 2024
by
Antoni Baum
Committed by
GitHub
Apr 11, 2024
Browse files
Add extra punica sizes to support bigger vocabs (#4015)
parent
95e7d4a9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
109 additions
and
48 deletions
+109
-48
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+11
-1
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+7
-7
tests/lora/test_layers.py
tests/lora/test_layers.py
+44
-34
tests/lora/test_punica.py
tests/lora/test_punica.py
+45
-4
vllm/lora/layers.py
vllm/lora/layers.py
+2
-2
No files found.
csrc/punica/bgmv/bgmv_config.h
View file @
1e96c334
...
...
@@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
...
...
csrc/punica/punica_ops.cc
View file @
1e96c334
...
...
@@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
}
}
inline
constexpr
uint
32
_t
pack_u
16
(
uint
16
_t
a
,
uint
16
_t
b
)
{
return
(
uint
32
_t
(
a
)
<<
16
)
|
uint
32
_t
(
b
);
inline
constexpr
uint
64
_t
pack_u
32
(
uint
32
_t
a
,
uint
32
_t
b
)
{
return
(
uint
64
_t
(
a
)
<<
32
)
|
uint
64
_t
(
b
);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
...
...
@@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
template
<
typename
in_T
,
typename
out_T
,
typename
W_T
>
inline
bool
launch_bgmv_kernel
(
out_T
*
Y
,
const
in_T
*
X
,
const
W_T
*
W
,
const
int64_t
*
lora_indices
,
uint
16
_t
in_features
,
uint
16
_t
out_features
,
uint
32
_t
in_features
,
uint
32
_t
out_features
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
switch
(
pack_u
16
(
in_features
,
out_features
))
{
switch
(
pack_u
32
(
in_features
,
out_features
))
{
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u
16
(feat_in, feat_out): \
case pack_u
32
(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
...
...
@@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ
(
y
.
size
(
0
),
x
.
size
(
0
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
x
));
bool
ok
=
false
;
if
(
h_in
<
65536
&&
h_out
<
65536
)
{
if
(
h_in
<
=
128512
&&
h_out
<
=
128512
)
{
// TODO: See if we can get rid of this massive nested switch
switch
(
x
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
...
...
@@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ
(
y
.
size
(
0
),
x
.
size
(
0
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
x
));
bool
ok
=
false
;
if
(
h_in
<
65536
&&
h_out
<
65536
)
{
if
(
h_in
<
=
128512
&&
h_out
<
=
128512
)
{
// TODO: See if we can get rid of this massive nested switch
switch
(
x
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
...
...
tests/lora/test_layers.py
View file @
1e96c334
...
...
@@ -170,7 +170,8 @@ def create_random_inputs(
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_embeddings
(
dist_init
,
num_loras
,
device
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
def
test_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
)
->
None
:
torch
.
set_default_device
(
device
)
max_loras
=
8
...
...
@@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
lora_dtype
=
torch
.
float16
)
def
create_random_embedding_layer
():
embedding
=
VocabParallelEmbedding
(
512
,
256
)
embedding
=
VocabParallelEmbedding
(
vocab_size
,
256
)
embedding
.
weight
.
data
=
torch
.
rand_like
(
embedding
.
weight
.
data
)
embedding
.
weight
.
data
[
512
:,
:]
=
0
embedding
.
weight
.
data
[
vocab_size
:,
:]
=
0
lora_embedding
=
VocabParallelEmbeddingWithLoRA
(
embedding
)
lora_embedding
.
create_lora_weights
(
max_loras
,
lora_config
)
...
...
@@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids
=
list
(
lora_dict
.
keys
()),
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_range
=
(
1
,
512
),
input_range
=
(
1
,
vocab_size
),
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_embedding
.
set_mapping
(
*
mapping_info
)
lora_result
=
lora_embedding
(
torch
.
cat
(
inputs
))
...
...
@@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids
=
[
0
],
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_range
=
(
1
,
512
),
input_range
=
(
1
,
vocab_size
),
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_embedding
.
set_mapping
(
*
mapping_info
,
)
lora_result
=
lora_embedding
(
torch
.
cat
(
inputs
))
...
...
@@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
# reason="Fails when loras are in any slot other than the first.")
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
)
->
None
:
torch
.
set_default_device
(
device
)
max_loras
=
8
...
...
@@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
lora_dtype
=
torch
.
float16
)
def
create_random_embedding_layer
():
embedding
=
VocabParallelEmbedding
(
512
,
256
)
embedding
=
VocabParallelEmbedding
(
vocab_size
,
256
)
embedding_data
=
torch
.
rand_like
(
embedding
.
weight
.
data
)
embedding
.
weight
.
data
=
embedding_data
embedding
.
weight
.
data
[
512
:,
:]
=
0
embedding
.
weight
.
data
[
vocab_size
:,
:]
=
0
expanded_embedding
=
VocabParallelEmbedding
(
512
+
lora_config
.
lora_extra_vocab_size
*
max_loras
,
vocab_size
+
lora_config
.
lora_extra_vocab_size
*
max_loras
,
256
,
org_num_embeddings
=
512
)
expanded_embedding
.
weight
.
data
[:
512
,
:]
=
embedding_data
org_num_embeddings
=
vocab_size
)
expanded_embedding
.
weight
.
data
[:
vocab_size
,
:]
=
embedding_data
# We need to deepcopy the embedding as it will be modified
# in place
lora_embedding
=
VocabParallelEmbeddingWithLoRA
(
...
...
@@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
id_to_index
,
layer
=
lora_embedding
,
layer_weights
=
torch
.
zeros
(
(
256
,
512
+
lora_config
.
lora_extra_vocab_size
)),
(
256
,
vocab_size
+
lora_config
.
lora_extra_vocab_size
)),
generate_embeddings_tensor
=
256
,
)
...
...
@@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids
=
list
(
lora_dict
.
keys
()),
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_range
=
(
1
,
512
),
input_range
=
(
1
,
vocab_size
),
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
...
...
@@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
for
input_
,
original_input_
,
lora_id
in
zip
(
inputs
,
original_inputs
,
prompt_mapping
):
embedding_id
=
lora_id
-
1
input_
[
-
1
]
=
512
+
(
embedding_id
*
embeddings_tensor_len
)
original_input_
[
-
1
]
=
512
input_
[
-
2
]
=
512
+
((
embedding_id
+
1
)
*
embeddings_tensor_len
-
1
)
original_input_
[
-
2
]
=
512
+
embeddings_tensor_len
-
1
input_
[
-
1
]
=
vocab_size
+
(
embedding_id
*
embeddings_tensor_len
)
original_input_
[
-
1
]
=
vocab_size
input_
[
-
2
]
=
vocab_size
+
(
(
embedding_id
+
1
)
*
embeddings_tensor_len
-
1
)
original_input_
[
-
2
]
=
vocab_size
+
embeddings_tensor_len
-
1
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_embedding
.
set_mapping
(
*
mapping_info
,
)
expanded_embedding
.
weight
[
512
:
512
+
expanded_embedding
.
weight
[
vocab_size
:
vocab_size
+
(
embeddings_tensor_len
*
max_loras
)]
=
torch
.
cat
(
embeddings_tensors
)
...
...
@@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids
=
[
0
],
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_range
=
(
1
,
512
),
input_range
=
(
1
,
vocab_size
),
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
original_inputs
=
deepcopy
(
inputs
)
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_embedding
.
set_mapping
(
*
mapping_info
,
)
lora_result
=
lora_embedding
(
torch
.
cat
(
original_inputs
))
...
...
@@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
,
vocab_size
)
->
None
:
torch
.
set_default_device
(
device
)
max_loras
=
8
...
...
@@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_dtype
=
torch
.
float16
)
def
_pretest
():
linear
=
ParallelLMHead
(
32000
+
lora_config
.
lora_extra_vocab_size
,
1024
,
32000
)
linear
=
ParallelLMHead
(
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
1024
,
vocab_size
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
[:,
32000
:]
=
0
linear
.
weight
.
data
[:,
vocab_size
:]
=
0
logits_processor
=
LogitsProcessor
(
32000
+
lora_config
.
lora_extra_vocab_size
,
32000
)
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
vocab_size
)
lora_logits_processor
=
LogitsProcessorWithLoRA
(
logits_processor
,
1024
,
linear
.
weight
.
dtype
,
linear
.
weight
.
device
)
lora_logits_processor
.
create_lora_weights
(
max_loras
,
lora_config
)
...
...
@@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_mapping
,
id_to_index
,
max_loras
,
32000
,
vocab_size
,
lora_config
.
lora_extra_vocab_size
,
)
lora_logits_processor
.
set_mapping
(
*
mapping_info
,
)
...
...
@@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
org_vocab_size
:
logits_processor
.
org_vocab_size
+
embeddings_tensor_len
]
=
embeddings_tensor
logits_processor
.
org_vocab_size
=
(
32000
+
logits_processor
.
org_vocab_size
=
(
vocab_size
+
lora_config
.
lora_extra_vocab_size
)
expected_results
=
[]
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
...
...
@@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
result
=
logits_processor
.
_get_logits
(
hidden_states
=
input_
,
embedding
=
linear
.
weight
,
embedding_bias
=
None
)
result
[:,
32000
+
embeddings_tensor_len
:]
=
float
(
"-inf"
)
result
[:,
vocab_size
+
embeddings_tensor_len
:]
=
float
(
"-inf"
)
result
+=
input_
@
lora
.
lora_a
@
lora
.
lora_b
*
lora
.
scaling
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
logits_processor
.
org_vocab_size
=
32000
logits_processor
.
org_vocab_size
=
vocab_size
# Check that resetting the lora weights succeeds
...
...
@@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
32000
,
vocab_size
,
lora_config
.
lora_extra_vocab_size
)
lora_logits_processor
.
set_mapping
(
*
mapping_info
,
)
lora_result
=
lora_logits_processor
.
_get_logits
(
hidden_states
=
torch
.
cat
(
inputs
),
embedding
=
original_weight
,
embedding_bias
=
None
)[:,
:
32000
]
embedding_bias
=
None
)[:,
:
vocab_size
]
expected_result
=
logits_processor
.
_get_logits
(
hidden_states
=
torch
.
cat
(
inputs
),
embedding
=
original_weight
,
...
...
tests/lora/test_punica.py
View file @
1e96c334
...
...
@@ -43,10 +43,51 @@ def _lora_ref_impl(
H1
=
H2
=
[
128
,
256
,
512
,
1024
,
1152
,
1280
,
1536
,
2048
,
2304
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
4608
,
5120
,
5504
,
5632
,
6144
,
6848
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
22016
,
24576
,
27392
,
32000
,
32256
,
32512
,
32768
,
33024
128
,
256
,
512
,
1024
,
1152
,
1280
,
1536
,
2048
,
2304
,
2560
,
2752
,
3072
,
3456
,
3584
,
4096
,
4608
,
5120
,
5504
,
5632
,
6144
,
6848
,
6912
,
7168
,
8192
,
9216
,
10240
,
11008
,
13824
,
14336
,
22016
,
24576
,
27392
,
32000
,
32256
,
32512
,
32768
,
33024
,
36864
,
49152
,
64000
,
64256
,
102400
,
102656
,
128000
,
128256
,
]
SEED
=
[
0xabcdabcd987
]
CUDA_DEVICES
=
[
...
...
vllm/lora/layers.py
View file @
1e96c334
...
...
@@ -935,9 +935,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if
32000
<
self
.
base_layer
.
vocab_size
>
33024
:
if
32000
<
self
.
base_layer
.
vocab_size
>
128512
:
raise
ValueError
(
"When using LoRA, vocab size must be "
"32000 >= vocab_size <=
33024
"
)
"32000 >= vocab_size <=
128512
"
)
self
.
lora_a_stacked
=
torch
.
zeros
(
(
max_loras
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment