Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ccdc490d
Unverified
Commit
ccdc490d
authored
Jun 06, 2024
by
Antoni Baum
Committed by
GitHub
Jun 06, 2024
Browse files
[Core] Change LoRA embedding sharding to support loading methods (#5038)
parent
a31cab75
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
661 additions
and
129 deletions
+661
-129
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-8
tests/conftest.py
tests/conftest.py
+21
-0
tests/lora/conftest.py
tests/lora/conftest.py
+16
-2
tests/lora/test_layers.py
tests/lora/test_layers.py
+217
-2
tests/lora/test_llama.py
tests/lora/test_llama.py
+7
-10
tests/lora/test_long_context.py
tests/lora/test_long_context.py
+13
-10
tests/test_sharded_state_loader.py
tests/test_sharded_state_loader.py
+80
-44
vllm/lora/layers.py
vllm/lora/layers.py
+57
-19
vllm/lora/utils.py
vllm/lora/utils.py
+2
-1
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+235
-25
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+11
-8
No files found.
.buildkite/test-pipeline.yaml
View file @
ccdc490d
...
...
@@ -46,6 +46,7 @@ steps:
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
-
pytest -v -s spec_decode/e2e/test_integration_dist.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
-
label
:
Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
...
...
@@ -138,14 +139,7 @@ steps:
num_gpus
:
4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands
:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
-
pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
-
pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
-
pytest -v -s lora/test_long_context.py::test_self_consistency
-
pytest -v -s lora/test_long_context.py::test_quality
-
pytest -v -s lora/test_long_context.py::test_max_len
-
pytest -v -s -x lora/test_long_context.py
-
label
:
Tensorizer Test
#mirror_hardwares: [amd]
...
...
tests/conftest.py
View file @
ccdc490d
import
contextlib
import
gc
import
os
import
subprocess
import
sys
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
import
pytest
...
...
@@ -522,3 +524,22 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
yield
caplog
@
pytest
.
fixture
(
scope
=
"session"
)
def
num_gpus_available
():
"""Get number of GPUs without initializing the CUDA context
in current process."""
try
:
out
=
subprocess
.
run
([
sys
.
executable
,
"-c"
,
"import torch; print(torch.cuda.device_count())"
],
capture_output
=
True
,
check
=
True
,
text
=
True
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
warning
(
"Failed to get number of GPUs."
,
exc_info
=
e
)
return
0
return
int
(
out
.
stdout
.
strip
())
tests/lora/conftest.py
View file @
ccdc490d
...
...
@@ -42,9 +42,23 @@ def cleanup():
ray
.
shutdown
()
@
pytest
.
fixture
()
def
should_do_global_cleanup_after_test
(
request
)
->
bool
:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
if
request
.
node
.
get_closest_marker
(
"skip_global_cleanup"
):
return
False
return
True
@
pytest
.
fixture
(
autouse
=
True
)
def
cleanup_fixture
():
def
cleanup_fixture
(
should_do_global_cleanup_after_test
:
bool
):
yield
if
should_do_global_cleanup_after_test
:
cleanup
()
...
...
tests/lora/test_layers.py
View file @
ccdc490d
...
...
@@ -2,6 +2,7 @@ import random
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
patch
import
pytest
import
torch
...
...
@@ -32,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
,
get_masked_input_and_mask
)
from
vllm.model_executor.utils
import
set_random_seed
from
.utils
import
DummyLoRAManager
...
...
@@ -427,7 +428,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
logits_processor
=
LogitsProcessor
(
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
vocab_size
)
lora_logits_processor
=
LogitsProcessorWithLoRA
(
logits_processor
,
1024
,
linear
.
weight
.
dtype
,
linear
.
weight
.
device
)
logits_processor
,
1024
,
linear
.
weight
.
dtype
,
linear
.
weight
.
device
,
None
)
lora_logits_processor
.
create_lora_weights
(
max_loras
,
lora_config
)
return
linear
,
logits_processor
,
lora_logits_processor
...
...
@@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
torch
.
allclose
(
ref_q
,
actual_q
)
torch
.
allclose
(
ref_k
,
actual_k
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
256
)))
def
test_vocab_parallel_embedding_indices
(
tp_size
,
seed
):
random
.
seed
(
seed
)
vocab_size
=
random
.
randint
(
4000
,
64000
)
added_vocab_size
=
random
.
randint
(
0
,
1024
)
org_vocab_size
=
vocab_size
-
added_vocab_size
last_org_vocab_end_index
=
0
last_added_vocab_end_index
=
org_vocab_size
computed_vocab_size
=
0
computed_org_vocab_size
=
0
computed_added_vocab_size
=
0
vocab_size_padded
=
-
1
all_org_tokens
=
[]
all_added_tokens
=
[]
token_ids
=
[]
for
tp_rank
in
range
(
tp_size
):
with
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank"
,
return_value
=
tp_rank
),
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size"
,
return_value
=
tp_size
):
vocab_embedding
=
VocabParallelEmbedding
(
vocab_size
,
1
,
org_num_embeddings
=
org_vocab_size
)
vocab_size_padded
=
vocab_embedding
.
num_embeddings_padded
shard_indices
=
vocab_embedding
.
shard_indices
# Assert that the ranges are contiguous
assert
shard_indices
.
org_vocab_start_index
==
last_org_vocab_end_index
assert
(
shard_indices
.
added_vocab_start_index
==
last_added_vocab_end_index
)
# Ensure that we are not exceeding the vocab size
computed_vocab_size
+=
shard_indices
.
num_elements_padded
computed_org_vocab_size
+=
shard_indices
.
num_org_elements
computed_added_vocab_size
+=
shard_indices
.
num_added_elements
# Ensure that the ranges are not overlapping
all_org_tokens
.
extend
(
range
(
shard_indices
.
org_vocab_start_index
,
shard_indices
.
org_vocab_end_index
))
all_added_tokens
.
extend
(
range
(
shard_indices
.
added_vocab_start_index
,
shard_indices
.
added_vocab_end_index
))
token_ids
.
extend
(
range
(
shard_indices
.
org_vocab_start_index
,
shard_indices
.
org_vocab_end_index
))
token_ids
.
extend
([
-
1
]
*
(
shard_indices
.
num_org_elements_padded
-
shard_indices
.
num_org_elements
))
token_ids
.
extend
(
range
(
shard_indices
.
added_vocab_start_index
,
shard_indices
.
added_vocab_end_index
))
token_ids
.
extend
([
-
1
]
*
(
shard_indices
.
num_added_elements_padded
-
shard_indices
.
num_added_elements
))
last_org_vocab_end_index
=
shard_indices
.
org_vocab_end_index
last_added_vocab_end_index
=
shard_indices
.
added_vocab_end_index
assert
computed_vocab_size
==
vocab_size_padded
assert
computed_org_vocab_size
==
org_vocab_size
assert
computed_added_vocab_size
==
added_vocab_size
# Ensure that the ranges are not overlapping
assert
len
(
all_org_tokens
)
==
len
(
set
(
all_org_tokens
))
assert
len
(
all_added_tokens
)
==
len
(
set
(
all_added_tokens
))
assert
not
set
(
all_org_tokens
).
intersection
(
set
(
all_added_tokens
))
token_ids_tensor
=
torch
.
tensor
(
token_ids
,
dtype
=
torch
.
long
)
reindex_mapping
=
vocab_embedding
.
get_sharded_to_full_mapping
()
assert
reindex_mapping
is
not
None
or
tp_size
==
1
if
reindex_mapping
is
not
None
:
reindexed_token_ids
=
token_ids_tensor
[
reindex_mapping
]
expected
=
torch
.
tensor
(
list
(
range
(
0
,
vocab_size
)))
assert
reindexed_token_ids
[:
vocab_size
].
equal
(
expected
)
assert
torch
.
all
(
reindexed_token_ids
[
vocab_size
:]
==
-
1
)
def
test_get_masked_input_and_mask
():
x
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
])
# base tp 1 case, no padding
modified_x
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
0
)
assert
torch
.
equal
(
x
,
modified_x
)
# tp 2 case, no padding
modified_x_rank_0
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
4
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
10
,
num_org_vocab_padding
=
0
)
modified_x_rank_1
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
4
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
10
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
0
)
assert
torch
.
equal
(
modified_x_rank_0
,
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
0
,
0
,
0
,
4
,
5
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_1
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
,
0
,
4
,
5
]))
# tp 4 case, no padding
modified_x_rank_0
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
2
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
9
,
num_org_vocab_padding
=
0
)
modified_x_rank_1
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
2
,
org_vocab_end_index
=
4
,
added_vocab_start_index
=
9
,
added_vocab_end_index
=
10
,
num_org_vocab_padding
=
0
)
modified_x_rank_2
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
4
,
org_vocab_end_index
=
6
,
added_vocab_start_index
=
10
,
added_vocab_end_index
=
11
,
num_org_vocab_padding
=
0
)
modified_x_rank_3
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
6
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
11
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
0
)
assert
torch
.
equal
(
modified_x_rank_0
,
torch
.
tensor
([
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_1
,
torch
.
tensor
([
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_2
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
2
,
0
]))
assert
torch
.
equal
(
modified_x_rank_3
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
2
]))
# base tp 1 case, with padding
modified_x
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
2
)
assert
torch
.
equal
(
modified_x
,
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
10
,
11
,
12
,
13
]))
# tp 2 case, with padding
modified_x_rank_0
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
4
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
10
,
num_org_vocab_padding
=
2
)
modified_x_rank_1
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
4
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
10
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
2
)
assert
torch
.
equal
(
modified_x_rank_0
,
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
0
,
0
,
0
,
6
,
7
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_1
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
,
0
,
6
,
7
]))
# tp 4 case, with padding
modified_x_rank_0
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
0
,
org_vocab_end_index
=
2
,
added_vocab_start_index
=
8
,
added_vocab_end_index
=
9
,
num_org_vocab_padding
=
2
)
modified_x_rank_1
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
2
,
org_vocab_end_index
=
4
,
added_vocab_start_index
=
9
,
added_vocab_end_index
=
10
,
num_org_vocab_padding
=
2
)
modified_x_rank_2
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
4
,
org_vocab_end_index
=
6
,
added_vocab_start_index
=
10
,
added_vocab_end_index
=
11
,
num_org_vocab_padding
=
2
)
modified_x_rank_3
,
_
=
get_masked_input_and_mask
(
x
,
org_vocab_start_index
=
6
,
org_vocab_end_index
=
8
,
added_vocab_start_index
=
11
,
added_vocab_end_index
=
12
,
num_org_vocab_padding
=
2
)
assert
torch
.
equal
(
modified_x_rank_0
,
torch
.
tensor
([
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
4
,
0
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_1
,
torch
.
tensor
([
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
4
,
0
,
0
]))
assert
torch
.
equal
(
modified_x_rank_2
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
4
,
0
]))
assert
torch
.
equal
(
modified_x_rank_3
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
4
]))
tests/lora/test_llama.py
View file @
ccdc490d
...
...
@@ -36,11 +36,10 @@ def do_sample(llm, lora_path: str, lora_id: int):
return
generated_texts
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
def
test_llama_lora
(
sql_lora_files
,
tp_size
):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
4
])
def
test_llama_lora
(
sql_lora_files
,
tp_size
,
num_gpus_available
):
if
num_gpus_available
<
tp_size
:
pytest
.
skip
(
f
"Not enough GPUs for tensor parallelism
{
tp_size
}
"
)
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
...
...
@@ -80,11 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size):
print
(
"removing lora"
)
@
pytest
.
mark
.
skip
(
"Requires multiple GPUs"
)
def
test_llama_tensor_parallel_equality
(
sql_lora_files
):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
def
test_llama_tensor_parallel_equality
(
sql_lora_files
,
num_gpus_available
):
if
num_gpus_available
<
4
:
pytest
.
skip
(
"Not enough GPUs for tensor parallelism 4"
)
llm_tp1
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
...
...
tests/lora/test_long_context.py
View file @
ccdc490d
...
...
@@ -102,22 +102,21 @@ def batched_generate(
return
[
outputs
[
i
].
outputs
[
0
].
text
.
strip
()
for
i
in
range
(
len
(
outputs
))]
@
pytest
.
fixture
@
pytest
.
fixture
(
scope
=
"module"
)
def
lora_llm
(
long_context_infos
):
scaling_factors
=
[
context_len_to_scaling_factor
[
info
[
"context_length"
]]
for
info
in
long_context_infos
.
values
()
]
llm
=
vllm
.
LLM
(
"meta-llama/Llama-2-13b-chat-hf"
,
llm
=
vllm
.
LLM
(
"meta-llama/Llama-2-13b-chat-hf"
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_loras
=
2
,
long_lora_scaling_factors
=
tuple
(
scaling_factors
),
max_num_batched_tokens
=
4096
*
8
,
tensor_parallel_size
=
4
,
)
distributed_executor_backend
=
"mp"
)
yield
llm
del
llm
...
...
@@ -154,6 +153,7 @@ def test_rotary_emb_replaced(dist_init):
assert
rotary_emb_count
==
32
@
pytest
.
mark
.
skip_global_cleanup
def
test_batched_rope_kernel
(
lora_llm
,
long_context_infos
):
"""We test the batched kernel by comparing the results of batched an
non-batched generation.
...
...
@@ -188,6 +188,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
f
"same:
\n
{
batched
}
\n
{
non_batched
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_self_consistency
(
lora_llm
,
long_context_infos
):
"""We test consistency of the batched kernel by permuting batched
inputs and comparing the results to the non-permuted batched results.
...
...
@@ -227,6 +228,7 @@ def test_self_consistency(lora_llm, long_context_infos):
f
"
\n
{
permutated_batched_results
[
permutation
[
i
]]
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_quality
(
lora_llm
,
long_context_infos
):
"""We test the quality of the answers given by the LoRA model by
comparing the generated text to the merged model's outputs.
...
...
@@ -257,6 +259,7 @@ def test_quality(lora_llm, long_context_infos):
assert
np
.
mean
(
scores
)
>
0.5
@
pytest
.
mark
.
skip_global_cleanup
def
test_max_len
(
lora_llm
,
long_context_infos
):
"""Test that we raise an ValueError when the input of a given LoRA
model exceeds the maximum length."""
...
...
tests/test_sharded_state_loader.py
View file @
ccdc490d
import
multiprocessing
as
mp
import
os
import
shutil
from
tempfile
import
TemporaryDirectory
...
...
@@ -18,9 +19,7 @@ prompts = [
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
seed
=
0
,
temperature
=
0
,
max_tokens
=
256
,
ignore_eos
=
True
,
)
...
...
@@ -43,48 +42,85 @@ def test_filter_subtensors():
assert
tensor
.
equal
(
state_dict
[
key
])
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
False
,
True
])
def
test_sharded_state_loader
(
enable_lora
):
weights_patterns
=
(
"*.bin"
,
"*.pt"
,
"*.safetensors"
)
with
TemporaryDirectory
()
as
cache_dir
,
TemporaryDirectory
()
as
output_dir
:
@
pytest
.
fixture
(
scope
=
"module"
)
def
llama_2_7b_files
():
with
TemporaryDirectory
()
as
cache_dir
:
input_dir
=
snapshot_download
(
"meta-llama/Llama-2-7b-hf"
,
cache_dir
=
cache_dir
)
cache_dir
=
cache_dir
,
ignore_patterns
=
"*.bin*"
)
yield
input_dir
llm
=
LLM
(
model
=
input_dir
,
worker_use_ray
=
True
,
gpu_memory_utilization
=
0.3
,
)
def
_run_writer
(
input_dir
,
output_dir
,
weights_patterns
,
**
kwargs
):
llm_sharded_writer
=
LLM
(
model
=
input_dir
,
**
kwargs
)
# Dump worker states to output directory
model_executor
=
llm
.
llm_engine
.
model_executor
model_executor
.
save_sharded_state
(
path
=
output_dir
)
llm_sharded_writer
.
llm_engine
.
model_executor
.
save_sharded_state
(
path
=
output_dir
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
input_dir
):
if
not
any
(
file
.
endswith
(
ext
)
for
ext
in
weights_patterns
):
shutil
.
copy
(
f
"
{
input_dir
}
/
{
file
}
"
,
output_dir
)
del
llm
.
llm_engine
.
model_executor
llm_before
=
LLM
(
model
=
input_dir
,
worker_use_ray
=
True
,
def
_run_generate
(
input_dir
,
queue
:
mp
.
Queue
,
**
kwargs
):
llm
=
LLM
(
model
=
input_dir
,
**
kwargs
)
gen
=
llm
.
generate
(
prompts
,
sampling_params
)
queue
.
put
([
g
.
outputs
[
0
].
__dict__
for
g
in
gen
])
queue
.
close
()
queue
.
join_thread
()
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
def
test_sharded_state_loader
(
enable_lora
,
tp_size
,
num_gpus_available
,
llama_2_7b_files
):
if
num_gpus_available
<
tp_size
:
pytest
.
skip
(
f
"Not enough GPUs for tensor parallelism
{
tp_size
}
"
)
weights_patterns
=
(
"*.safetensors"
,
)
gpu_memory_utilization
=
0.8
input_dir
=
llama_2_7b_files
ctx
=
mp
.
get_context
(
"spawn"
)
# Run in separate processes for memory & CUDA isolation
with
TemporaryDirectory
()
as
output_dir
:
p
=
ctx
.
Process
(
target
=
_run_writer
,
args
=
(
input_dir
,
output_dir
,
weights_patterns
),
kwargs
=
dict
(
tensor_parallel_size
=
tp_size
,
distributed_executor_backend
=
"mp"
,
gpu_memory_utilization
=
gpu_memory_utilization
,
enforce_eager
=
True
,
))
p
.
start
()
p
.
join
()
queue
=
ctx
.
Queue
()
p
=
ctx
.
Process
(
target
=
_run_generate
,
args
=
(
input_dir
,
queue
),
kwargs
=
dict
(
distributed_executor_backend
=
"mp"
,
enable_lora
=
enable_lora
,
gpu_memory_utilization
=
0.3
,
)
gen_before
=
llm_before
.
generate
(
prompts
,
sampling_params
)
out_before
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_before
]
del
llm_before
.
llm_engine
.
model_executor
llm_after
=
LLM
(
model
=
output_dir
,
worker_use_ray
=
True
,
gpu_memory_utilization
=
gpu_memory_utilization
,
tensor_parallel_size
=
tp_size
,
))
p
.
start
()
p
.
join
()
out_before
=
queue
.
get
()
p
=
ctx
.
Process
(
target
=
_run_generate
,
args
=
(
output_dir
,
queue
),
kwargs
=
dict
(
distributed_executor_backend
=
"mp"
,
enable_lora
=
enable_lora
,
gpu_memory_utilization
=
0.3
,
gpu_memory_utilization
=
gpu_memory_utilization
,
tensor_parallel_size
=
tp_size
,
load_format
=
"sharded_state"
,
)
gen_after
=
llm_after
.
generate
(
prompts
,
sampling_params
)
out_after
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_after
]
del
llm_after
.
llm_engine
.
model_executor
)
)
p
.
start
(
)
p
.
join
()
out_after
=
queue
.
get
()
assert
out_before
==
out_after
vllm/lora/layers.py
View file @
ccdc490d
...
...
@@ -215,19 +215,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
lora_vocab_start_idx
=
self
.
base_layer
.
org_vocab_size
weights_idx
=
None
if
self
.
base_layer
.
vocab_end_index
>
lora_vocab_start_idx
:
if
self
.
base_layer
.
num_added_embeddings_per_partition
>
0
:
# We can start adding lora weights
weights_idx
=
max
(
lora_vocab_start_idx
-
self
.
base_layer
.
vocab_start_index
,
0
)
self
.
embeddings_slice
=
(
self
.
base_layer
.
vocab_start_index
-
self
.
base_layer
.
org_vocab_size
+
weights_idx
,
self
.
base_layer
.
vocab_end_index
-
self
.
embeddings_weights
=
self
.
base_layer
.
weight
.
data
[
self
.
base_layer
.
num_org_embeddings_per_partition
:
self
.
base_layer
.
num_org_embeddings_per_partition
+
self
.
base_layer
.
num_added_embeddings_per_partition
]
self
.
embeddings_slice
=
(
self
.
base_layer
.
shard_indices
.
added_vocab_start_index
-
self
.
base_layer
.
org_vocab_size
,
self
.
base_layer
.
shard_indices
.
added_vocab_end_index
-
self
.
base_layer
.
org_vocab_size
)
self
.
embeddings_weights
=
self
.
base_layer
.
weight
.
data
[
weights_idx
:]
self
.
embeddings_weights
.
fill_
(
0
)
self
.
base_layer
.
weight
.
data
[
self
.
base_layer
.
num_org_embeddings_per_partition
:]
.
fill_
(
0
)
else
:
self
.
embeddings_slice
=
None
self
.
embeddings_weights
=
None
...
...
@@ -1025,19 +1025,31 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
class
LogitsProcessorWithLoRA
(
BaseLayerWithLoRA
):
"""
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
def
__init__
(
self
,
base_layer
:
LogitsProcessor
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
None
:
Args:
base_layer: LogitsProcessor layer
hidden_size: hidden size of the model
dtype: data type of the model
device: device of the model
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
received from base_layer.get_sharded_to_full_mapping(). If None,
no reindexing will be done.
"""
def
__init__
(
self
,
base_layer
:
LogitsProcessor
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
sharded_to_full_mapping
:
Optional
[
List
[
int
]])
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
device
=
device
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
sharded_to_full_mapping
=
sharded_to_full_mapping
@
property
def
logits_as_input
(
self
):
...
...
@@ -1098,6 +1110,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
if
self
.
sharded_to_full_mapping
is
not
None
:
self
.
sharded_to_full_mapping_gpu
=
torch
.
tensor
(
self
.
sharded_to_full_mapping
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
else
:
self
.
sharded_to_full_mapping_gpu
=
None
# Lazily initialized.
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
...
...
@@ -1154,6 +1173,25 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
if
logits
is
None
:
return
None
if
self
.
sharded_to_full_mapping_gpu
is
not
None
:
# Reindex full logits tensor to ensure 1:1 mapping between
# index and token_id
# Example for:
# org_vocab_size = 4
# added_vocab_size = 2
# pad_to_size = 8
# tp_size = 2
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
# Therefore, the mapping is expected to be:
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
# we get:
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits
=
logits
[:,
self
.
sharded_to_full_mapping_gpu
]
lora_logits
=
torch
.
empty
(
self
.
embeddings_tensors
.
shape
[
0
]
+
1
,
self
.
embeddings_tensors
.
shape
[
1
],
...
...
vllm/lora/utils.py
View file @
ccdc490d
...
...
@@ -67,7 +67,8 @@ def from_layer_logits_processor(
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
LogitsProcessorWithLoRA
:
ret
=
LogitsProcessorWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
,
lm_head
.
get_sharded_to_full_mapping
())
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
ccdc490d
from
typing
import
Optional
,
Sequence
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
import
torch
import
torch.nn.functional
as
F
...
...
@@ -18,18 +19,107 @@ def pad_vocab_size(vocab_size: int,
return
((
vocab_size
+
pad_to
-
1
)
//
pad_to
)
*
pad_to
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
:
int
,
rank
:
int
)
->
Sequence
[
int
]:
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
:
int
,
rank
:
int
,
offset
:
int
=
0
)
->
Sequence
[
int
]:
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
return
index_f
+
offset
,
index_l
+
offset
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
world_size
:
int
)
->
Sequence
[
int
]:
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
world_size
:
int
,
offset
:
int
=
0
)
->
Sequence
[
int
]:
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
)
rank
,
offset
=
offset
)
@
dataclass
class
VocabParallelEmbeddingShardIndices
:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index
:
int
padded_org_vocab_end_index
:
int
padded_added_vocab_start_index
:
int
padded_added_vocab_end_index
:
int
org_vocab_start_index
:
int
org_vocab_end_index
:
int
added_vocab_start_index
:
int
added_vocab_end_index
:
int
@
property
def
num_org_elements
(
self
)
->
int
:
return
self
.
org_vocab_end_index
-
self
.
org_vocab_start_index
@
property
def
num_added_elements
(
self
)
->
int
:
return
self
.
added_vocab_end_index
-
self
.
added_vocab_start_index
@
property
def
num_org_elements_padded
(
self
)
->
int
:
return
(
self
.
padded_org_vocab_end_index
-
self
.
padded_org_vocab_start_index
)
@
property
def
num_added_elements_padded
(
self
)
->
int
:
return
(
self
.
padded_added_vocab_end_index
-
self
.
padded_added_vocab_start_index
)
@
property
def
num_org_vocab_padding
(
self
)
->
int
:
return
self
.
num_org_elements_padded
-
self
.
num_org_elements
@
property
def
num_added_vocab_padding
(
self
)
->
int
:
return
self
.
num_added_elements_padded
-
self
.
num_added_elements
@
property
def
num_elements_padded
(
self
)
->
int
:
return
self
.
num_org_elements_padded
+
self
.
num_added_elements_padded
def
__post_init__
(
self
):
# sanity checks
assert
(
self
.
padded_org_vocab_start_index
<=
self
.
padded_org_vocab_end_index
)
assert
(
self
.
padded_added_vocab_start_index
<=
self
.
padded_added_vocab_end_index
)
assert
self
.
org_vocab_start_index
<=
self
.
org_vocab_end_index
assert
self
.
added_vocab_start_index
<=
self
.
added_vocab_end_index
assert
self
.
org_vocab_start_index
<=
self
.
padded_org_vocab_start_index
assert
(
self
.
added_vocab_start_index
<=
self
.
padded_added_vocab_start_index
)
assert
self
.
org_vocab_end_index
<=
self
.
padded_org_vocab_end_index
assert
self
.
added_vocab_end_index
<=
self
.
padded_added_vocab_end_index
assert
self
.
num_org_elements
<=
self
.
num_org_elements_padded
assert
self
.
num_added_elements
<=
self
.
num_added_elements_padded
@
torch
.
jit
.
script
def
get_masked_input_and_mask
(
input_
:
torch
.
Tensor
,
org_vocab_start_index
:
int
,
org_vocab_end_index
:
int
,
num_org_vocab_padding
:
int
,
added_vocab_start_index
:
int
,
added_vocab_end_index
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# torch.jit.script will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask
=
(
input_
>=
org_vocab_start_index
)
&
(
input_
<
org_vocab_end_index
)
added_vocab_mask
=
(
input_
>=
added_vocab_start_index
)
&
(
input_
<
added_vocab_end_index
)
added_offset
=
added_vocab_start_index
-
(
org_vocab_end_index
-
org_vocab_start_index
)
-
num_org_vocab_padding
valid_offset
=
(
org_vocab_start_index
*
org_vocab_mask
)
+
(
added_offset
*
added_vocab_mask
)
vocab_mask
=
org_vocab_mask
|
added_vocab_mask
input_
=
vocab_mask
*
(
input_
-
valid_offset
)
return
input_
,
~
vocab_mask
class
VocabParallelEmbedding
(
torch
.
nn
.
Module
):
...
...
@@ -38,13 +128,36 @@ class VocabParallelEmbedding(torch.nn.Module):
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
"""
# noqa: E501
def
__init__
(
self
,
num_embeddings
:
int
,
...
...
@@ -55,21 +168,39 @@ class VocabParallelEmbedding(torch.nn.Module):
super
().
__init__
()
# Keep the input dimensions.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_embeddings
=
num_embeddings
self
.
padding_size
=
padding_size
self
.
org_vocab_size
=
org_num_embeddings
or
num_embeddings
self
.
num_embeddings_padded
=
pad_vocab_size
(
num_embeddings
,
padding_size
)
num_added_embeddings
=
num_embeddings
-
self
.
org_vocab_size
self
.
org_vocab_size_padded
=
pad_vocab_size
(
self
.
org_vocab_size
,
self
.
padding_size
)
self
.
num_embeddings_padded
=
pad_vocab_size
(
self
.
org_vocab_size_padded
+
num_added_embeddings
,
self
.
padding_size
)
assert
self
.
org_vocab_size_padded
<=
self
.
num_embeddings_padded
self
.
shard_indices
=
self
.
_get_indices
(
self
.
num_embeddings_padded
,
self
.
org_vocab_size_padded
,
self
.
num_embeddings
,
self
.
org_vocab_size
,
tp_rank
,
self
.
tp_size
)
self
.
embedding_dim
=
embedding_dim
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
(
vocab_range_from_global_vocab_size
(
self
.
num_embeddings_padded
,
get_tensor_model_parallel_rank
(),
self
.
tp_size
))
self
.
num_embeddings_per_partition
=
(
self
.
vocab_end_index
-
self
.
vocab_start_index
)
self
.
num_added_embeddings
=
self
.
num_embeddings
-
self
.
org_vocab_size
self
.
num_embeddings_per_partition
=
divide
(
self
.
num_embeddings_padded
,
self
.
tp_size
)
assert
(
self
.
shard_indices
.
num_elements_padded
==
self
.
num_embeddings_per_partition
)
self
.
num_org_embeddings_per_partition
=
(
self
.
shard_indices
.
org_vocab_end_index
-
self
.
shard_indices
.
org_vocab_start_index
)
self
.
num_added_embeddings_per_partition
=
(
self
.
shard_indices
.
added_vocab_end_index
-
self
.
shard_indices
.
added_vocab_start_index
)
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
...
...
@@ -79,28 +210,107 @@ class VocabParallelEmbedding(torch.nn.Module):
"weight_loader"
:
self
.
weight_loader
})
@
classmethod
def
_get_indices
(
cls
,
vocab_size_padded
:
int
,
org_vocab_size_padded
:
int
,
vocab_size
:
int
,
org_vocab_size
:
int
,
tp_rank
:
int
,
tp_size
:
int
)
->
VocabParallelEmbeddingShardIndices
:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded
=
vocab_size_padded
-
org_vocab_size_padded
padded_org_vocab_start_index
,
padded_org_vocab_end_index
=
(
vocab_range_from_global_vocab_size
(
org_vocab_size_padded
,
tp_rank
,
tp_size
))
padded_added_vocab_start_index
,
padded_added_vocab_end_index
=
(
vocab_range_from_global_vocab_size
(
num_added_embeddings_padded
,
tp_rank
,
tp_size
,
offset
=
org_vocab_size
))
# remove padding
org_vocab_start_index
=
min
(
padded_org_vocab_start_index
,
org_vocab_size
)
org_vocab_end_index
=
min
(
padded_org_vocab_end_index
,
org_vocab_size
)
added_vocab_start_index
=
min
(
padded_added_vocab_start_index
,
vocab_size
)
added_vocab_end_index
=
min
(
padded_added_vocab_end_index
,
vocab_size
)
return
VocabParallelEmbeddingShardIndices
(
padded_org_vocab_start_index
,
padded_org_vocab_end_index
,
padded_added_vocab_start_index
,
padded_added_vocab_end_index
,
org_vocab_start_index
,
org_vocab_end_index
,
added_vocab_start_index
,
added_vocab_end_index
)
def
get_sharded_to_full_mapping
(
self
)
->
Optional
[
List
[
int
]]:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
During sampling, we gather logits from all ranks. The relationship
of index->token_id will follow the same format as outlined in the class
docstring. However, after the gather, we want to reindex the final
logits tensor to map index->token_id one-to-one (the index is always
equal the token_id it corresponds to). The indices returned by this
method allow us to do that.
"""
if
self
.
tp_size
<
2
:
return
None
base_embeddings
:
List
[
int
]
=
[]
added_embeddings
:
List
[
int
]
=
[]
padding
:
List
[
int
]
=
[]
for
tp_rank
in
range
(
self
.
tp_size
):
shard_indices
=
self
.
_get_indices
(
self
.
num_embeddings_padded
,
self
.
org_vocab_size_padded
,
self
.
num_embeddings
,
self
.
org_vocab_size
,
tp_rank
,
self
.
tp_size
)
range_start
=
self
.
num_embeddings_per_partition
*
tp_rank
range_end
=
self
.
num_embeddings_per_partition
*
(
tp_rank
+
1
)
base_embeddings
.
extend
(
range
(
range_start
,
range_start
+
shard_indices
.
num_org_elements
))
padding
.
extend
(
range
(
range_start
+
shard_indices
.
num_org_elements
,
range_start
+
shard_indices
.
num_org_elements_padded
))
added_embeddings
.
extend
(
range
(
range_start
+
shard_indices
.
num_org_elements_padded
,
range_start
+
shard_indices
.
num_org_elements_padded
+
shard_indices
.
num_added_elements
))
padding
.
extend
(
range
(
range_start
+
shard_indices
.
num_org_elements_padded
+
shard_indices
.
num_added_elements
,
range_start
+
shard_indices
.
num_org_elements_padded
+
shard_indices
.
num_added_elements_padded
))
assert
(
range_start
+
shard_indices
.
num_org_elements_padded
+
shard_indices
.
num_added_elements_padded
==
range_end
)
ret
=
base_embeddings
+
added_embeddings
+
padding
assert
len
(
ret
)
==
self
.
num_embeddings_padded
return
ret
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
parallel_dim
=
param
.
parallel_dim
assert
loaded_weight
.
shape
[
parallel_dim
]
==
self
.
org_vocab_size
loaded_weight
=
loaded_weight
[
self
.
vocab_start_index
:
self
.
vocab_end_index
]
loaded_weight
=
loaded_weight
[
self
.
shard_indices
.
org_
vocab_start_index
:
self
.
shard_indices
.
org_
vocab_end_index
]
param
[:
loaded_weight
.
shape
[
0
]].
data
.
copy_
(
loaded_weight
)
param
[
loaded_weight
.
shape
[
0
]:].
data
.
fill_
(
0
)
def
forward
(
self
,
input_
):
if
self
.
tp_size
>
1
:
# Build the mask.
input_mask
=
((
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
))
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
masked_input
,
input_mask
=
get_masked_input_and_mask
(
input_
,
self
.
shard_indices
.
org_vocab_start_index
,
self
.
shard_indices
.
org_vocab_end_index
,
self
.
shard_indices
.
num_org_vocab_padding
,
self
.
shard_indices
.
added_vocab_start_index
,
self
.
shard_indices
.
added_vocab_end_index
)
else
:
masked_input
=
input_
# Get the embeddings.
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
)
# Mask the output embedding.
if
self
.
tp_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
1
),
0
)
# Reduce across all the model parallel GPUs.
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
return
output
...
...
vllm/worker/model_runner.py
View file @
ccdc490d
...
...
@@ -35,6 +35,7 @@ _BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
]
_NUM_WARMUP_ITERS
=
2
class
ModelInput
(
NamedTuple
):
...
...
@@ -975,9 +976,11 @@ class CUDAGraphRunner:
**
kwargs
,
)
->
None
:
assert
self
.
_graph
is
None
# Run the model
once
without capturing the graph.
# Run the model
a few times
without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# Note one iteration is not enough for torch.jit.script
for
_
in
range
(
_NUM_WARMUP_ITERS
):
self
.
model
(
input_ids
,
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment