Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ccdc490d
Unverified
Commit
ccdc490d
authored
Jun 06, 2024
by
Antoni Baum
Committed by
GitHub
Jun 06, 2024
Browse files
[Core] Change LoRA embedding sharding to support loading methods (#5038)
parent
a31cab75
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
661 additions
and
129 deletions
+661
-129
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-8
tests/conftest.py
tests/conftest.py
+21
-0
tests/lora/conftest.py
tests/lora/conftest.py
+16
-2
tests/lora/test_layers.py
tests/lora/test_layers.py
+217
-2
tests/lora/test_llama.py
tests/lora/test_llama.py
+7
-10
tests/lora/test_long_context.py
tests/lora/test_long_context.py
+13
-10
tests/test_sharded_state_loader.py
tests/test_sharded_state_loader.py
+80
-44
vllm/lora/layers.py
vllm/lora/layers.py
+57
-19
vllm/lora/utils.py
vllm/lora/utils.py
+2
-1
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+235
-25
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+11
-8
No files found.
.buildkite/test-pipeline.yaml
View file @
ccdc490d
...
@@ -46,6 +46,7 @@ steps:
...
@@ -46,6 +46,7 @@ steps:
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
-
pytest -v -s spec_decode/e2e/test_integration_dist.py
-
pytest -v -s spec_decode/e2e/test_integration_dist.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
-
label
:
Distributed Tests (Multiple Groups)
-
label
:
Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
...
@@ -138,14 +139,7 @@ steps:
...
@@ -138,14 +139,7 @@ steps:
num_gpus
:
4
num_gpus
:
4
# This test runs llama 13B, so it is required to run on 4 GPUs.
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands
:
commands
:
# Temporarily run this way because we cannot clean up GPU mem usage
-
pytest -v -s -x lora/test_long_context.py
# for multi GPU tests.
# TODO(sang): Fix it.
-
pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
-
pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
-
pytest -v -s lora/test_long_context.py::test_self_consistency
-
pytest -v -s lora/test_long_context.py::test_quality
-
pytest -v -s lora/test_long_context.py::test_max_len
-
label
:
Tensorizer Test
-
label
:
Tensorizer Test
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
...
...
tests/conftest.py
View file @
ccdc490d
import
contextlib
import
contextlib
import
gc
import
gc
import
os
import
os
import
subprocess
import
sys
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
import
pytest
import
pytest
...
@@ -522,3 +524,22 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
...
@@ -522,3 +524,22 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
# To capture vllm log, we should enable propagate=True temporarily
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
# because caplog depends on logs propagated to the root logger.
yield
caplog
yield
caplog
@
pytest
.
fixture
(
scope
=
"session"
)
def
num_gpus_available
():
"""Get number of GPUs without initializing the CUDA context
in current process."""
try
:
out
=
subprocess
.
run
([
sys
.
executable
,
"-c"
,
"import torch; print(torch.cuda.device_count())"
],
capture_output
=
True
,
check
=
True
,
text
=
True
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
warning
(
"Failed to get number of GPUs."
,
exc_info
=
e
)
return
0
return
int
(
out
.
stdout
.
strip
())
tests/lora/conftest.py
View file @
ccdc490d
...
@@ -42,10 +42,24 @@ def cleanup():
...
@@ -42,10 +42,24 @@ def cleanup():
ray
.
shutdown
()
ray
.
shutdown
()
@
pytest
.
fixture
()
def
should_do_global_cleanup_after_test
(
request
)
->
bool
:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
if
request
.
node
.
get_closest_marker
(
"skip_global_cleanup"
):
return
False
return
True
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
cleanup_fixture
():
def
cleanup_fixture
(
should_do_global_cleanup_after_test
:
bool
):
yield
yield
cleanup
()
if
should_do_global_cleanup_after_test
:
cleanup
()
@
pytest
.
fixture
@
pytest
.
fixture
...
...
tests/lora/test_layers.py
View file @
ccdc490d
...
@@ -2,6 +2,7 @@ import random
...
@@ -2,6 +2,7 @@ import random
from
copy
import
deepcopy
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
torch
...
@@ -32,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -32,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
,
get_masked_input_and_mask
)
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
.utils
import
DummyLoRAManager
from
.utils
import
DummyLoRAManager
...
@@ -427,7 +428,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
...
@@ -427,7 +428,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
logits_processor
=
LogitsProcessor
(
logits_processor
=
LogitsProcessor
(
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
vocab_size
)
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
vocab_size
)
lora_logits_processor
=
LogitsProcessorWithLoRA
(
lora_logits_processor
=
LogitsProcessorWithLoRA
(
logits_processor
,
1024
,
linear
.
weight
.
dtype
,
linear
.
weight
.
device
)
logits_processor
,
1024
,
linear
.
weight
.
dtype
,
linear
.
weight
.
device
,
None
)
lora_logits_processor
.
create_lora_weights
(
max_loras
,
lora_config
)
lora_logits_processor
.
create_lora_weights
(
max_loras
,
lora_config
)
return
linear
,
logits_processor
,
lora_logits_processor
return
linear
,
logits_processor
,
lora_logits_processor
...
@@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
...
@@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
torch
.
allclose
(
ref_q
,
actual_q
)
torch
.
allclose
(
ref_q
,
actual_q
)
torch
.
allclose
(
ref_k
,
actual_k
)
torch
.
allclose
(
ref_k
,
actual_k
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
256
)))
def
test_vocab_parallel_embedding_indices
(
tp_size
,
seed
):
random
.
seed
(
seed
)
vocab_size
=
random
.
randint
(
4000
,
64000
)
added_vocab_size
=
random
.
randint
(
0
,
1024
)
org_vocab_size
=
vocab_size
-
added_vocab_size
last_org_vocab_end_index
=
0
last_added_vocab_end_index
=
org_vocab_size
computed_vocab_size
=
0
computed_org_vocab_size
=
0
computed_added_vocab_size
=
0
vocab_size_padded
=
-
1
all_org_tokens
=
[]
all_added_tokens
=
[]
token_ids
=
[]
for
tp_rank
in
range
(
tp_size
):
with
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank"
,
return_value
=
tp_rank
),
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size"
,
return_value
=
tp_size
):
vocab_embedding
=
VocabParallelEmbedding
(
vocab_size
,
1
,
org_num_embeddings
=
org_vocab_size
)
vocab_size_padded
=
vocab_embedding
.
num_embeddings_padded
shard_indices
=
vocab_embedding
.
shard_indices
# Assert that the ranges are contiguous
assert
shard_indices
.
org_vocab_start_index
==
last_org_vocab_end_index
assert
(
shard_indices
.
added_vocab_start_index
==
last_added_vocab_end_index
)
# Ensure that we are not exceeding the vocab size
computed_vocab_size
+=
shard_indices
.
num_elements_padded
computed_org_vocab_size
+=
shard_indices
.
num_org_elements
computed_added_vocab_size
+=
shard_indices
.
num_added_elements
# Ensure that the ranges are not overlapping
all_org_tokens
.
extend
(
range
(
shard_indices
.
org_vocab_start_index
,
shard_indices
.
org_vocab_end_index
))
all_added_tokens
.
extend
(
range
(
shard_indices
.
added_vocab_start_index
,
shard_indices
.
added_vocab_end_index
))
token_ids
.
extend
(
range
(
shard_indices
.
org_vocab_start_index
,
shard_indices
.
org_vocab_end_index
))
token_ids
.
extend
([
-
1
]
*
(
shard_indices
.
num_org_elements_padded
-
shard_indices
.
num_org_elements
))
token_ids
.
extend
(
range
(
shard_indices
.
added_vocab_start_index
,
shard_indices
.
added_vocab_end_index
))
token_ids
.
extend
([
-
1
]
*
(
shard_indices
.
num_added_elements_padded
-
shard_indices
.
num_added_elements
))
last_org_vocab_end_index
=
shard_indices
.
org_vocab_end_index
last_added_vocab_end_index
=
shard_indices
.
added_vocab_end_index
assert
computed_vocab_size
==
vocab_size_padded
assert
computed_org_vocab_size
==
org_vocab_size
assert
computed_added_vocab_size
==
added_vocab_size
# Ensure that the ranges are not overlapping
assert
len
(
all_org_tokens
)
==
len
(
set
(
all_org_tokens
))
assert
len
(
all_added_tokens
)
==
len
(
set
(
all_added_tokens
))
assert
not
set
(
all_org_tokens
).
intersection
(
set
(
all_added_tokens
))
token_ids_tensor
=
torch
.
tensor
(
token_ids
,
dtype
=
torch
.
long
)
reindex_mapping
=
vocab_embedding
.
get_sharded_to_full_mapping
()
assert
reindex_mapping
is
not
None
or
tp_size
==
1
if
reindex_mapping
is
not
None
:
reindexed_token_ids
=
token_ids_tensor
[
reindex_mapping
]
expected
=
torch
.
tensor
(
list
(
range
(
0
,
vocab_size
)))
assert
reindexed_token_ids
[:
vocab_size
].
equal
(
expected
)
assert
torch
.
all
(
reindexed_token_ids
[
vocab_size
:]
==
-
1
)
def
test_get_masked_input_and_mask
():
x
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
])
# base tp 1 case, no padding
modified_x
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
0
)
assert
torch
.
equal
(
x
,
modified_x
)
# tp 2 case, no padding
modified_x_rank_0
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
4
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
10
,
num_org_vocab_padding
=
0
)
modified_x_rank_1
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
4
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
10
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
0
)
assert
torch
.
equal
(
modified_x_rank_0
,
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
0
,
0
,
0
,
4
,
5
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_1
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
,
0
,
4
,
5
]))
# tp 4 case, no padding
modified_x_rank_0
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
2
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
9
,
num_org_vocab_padding
=
0
)
modified_x_rank_1
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
2
,
org_vocab_end_index
=
4
,
added_vocab_start_index
=
9
,
added_vocab_end_index
=
10
,
num_org_vocab_padding
=
0
)
modified_x_rank_2
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
4
,
org_vocab_end_index
=
6
,
added_vocab_start_index
=
10
,
added_vocab_end_index
=
11
,
num_org_vocab_padding
=
0
)
modified_x_rank_3
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
6
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
11
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
0
)
assert
torch
.
equal
(
modified_x_rank_0
,
torch
.
tensor
([
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_1
,
torch
.
tensor
([
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_2
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
2
,
0
]))
assert
torch
.
equal
(
modified_x_rank_3
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
2
]))
# base tp 1 case, with padding
modified_x
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
2
)
assert
torch
.
equal
(
modified_x
,
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
10
,
11
,
12
,
13
]))
# tp 2 case, with padding
modified_x_rank_0
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
4
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
10
,
num_org_vocab_padding
=
2
)
modified_x_rank_1
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
4
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
10
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
2
)
assert
torch
.
equal
(
modified_x_rank_0
,
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
0
,
0
,
0
,
6
,
7
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_1
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
,
0
,
6
,
7
]))
# tp 4 case, with padding
modified_x_rank_0
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
2
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
9
,
num_org_vocab_padding
=
2
)
modified_x_rank_1
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
2
,
org_vocab_end_index
=
4
,
added_vocab_start_index
=
9
,
added_vocab_end_index
=
10
,
num_org_vocab_padding
=
2
)
modified_x_rank_2
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
4
,
org_vocab_end_index
=
6
,
added_vocab_start_index
=
10
,
added_vocab_end_index
=
11
,
num_org_vocab_padding
=
2
)
modified_x_rank_3
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
6
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
11
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
2
)
assert
torch
.
equal
(
modified_x_rank_0
,
torch
.
tensor
([
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
4
,
0
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_1
,
torch
.
tensor
([
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
4
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_2
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
4
,
0
]))
assert
torch
.
equal
(
modified_x_rank_3
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
4
]))
tests/lora/test_llama.py
View file @
ccdc490d
...
@@ -36,11 +36,10 @@ def do_sample(llm, lora_path: str, lora_id: int):
...
@@ -36,11 +36,10 @@ def do_sample(llm, lora_path: str, lora_id: int):
return
generated_texts
return
generated_texts
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
4
])
def
test_llama_lora
(
sql_lora_files
,
tp_size
):
def
test_llama_lora
(
sql_lora_files
,
tp_size
,
num_gpus_available
):
# Cannot use as it will initialize torch.cuda too early...
if
num_gpus_available
<
tp_size
:
# if torch.cuda.device_count() < tp_size:
pytest
.
skip
(
f
"Not enough GPUs for tensor parallelism
{
tp_size
}
"
)
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm
=
vllm
.
LLM
(
MODEL_PATH
,
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
enable_lora
=
True
,
...
@@ -80,11 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size):
...
@@ -80,11 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size):
print
(
"removing lora"
)
print
(
"removing lora"
)
@
pytest
.
mark
.
skip
(
"Requires multiple GPUs"
)
def
test_llama_tensor_parallel_equality
(
sql_lora_files
,
num_gpus_available
):
def
test_llama_tensor_parallel_equality
(
sql_lora_files
):
if
num_gpus_available
<
4
:
# Cannot use as it will initialize torch.cuda too early...
pytest
.
skip
(
"Not enough GPUs for tensor parallelism 4"
)
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
llm_tp1
=
vllm
.
LLM
(
MODEL_PATH
,
llm_tp1
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
enable_lora
=
True
,
...
...
tests/lora/test_long_context.py
View file @
ccdc490d
...
@@ -102,22 +102,21 @@ def batched_generate(
...
@@ -102,22 +102,21 @@ def batched_generate(
return
[
outputs
[
i
].
outputs
[
0
].
text
.
strip
()
for
i
in
range
(
len
(
outputs
))]
return
[
outputs
[
i
].
outputs
[
0
].
text
.
strip
()
for
i
in
range
(
len
(
outputs
))]
@
pytest
.
fixture
@
pytest
.
fixture
(
scope
=
"module"
)
def
lora_llm
(
long_context_infos
):
def
lora_llm
(
long_context_infos
):
scaling_factors
=
[
scaling_factors
=
[
context_len_to_scaling_factor
[
info
[
"context_length"
]]
context_len_to_scaling_factor
[
info
[
"context_length"
]]
for
info
in
long_context_infos
.
values
()
for
info
in
long_context_infos
.
values
()
]
]
llm
=
vllm
.
LLM
(
llm
=
vllm
.
LLM
(
"meta-llama/Llama-2-13b-chat-hf"
,
"meta-llama/Llama-2-13b-chat-hf"
,
enable_lora
=
True
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_num_seqs
=
16
,
max_loras
=
2
,
max_loras
=
2
,
long_lora_scaling_factors
=
tuple
(
scaling_factors
),
long_lora_scaling_factors
=
tuple
(
scaling_factors
),
max_num_batched_tokens
=
4096
*
8
,
max_num_batched_tokens
=
4096
*
8
,
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
distributed_executor_backend
=
"mp"
)
)
yield
llm
yield
llm
del
llm
del
llm
...
@@ -154,6 +153,7 @@ def test_rotary_emb_replaced(dist_init):
...
@@ -154,6 +153,7 @@ def test_rotary_emb_replaced(dist_init):
assert
rotary_emb_count
==
32
assert
rotary_emb_count
==
32
@
pytest
.
mark
.
skip_global_cleanup
def
test_batched_rope_kernel
(
lora_llm
,
long_context_infos
):
def
test_batched_rope_kernel
(
lora_llm
,
long_context_infos
):
"""We test the batched kernel by comparing the results of batched an
"""We test the batched kernel by comparing the results of batched an
non-batched generation.
non-batched generation.
...
@@ -188,6 +188,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
...
@@ -188,6 +188,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
f
"same:
\n
{
batched
}
\n
{
non_batched
}
"
)
f
"same:
\n
{
batched
}
\n
{
non_batched
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_self_consistency
(
lora_llm
,
long_context_infos
):
def
test_self_consistency
(
lora_llm
,
long_context_infos
):
"""We test consistency of the batched kernel by permuting batched
"""We test consistency of the batched kernel by permuting batched
inputs and comparing the results to the non-permuted batched results.
inputs and comparing the results to the non-permuted batched results.
...
@@ -227,6 +228,7 @@ def test_self_consistency(lora_llm, long_context_infos):
...
@@ -227,6 +228,7 @@ def test_self_consistency(lora_llm, long_context_infos):
f
"
\n
{
permutated_batched_results
[
permutation
[
i
]]
}
"
)
f
"
\n
{
permutated_batched_results
[
permutation
[
i
]]
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_quality
(
lora_llm
,
long_context_infos
):
def
test_quality
(
lora_llm
,
long_context_infos
):
"""We test the quality of the answers given by the LoRA model by
"""We test the quality of the answers given by the LoRA model by
comparing the generated text to the merged model's outputs.
comparing the generated text to the merged model's outputs.
...
@@ -257,6 +259,7 @@ def test_quality(lora_llm, long_context_infos):
...
@@ -257,6 +259,7 @@ def test_quality(lora_llm, long_context_infos):
assert
np
.
mean
(
scores
)
>
0.5
assert
np
.
mean
(
scores
)
>
0.5
@
pytest
.
mark
.
skip_global_cleanup
def
test_max_len
(
lora_llm
,
long_context_infos
):
def
test_max_len
(
lora_llm
,
long_context_infos
):
"""Test that we raise an ValueError when the input of a given LoRA
"""Test that we raise an ValueError when the input of a given LoRA
model exceeds the maximum length."""
model exceeds the maximum length."""
...
...
tests/test_sharded_state_loader.py
View file @
ccdc490d
import
multiprocessing
as
mp
import
os
import
os
import
shutil
import
shutil
from
tempfile
import
TemporaryDirectory
from
tempfile
import
TemporaryDirectory
...
@@ -18,9 +19,7 @@ prompts = [
...
@@ -18,9 +19,7 @@ prompts = [
# Create a sampling params object.
# Create a sampling params object.
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
temperature
=
0
,
top_p
=
0.95
,
seed
=
0
,
max_tokens
=
256
,
max_tokens
=
256
,
ignore_eos
=
True
,
ignore_eos
=
True
,
)
)
...
@@ -43,48 +42,85 @@ def test_filter_subtensors():
...
@@ -43,48 +42,85 @@ def test_filter_subtensors():
assert
tensor
.
equal
(
state_dict
[
key
])
assert
tensor
.
equal
(
state_dict
[
key
])
@
pytest
.
fixture
(
scope
=
"module"
)
def
llama_2_7b_files
():
with
TemporaryDirectory
()
as
cache_dir
:
input_dir
=
snapshot_download
(
"meta-llama/Llama-2-7b-hf"
,
cache_dir
=
cache_dir
,
ignore_patterns
=
"*.bin*"
)
yield
input_dir
def
_run_writer
(
input_dir
,
output_dir
,
weights_patterns
,
**
kwargs
):
llm_sharded_writer
=
LLM
(
model
=
input_dir
,
**
kwargs
)
# Dump worker states to output directory
llm_sharded_writer
.
llm_engine
.
model_executor
.
save_sharded_state
(
path
=
output_dir
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
input_dir
):
if
not
any
(
file
.
endswith
(
ext
)
for
ext
in
weights_patterns
):
shutil
.
copy
(
f
"
{
input_dir
}
/
{
file
}
"
,
output_dir
)
def
_run_generate
(
input_dir
,
queue
:
mp
.
Queue
,
**
kwargs
):
llm
=
LLM
(
model
=
input_dir
,
**
kwargs
)
gen
=
llm
.
generate
(
prompts
,
sampling_params
)
queue
.
put
([
g
.
outputs
[
0
].
__dict__
for
g
in
gen
])
queue
.
close
()
queue
.
join_thread
()
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
False
,
True
])
def
test_sharded_state_loader
(
enable_lora
):
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
weights_patterns
=
(
"*.bin"
,
"*.pt"
,
"*.safetensors"
)
def
test_sharded_state_loader
(
enable_lora
,
tp_size
,
num_gpus_available
,
llama_2_7b_files
):
if
num_gpus_available
<
tp_size
:
pytest
.
skip
(
f
"Not enough GPUs for tensor parallelism
{
tp_size
}
"
)
with
TemporaryDirectory
()
as
cache_dir
,
TemporaryDirectory
()
as
output_dir
:
weights_patterns
=
(
"*.safetensors"
,
)
input_dir
=
snapshot_download
(
"meta-llama/Llama-2-7b-hf"
,
gpu_memory_utilization
=
0.8
cache_dir
=
cache_dir
)
input_dir
=
llama_2_7b_files
ctx
=
mp
.
get_context
(
"spawn"
)
llm
=
LLM
(
model
=
input_dir
,
# Run in separate processes for memory & CUDA isolation
worker_use_ray
=
True
,
with
TemporaryDirectory
()
as
output_dir
:
gpu_memory_utilization
=
0.3
,
p
=
ctx
.
Process
(
target
=
_run_writer
,
)
args
=
(
input_dir
,
output_dir
,
weights_patterns
),
kwargs
=
dict
(
# Dump worker states to output directory
tensor_parallel_size
=
tp_size
,
model_executor
=
llm
.
llm_engine
.
model_executor
distributed_executor_backend
=
"mp"
,
model_executor
.
save_sharded_state
(
path
=
output_dir
)
gpu_memory_utilization
=
gpu_memory_utilization
,
# Copy metadata files to output directory
enforce_eager
=
True
,
for
file
in
os
.
listdir
(
input_dir
):
))
if
not
any
(
file
.
endswith
(
ext
)
for
ext
in
weights_patterns
):
p
.
start
()
shutil
.
copy
(
f
"
{
input_dir
}
/
{
file
}
"
,
output_dir
)
p
.
join
()
del
llm
.
llm_engine
.
model_executor
queue
=
ctx
.
Queue
()
llm_before
=
LLM
(
model
=
input_dir
,
p
=
ctx
.
Process
(
target
=
_run_generate
,
worker_use_ray
=
True
,
args
=
(
input_dir
,
queue
),
enable_lora
=
enable_lora
,
kwargs
=
dict
(
gpu_memory_utilization
=
0.3
,
distributed_executor_backend
=
"mp"
,
)
enable_lora
=
enable_lora
,
gen_before
=
llm_before
.
generate
(
prompts
,
sampling_params
)
gpu_memory_utilization
=
gpu_memory_utilization
,
out_before
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_before
]
tensor_parallel_size
=
tp_size
,
del
llm_before
.
llm_engine
.
model_executor
))
p
.
start
()
llm_after
=
LLM
(
p
.
join
()
model
=
output_dir
,
out_before
=
queue
.
get
()
worker_use_ray
=
True
,
enable_lora
=
enable_lora
,
p
=
ctx
.
Process
(
target
=
_run_generate
,
gpu_memory_utilization
=
0.3
,
args
=
(
output_dir
,
queue
),
load_format
=
"sharded_state"
,
kwargs
=
dict
(
)
distributed_executor_backend
=
"mp"
,
gen_after
=
llm_after
.
generate
(
prompts
,
sampling_params
)
enable_lora
=
enable_lora
,
out_after
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_after
]
gpu_memory_utilization
=
gpu_memory_utilization
,
del
llm_after
.
llm_engine
.
model_executor
tensor_parallel_size
=
tp_size
,
load_format
=
"sharded_state"
,
))
p
.
start
()
p
.
join
()
out_after
=
queue
.
get
()
assert
out_before
==
out_after
assert
out_before
==
out_after
vllm/lora/layers.py
View file @
ccdc490d
...
@@ -215,19 +215,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -215,19 +215,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
lora_vocab_start_idx
=
self
.
base_layer
.
org_vocab_size
if
self
.
base_layer
.
num_added_embeddings_per_partition
>
0
:
weights_idx
=
None
if
self
.
base_layer
.
vocab_end_index
>
lora_vocab_start_idx
:
# We can start adding lora weights
# We can start adding lora weights
weights_idx
=
max
(
self
.
embeddings_weights
=
self
.
base_layer
.
weight
.
data
[
lora_vocab_start_idx
-
self
.
base_layer
.
vocab_start_index
,
0
)
self
.
base_layer
.
num_org_embeddings_per_partition
:
self
.
self
.
embeddings_slice
=
(
self
.
base_layer
.
vocab_start_index
-
base_layer
.
num_org_embeddings_per_partition
+
self
.
base_layer
.
org_vocab_size
+
self
.
base_layer
.
num_added_embeddings_per_partition
]
weights_idx
,
self
.
embeddings_slice
=
(
self
.
base_layer
.
vocab_end_index
-
self
.
base_layer
.
shard_indices
.
added_vocab_start_index
-
self
.
base_layer
.
org_vocab_size
)
self
.
base_layer
.
org_vocab_size
,
self
.
embeddings_weights
=
self
.
base_layer
.
weight
.
data
[
weights_idx
:]
self
.
base_layer
.
shard_indices
.
added_vocab_end_index
-
self
.
embeddings_weights
.
fill_
(
0
)
self
.
base_layer
.
org_vocab_size
)
self
.
base_layer
.
weight
.
data
[
self
.
base_layer
.
num_org_embeddings_per_partition
:].
fill_
(
0
)
else
:
else
:
self
.
embeddings_slice
=
None
self
.
embeddings_slice
=
None
self
.
embeddings_weights
=
None
self
.
embeddings_weights
=
None
...
@@ -1025,19 +1025,31 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -1025,19 +1025,31 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
class
LogitsProcessorWithLoRA
(
BaseLayerWithLoRA
):
class
LogitsProcessorWithLoRA
(
BaseLayerWithLoRA
):
"""
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
Args:
base_layer: LogitsProcessor layer
hidden_size: hidden size of the model
dtype: data type of the model
device: device of the model
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
received from base_layer.get_sharded_to_full_mapping(). If None,
no reindexing will be done.
"""
def
__init__
(
def
__init__
(
self
,
base_layer
:
LogitsProcessor
,
hidden_size
:
int
,
self
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
base_layer
:
LogitsProcessor
,
sharded_to_full_mapping
:
Optional
[
List
[
int
]])
->
None
:
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
base_layer
=
base_layer
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
sharded_to_full_mapping
=
sharded_to_full_mapping
@
property
@
property
def
logits_as_input
(
self
):
def
logits_as_input
(
self
):
...
@@ -1098,6 +1110,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -1098,6 +1110,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
if
self
.
sharded_to_full_mapping
is
not
None
:
self
.
sharded_to_full_mapping_gpu
=
torch
.
tensor
(
self
.
sharded_to_full_mapping
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
else
:
self
.
sharded_to_full_mapping_gpu
=
None
# Lazily initialized.
# Lazily initialized.
self
.
indices
:
torch
.
Tensor
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
self
.
indices_len
:
List
[
int
]
...
@@ -1154,6 +1173,25 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -1154,6 +1173,25 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
if
logits
is
None
:
if
logits
is
None
:
return
None
return
None
if
self
.
sharded_to_full_mapping_gpu
is
not
None
:
# Reindex full logits tensor to ensure 1:1 mapping between
# index and token_id
# Example for:
# org_vocab_size = 4
# added_vocab_size = 2
# pad_to_size = 8
# tp_size = 2
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
# Therefore, the mapping is expected to be:
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
# we get:
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits
=
logits
[:,
self
.
sharded_to_full_mapping_gpu
]
lora_logits
=
torch
.
empty
(
lora_logits
=
torch
.
empty
(
self
.
embeddings_tensors
.
shape
[
0
]
+
1
,
self
.
embeddings_tensors
.
shape
[
0
]
+
1
,
self
.
embeddings_tensors
.
shape
[
1
],
self
.
embeddings_tensors
.
shape
[
1
],
...
...
vllm/lora/utils.py
View file @
ccdc490d
...
@@ -67,7 +67,8 @@ def from_layer_logits_processor(
...
@@ -67,7 +67,8 @@ def from_layer_logits_processor(
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
LogitsProcessorWithLoRA
:
)
->
LogitsProcessorWithLoRA
:
ret
=
LogitsProcessorWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
ret
=
LogitsProcessorWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
,
lm_head
.
get_sharded_to_full_mapping
())
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
return
ret
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
ccdc490d
from
typing
import
Optional
,
Sequence
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -18,18 +19,107 @@ def pad_vocab_size(vocab_size: int,
...
@@ -18,18 +19,107 @@ def pad_vocab_size(vocab_size: int,
return
((
vocab_size
+
pad_to
-
1
)
//
pad_to
)
*
pad_to
return
((
vocab_size
+
pad_to
-
1
)
//
pad_to
)
*
pad_to
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
:
int
,
def
vocab_range_from_per_partition_vocab_size
(
rank
:
int
)
->
Sequence
[
int
]:
per_partition_vocab_size
:
int
,
rank
:
int
,
offset
:
int
=
0
)
->
Sequence
[
int
]:
index_f
=
rank
*
per_partition_vocab_size
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
return
index_f
+
offset
,
index_l
+
offset
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
world_size
:
int
)
->
Sequence
[
int
]:
rank
:
int
,
world_size
:
int
,
offset
:
int
=
0
)
->
Sequence
[
int
]:
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
return
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
)
rank
,
offset
=
offset
)
@
dataclass
class
VocabParallelEmbeddingShardIndices
:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index
:
int
padded_org_vocab_end_index
:
int
padded_added_vocab_start_index
:
int
padded_added_vocab_end_index
:
int
org_vocab_start_index
:
int
org_vocab_end_index
:
int
added_vocab_start_index
:
int
added_vocab_end_index
:
int
@
property
def
num_org_elements
(
self
)
->
int
:
return
self
.
org_vocab_end_index
-
self
.
org_vocab_start_index
@
property
def
num_added_elements
(
self
)
->
int
:
return
self
.
added_vocab_end_index
-
self
.
added_vocab_start_index
@
property
def
num_org_elements_padded
(
self
)
->
int
:
return
(
self
.
padded_org_vocab_end_index
-
self
.
padded_org_vocab_start_index
)
@
property
def
num_added_elements_padded
(
self
)
->
int
:
return
(
self
.
padded_added_vocab_end_index
-
self
.
padded_added_vocab_start_index
)
@
property
def
num_org_vocab_padding
(
self
)
->
int
:
return
self
.
num_org_elements_padded
-
self
.
num_org_elements
@
property
def
num_added_vocab_padding
(
self
)
->
int
:
return
self
.
num_added_elements_padded
-
self
.
num_added_elements
@
property
def
num_elements_padded
(
self
)
->
int
:
return
self
.
num_org_elements_padded
+
self
.
num_added_elements_padded
def
__post_init__
(
self
):
# sanity checks
assert
(
self
.
padded_org_vocab_start_index
<=
self
.
padded_org_vocab_end_index
)
assert
(
self
.
padded_added_vocab_start_index
<=
self
.
padded_added_vocab_end_index
)
assert
self
.
org_vocab_start_index
<=
self
.
org_vocab_end_index
assert
self
.
added_vocab_start_index
<=
self
.
added_vocab_end_index
assert
self
.
org_vocab_start_index
<=
self
.
padded_org_vocab_start_index
assert
(
self
.
added_vocab_start_index
<=
self
.
padded_added_vocab_start_index
)
assert
self
.
org_vocab_end_index
<=
self
.
padded_org_vocab_end_index
assert
self
.
added_vocab_end_index
<=
self
.
padded_added_vocab_end_index
assert
self
.
num_org_elements
<=
self
.
num_org_elements_padded
assert
self
.
num_added_elements
<=
self
.
num_added_elements_padded
@
torch
.
jit
.
script
def
get_masked_input_and_mask
(
input_
:
torch
.
Tensor
,
org_vocab_start_index
:
int
,
org_vocab_end_index
:
int
,
num_org_vocab_padding
:
int
,
added_vocab_start_index
:
int
,
added_vocab_end_index
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# torch.jit.script will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask
=
(
input_
>=
org_vocab_start_index
)
&
(
input_
<
org_vocab_end_index
)
added_vocab_mask
=
(
input_
>=
added_vocab_start_index
)
&
(
input_
<
added_vocab_end_index
)
added_offset
=
added_vocab_start_index
-
(
org_vocab_end_index
-
org_vocab_start_index
)
-
num_org_vocab_padding
valid_offset
=
(
org_vocab_start_index
*
org_vocab_mask
)
+
(
added_offset
*
added_vocab_mask
)
vocab_mask
=
org_vocab_mask
|
added_vocab_mask
input_
=
vocab_mask
*
(
input_
-
valid_offset
)
return
input_
,
~
vocab_mask
class
VocabParallelEmbedding
(
torch
.
nn
.
Module
):
class
VocabParallelEmbedding
(
torch
.
nn
.
Module
):
...
@@ -38,13 +128,36 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -38,13 +128,36 @@ class VocabParallelEmbedding(torch.nn.Module):
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
Args:
num_embeddings: vocabulary size.
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
padding_size: padding size for the vocabulary.
"""
"""
# noqa: E501
def
__init__
(
self
,
def
__init__
(
self
,
num_embeddings
:
int
,
num_embeddings
:
int
,
...
@@ -55,21 +168,39 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -55,21 +168,39 @@ class VocabParallelEmbedding(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
# Keep the input dimensions.
# Keep the input dimensions.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_embeddings
=
num_embeddings
self
.
num_embeddings
=
num_embeddings
self
.
padding_size
=
padding_size
self
.
org_vocab_size
=
org_num_embeddings
or
num_embeddings
self
.
org_vocab_size
=
org_num_embeddings
or
num_embeddings
self
.
num_embeddings_padded
=
pad_vocab_size
(
num_embeddings
,
num_added_embeddings
=
num_embeddings
-
self
.
org_vocab_size
padding_size
)
self
.
org_vocab_size_padded
=
pad_vocab_size
(
self
.
org_vocab_size
,
self
.
padding_size
)
self
.
num_embeddings_padded
=
pad_vocab_size
(
self
.
org_vocab_size_padded
+
num_added_embeddings
,
self
.
padding_size
)
assert
self
.
org_vocab_size_padded
<=
self
.
num_embeddings_padded
self
.
shard_indices
=
self
.
_get_indices
(
self
.
num_embeddings_padded
,
self
.
org_vocab_size_padded
,
self
.
num_embeddings
,
self
.
org_vocab_size
,
tp_rank
,
self
.
tp_size
)
self
.
embedding_dim
=
embedding_dim
self
.
embedding_dim
=
embedding_dim
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
# Divide the weight matrix along the vocaburaly dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
(
self
.
num_added_embeddings
=
self
.
num_embeddings
-
self
.
org_vocab_size
vocab_range_from_global_vocab_size
(
self
.
num_embeddings_per_partition
=
divide
(
self
.
num_embeddings_padded
,
self
.
num_embeddings_padded
,
get_tensor_model_parallel_rank
(),
self
.
tp_size
)
self
.
tp_size
))
assert
(
self
.
shard_indices
.
num_elements_padded
==
self
.
num_embeddings_per_partition
=
(
self
.
vocab_end_index
-
self
.
num_embeddings_per_partition
)
self
.
vocab_start_index
)
self
.
num_org_embeddings_per_partition
=
(
self
.
shard_indices
.
org_vocab_end_index
-
self
.
shard_indices
.
org_vocab_start_index
)
self
.
num_added_embeddings_per_partition
=
(
self
.
shard_indices
.
added_vocab_end_index
-
self
.
shard_indices
.
added_vocab_start_index
)
self
.
weight
=
Parameter
(
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
self
.
embedding_dim
,
...
@@ -79,28 +210,107 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -79,28 +210,107 @@ class VocabParallelEmbedding(torch.nn.Module):
"weight_loader"
:
self
.
weight_loader
"weight_loader"
:
self
.
weight_loader
})
})
@
classmethod
def
_get_indices
(
cls
,
vocab_size_padded
:
int
,
org_vocab_size_padded
:
int
,
vocab_size
:
int
,
org_vocab_size
:
int
,
tp_rank
:
int
,
tp_size
:
int
)
->
VocabParallelEmbeddingShardIndices
:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded
=
vocab_size_padded
-
org_vocab_size_padded
padded_org_vocab_start_index
,
padded_org_vocab_end_index
=
(
vocab_range_from_global_vocab_size
(
org_vocab_size_padded
,
tp_rank
,
tp_size
))
padded_added_vocab_start_index
,
padded_added_vocab_end_index
=
(
vocab_range_from_global_vocab_size
(
num_added_embeddings_padded
,
tp_rank
,
tp_size
,
offset
=
org_vocab_size
))
# remove padding
org_vocab_start_index
=
min
(
padded_org_vocab_start_index
,
org_vocab_size
)
org_vocab_end_index
=
min
(
padded_org_vocab_end_index
,
org_vocab_size
)
added_vocab_start_index
=
min
(
padded_added_vocab_start_index
,
vocab_size
)
added_vocab_end_index
=
min
(
padded_added_vocab_end_index
,
vocab_size
)
return
VocabParallelEmbeddingShardIndices
(
padded_org_vocab_start_index
,
padded_org_vocab_end_index
,
padded_added_vocab_start_index
,
padded_added_vocab_end_index
,
org_vocab_start_index
,
org_vocab_end_index
,
added_vocab_start_index
,
added_vocab_end_index
)
def
get_sharded_to_full_mapping
(
self
)
->
Optional
[
List
[
int
]]:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
During sampling, we gather logits from all ranks. The relationship
of index->token_id will follow the same format as outlined in the class
docstring. However, after the gather, we want to reindex the final
logits tensor to map index->token_id one-to-one (the index is always
equal the token_id it corresponds to). The indices returned by this
method allow us to do that.
"""
if
self
.
tp_size
<
2
:
return
None
base_embeddings
:
List
[
int
]
=
[]
added_embeddings
:
List
[
int
]
=
[]
padding
:
List
[
int
]
=
[]
for
tp_rank
in
range
(
self
.
tp_size
):
shard_indices
=
self
.
_get_indices
(
self
.
num_embeddings_padded
,
self
.
org_vocab_size_padded
,
self
.
num_embeddings
,
self
.
org_vocab_size
,
tp_rank
,
self
.
tp_size
)
range_start
=
self
.
num_embeddings_per_partition
*
tp_rank
range_end
=
self
.
num_embeddings_per_partition
*
(
tp_rank
+
1
)
base_embeddings
.
extend
(
range
(
range_start
,
range_start
+
shard_indices
.
num_org_elements
))
padding
.
extend
(
range
(
range_start
+
shard_indices
.
num_org_elements
,
range_start
+
shard_indices
.
num_org_elements_padded
))
added_embeddings
.
extend
(
range
(
range_start
+
shard_indices
.
num_org_elements_padded
,
range_start
+
shard_indices
.
num_org_elements_padded
+
shard_indices
.
num_added_elements
))
padding
.
extend
(
range
(
range_start
+
shard_indices
.
num_org_elements_padded
+
shard_indices
.
num_added_elements
,
range_start
+
shard_indices
.
num_org_elements_padded
+
shard_indices
.
num_added_elements_padded
))
assert
(
range_start
+
shard_indices
.
num_org_elements_padded
+
shard_indices
.
num_added_elements_padded
==
range_end
)
ret
=
base_embeddings
+
added_embeddings
+
padding
assert
len
(
ret
)
==
self
.
num_embeddings_padded
return
ret
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
parallel_dim
=
param
.
parallel_dim
parallel_dim
=
param
.
parallel_dim
assert
loaded_weight
.
shape
[
parallel_dim
]
==
self
.
org_vocab_size
assert
loaded_weight
.
shape
[
parallel_dim
]
==
self
.
org_vocab_size
loaded_weight
=
loaded_weight
[
self
.
vocab_start_index
:
self
.
loaded_weight
=
loaded_weight
[
self
.
shard_indices
.
org_
vocab_start_index
:
vocab_end_index
]
self
.
shard_indices
.
org_
vocab_end_index
]
param
[:
loaded_weight
.
shape
[
0
]].
data
.
copy_
(
loaded_weight
)
param
[:
loaded_weight
.
shape
[
0
]].
data
.
copy_
(
loaded_weight
)
param
[
loaded_weight
.
shape
[
0
]:].
data
.
fill_
(
0
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
# Build the mask.
# Build the mask.
input_mask
=
((
input_
<
self
.
vocab_start_index
)
|
masked_input
,
input_mask
=
get_masked_input_and_mask
(
(
input_
>=
self
.
vocab_end_index
))
input_
,
self
.
shard_indices
.
org_vocab_start_index
,
# Mask the input.
self
.
shard_indices
.
org_vocab_end_index
,
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
self
.
shard_indices
.
num_org_vocab_padding
,
masked_input
[
input_mask
]
=
0
self
.
shard_indices
.
added_vocab_start_index
,
self
.
shard_indices
.
added_vocab_end_index
)
else
:
else
:
masked_input
=
input_
masked_input
=
input_
# Get the embeddings.
# Get the embeddings.
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
)
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
)
# Mask the output embedding.
# Mask the output embedding.
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
1
),
0
)
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
return
output
return
output
...
...
vllm/worker/model_runner.py
View file @
ccdc490d
...
@@ -35,6 +35,7 @@ _BATCH_SIZE_ALIGNMENT = 8
...
@@ -35,6 +35,7 @@ _BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
]
]
_NUM_WARMUP_ITERS
=
2
class
ModelInput
(
NamedTuple
):
class
ModelInput
(
NamedTuple
):
...
@@ -975,16 +976,18 @@ class CUDAGraphRunner:
...
@@ -975,16 +976,18 @@ class CUDAGraphRunner:
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
assert
self
.
_graph
is
None
assert
self
.
_graph
is
None
# Run the model
once
without capturing the graph.
# Run the model
a few times
without capturing the graph.
# This is to make sure that the captured graph does not include the
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# kernel launches for initial benchmarking (e.g., Triton autotune).
self
.
model
(
# Note one iteration is not enough for torch.jit.script
input_ids
,
for
_
in
range
(
_NUM_WARMUP_ITERS
):
positions
,
self
.
model
(
kv_caches
,
input_ids
,
attn_metadata
,
positions
,
**
kwargs
,
kv_caches
,
)
attn_metadata
,
**
kwargs
,
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Capture the graph.
# Capture the graph.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment