Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ff7ec82c
Unverified
Commit
ff7ec82c
authored
Aug 18, 2024
by
SangBin Cho
Committed by
GitHub
Aug 18, 2024
Browse files
[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
parent
200a2ffa
Changes
36
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
256 additions
and
84 deletions
+256
-84
requirements-common.txt
requirements-common.txt
+1
-0
tests/basic_correctness/test_preemption.py
tests/basic_correctness/test_preemption.py
+18
-0
tests/core/test_serialization.py
tests/core/test_serialization.py
+33
-0
tests/distributed/test_basic_distributed_correctness.py
tests/distributed/test_basic_distributed_correctness.py
+2
-1
tests/distributed/test_chunked_prefill_distributed.py
tests/distributed/test_chunked_prefill_distributed.py
+7
-0
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+19
-6
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+6
-3
tests/test_logits_processor.py
tests/test_logits_processor.py
+6
-2
tests/test_sequence.py
tests/test_sequence.py
+5
-2
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+11
-5
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+10
-5
vllm/adapter_commons/request.py
vllm/adapter_commons/request.py
+0
-2
vllm/config.py
vllm/config.py
+9
-2
vllm/core/scheduler.py
vllm/core/scheduler.py
+50
-38
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+3
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+0
-1
vllm/executor/msgspec_utils.py
vllm/executor/msgspec_utils.py
+27
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+15
-4
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+27
-12
vllm/inputs/registry.py
vllm/inputs/registry.py
+7
-1
No files found.
requirements-common.txt
View file @
ff7ec82c
...
...
@@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
msgspec
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
...
...
tests/basic_correctness/test_preemption.py
View file @
ff7ec82c
...
...
@@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`.
import
pytest
from
prometheus_client
import
REGISTRY
import
vllm.envs
as
envs
from
vllm
import
SamplingParams
from
vllm.core.scheduler
import
(
ARTIFICIAL_PREEMPTION_MAX_CNT
,
ENABLE_ARTIFICIAL_PREEMPT
)
...
...
@@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"tests/basic_correctness/test_preemption.py`"
)
@
pytest
.
fixture
def
worker_use_ray
()
->
bool
:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return
envs
.
VLLM_USE_RAY_SPMD_WORKER
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
...
...
@@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype
:
str
,
max_tokens
:
int
,
chunked_prefill_token_size
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
"""Ensure that chunked prefill works with preemption."""
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
...
...
@@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens
=
max_num_batched_tokens
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_seqs
=
max_num_seqs
,
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
assert
(
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
artificial_preempt_cnt
...
...
@@ -79,6 +89,7 @@ def test_preemption(
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
"""By default, recompute preemption is enabled"""
...
...
@@ -89,6 +100,7 @@ def test_preemption(
model
,
dtype
=
dtype
,
disable_log_stats
=
False
,
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
assert
(
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
artificial_preempt_cnt
...
...
@@ -132,6 +144,7 @@ def test_swap(
dtype
:
str
,
max_tokens
:
int
,
beam_width
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
"""Use beam search enables swapping."""
example_prompts
=
example_prompts
[:
1
]
...
...
@@ -144,6 +157,7 @@ def test_swap(
dtype
=
dtype
,
swap_space
=
10
,
disable_log_stats
=
False
,
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
max_tokens
)
...
...
@@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype
:
str
,
max_tokens
:
int
,
beam_width
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE
=
16
...
...
@@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish.
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
,
max_model_len
=
(
prefill_blocks
+
decode_blocks
)
*
BLOCK_SIZE
,
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
sampling_params
=
SamplingParams
(
n
=
beam_width
,
use_beam_search
=
True
,
...
...
@@ -230,6 +246,7 @@ def test_preemption_infeasible(
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE
=
16
...
...
@@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever.
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
//
2
,
max_model_len
=
((
prefill_blocks
+
decode_blocks
//
2
)
*
BLOCK_SIZE
),
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
ignore_eos
=
True
)
...
...
tests/core/test_serialization.py
0 → 100644
View file @
ff7ec82c
import
msgspec
from
vllm.executor.msgspec_utils
import
decode_hook
,
encode_hook
from
vllm.sequence
import
ExecuteModelRequest
from
..spec_decode.utils
import
create_batch
def
test_msgspec_serialization
():
num_lookahead_slots
=
4
seq_group_metadata_list
,
_
,
_
=
create_batch
(
16
,
num_lookahead_slots
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
num_lookahead_slots
,
running_queue_size
=
4
)
encoder
=
msgspec
.
msgpack
.
Encoder
(
enc_hook
=
encode_hook
)
decoder
=
msgspec
.
msgpack
.
Decoder
(
ExecuteModelRequest
,
dec_hook
=
decode_hook
)
req
=
decoder
.
decode
(
encoder
.
encode
(
execute_model_req
))
expected
=
execute_model_req
.
seq_group_metadata_list
actual
=
req
.
seq_group_metadata_list
assert
(
len
(
expected
)
==
len
(
actual
))
expected
=
expected
[
0
]
actual
=
actual
[
0
]
assert
expected
.
block_tables
==
actual
.
block_tables
assert
expected
.
is_prompt
==
actual
.
is_prompt
assert
expected
.
request_id
==
actual
.
request_id
assert
(
expected
.
seq_data
[
0
].
prompt_token_ids
==
actual
.
seq_data
[
0
].
prompt_token_ids
)
assert
(
expected
.
seq_data
[
0
].
output_token_ids
==
actual
.
seq_data
[
0
].
output_token_ids
)
tests/distributed/test_basic_distributed_correctness.py
View file @
ff7ec82c
...
...
@@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
@
pytest
.
mark
.
skipif
(
cuda_device_count_stateless
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"model, distributed_executor_backend, attention_backend, test_suite"
,
[
"model, distributed_executor_backend, attention_backend, "
"test_suite"
,
[
(
"facebook/opt-125m"
,
"ray"
,
""
,
"L4"
),
(
"facebook/opt-125m"
,
"mp"
,
""
,
"L4"
),
(
"meta-llama/Llama-2-7b-hf"
,
"ray"
,
""
,
"L4"
),
...
...
tests/distributed/test_chunked_prefill_distributed.py
View file @
ff7ec82c
...
...
@@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py
```
"""
import
os
import
pytest
from
vllm.utils
import
cuda_device_count_stateless
...
...
@@ -30,6 +32,11 @@ def test_models(
model
:
str
,
distributed_executor_backend
:
str
,
)
->
None
:
if
model
==
"meta-llama/Llama-2-7b-hf"
and
distributed_executor_backend
==
"ray"
:
# noqa
assert
distributed_executor_backend
==
"ray"
# test ray adag
os
.
environ
[
'VLLM_USE_RAY_SPMD_WORKER'
]
=
"1"
os
.
environ
[
'VLLM_USE_RAY_COMPILED_DAG'
]
=
"1"
dtype
=
"half"
max_tokens
=
5
...
...
tests/samplers/test_sampler.py
View file @
ff7ec82c
import
itertools
import
random
from
array
import
array
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
Mock
,
patch
...
...
@@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
Counter
,
is_pin_memory_available
...
...
@@ -56,7 +58,9 @@ def _do_sample(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
...
...
@@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
seq_data
=
SequenceData
(
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
)))
if
num_generated
>
0
:
seq_data
.
output_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_generated
)
...
...
@@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
...
...
@@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
SamplingParams
(
temperature
=
1
,
top_k
=
top_k
,
...
...
@@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
[
i
],
block_tables
=
{
0
:
[
1
]},
))
...
...
tests/spec_decode/utils.py
View file @
ff7ec82c
from
array
import
array
from
itertools
import
count
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
...
...
@@ -9,7 +10,8 @@ import torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
...
...
@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
seq_data
=
{
i
:
SequenceData
(
prompt_token_ids
=
prompt_token_ids
[:],
output_token_ids
=
cont_token_ids
[:],
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
prompt_token_ids
[:]),
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
cont_token_ids
[:]),
),
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
...
...
tests/test_logits_processor.py
View file @
ff7ec82c
import
random
from
array
import
array
from
typing
import
Tuple
from
unittest.mock
import
patch
...
...
@@ -8,7 +9,8 @@ import torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_pin_memory_available
...
...
@@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
...
...
tests/test_sequence.py
View file @
ff7ec82c
from
array
import
array
import
pytest
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SamplerOutput
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
SamplerOutput
,
SequenceData
,
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
...
...
@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
def
test_sequence_data_prefill
():
seq_data
=
SequenceData
(
prompt_token_ids
=
[
1
,
2
,
3
,
4
])
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
,
4
])
)
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_computed_tokens
()
==
0
# advance by 2
...
...
tests/worker/test_encoder_decoder_model_runner.py
View file @
ff7ec82c
from
array
import
array
from
typing
import
List
import
pytest
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_cpu
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
...
...
@@ -125,10 +127,12 @@ def test_prepare_prompt(
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
list
(
range
(
encoder_seq_len
)))
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
encoder_seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
...
...
@@ -319,10 +323,12 @@ def test_prepare_decode(
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
list
(
range
(
encoder_seq_len
)))
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
...
...
tests/worker/test_model_runner.py
View file @
ff7ec82c
from
array
import
array
from
typing
import
List
import
pytest
...
...
@@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
...
@@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
...
...
@@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_lens
.
append
(
context_len
)
seq_data
=
SequenceData
(
list
(
range
(
context_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
)))
seq_data
.
update_num_computed_tokens
(
context_len
)
# Append one token ID since prefill is finished.
seq_data
.
append_token_id
(
1
,
0
)
...
...
@@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
...
...
@@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
list
(
range
(
context_len
))
prompt_toks
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
update_num_computed_tokens
(
context_len
)
...
...
vllm/adapter_commons/request.py
View file @
ff7ec82c
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
@
dataclass
class
AdapterRequest
(
ABC
):
"""
Base class for adapter requests.
...
...
vllm/config.py
View file @
ff7ec82c
...
...
@@ -770,8 +770,8 @@ class ParallelConfig:
self
.
tokenizer_pool_config
=
tokenizer_pool_config
self
.
ray_workers_use_nsight
=
ray_workers_use_nsight
self
.
placement_group
=
placement_group
self
.
world_size
=
pipeline_parallel_size
*
self
.
tensor_parallel_size
if
worker_use_ray
:
if
self
.
distributed_executor_backend
is
None
:
self
.
distributed_executor_backend
=
"ray"
...
...
@@ -867,6 +867,11 @@ class SchedulerConfig:
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead.
send_delta_data: Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1
"""
def
__init__
(
self
,
...
...
@@ -879,7 +884,8 @@ class SchedulerConfig:
enable_chunked_prefill
:
bool
=
False
,
embedding_mode
:
Optional
[
bool
]
=
False
,
preemption_mode
:
Optional
[
str
]
=
None
,
num_scheduler_steps
:
int
=
1
)
->
None
:
num_scheduler_steps
:
int
=
1
,
send_delta_data
:
bool
=
False
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
else
:
...
...
@@ -909,6 +915,7 @@ class SchedulerConfig:
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
send_delta_data
=
send_delta_data
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/scheduler.py
View file @
ff7ec82c
...
...
@@ -12,7 +12,8 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
,
SequenceStatus
)
from
vllm.utils
import
PyObjectCache
logger
=
init_logger
(
__name__
)
...
...
@@ -363,8 +364,6 @@ class Scheduler:
self
.
num_cumulative_preemption
:
int
=
0
# Used to cache python objects
self
.
_seq_group_metadata_cache
:
PyObjectCache
=
PyObjectCache
(
seq_group_metadata_builder
)
self
.
_scheduler_running_outputs_cache
:
PyObjectCache
=
PyObjectCache
(
scheduler_running_outputs_builder
)
self
.
_scheduled_seq_group_cache
:
PyObjectCache
=
PyObjectCache
(
...
...
@@ -1048,15 +1047,10 @@ class Scheduler:
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group_metadata
=
self
.
_seq_group_metadata_cache
.
get_object
()
seq_group_metadata
.
seq_data
.
clear
()
seq_group_metadata
.
block_tables
.
clear
()
# seq_id -> SequenceData
seq_data
:
Dict
[
int
,
SequenceData
]
=
seq_group_metadata
.
seq_data
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
# seq_id -> physical block numbers
block_tables
:
Dict
[
int
,
List
[
int
]]
=
seq_group_metadata
.
block_tables
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
if
seq_group
.
is_encoder_decoder
():
# Encoder associated with SequenceGroup
...
...
@@ -1081,24 +1075,29 @@ class Scheduler:
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
do_sample
=
True
if
seq_group
.
is_prefill
():
is_prompt
=
seq_group
.
is_prefill
()
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill
=
False
if
is_prompt
:
seqs
=
seq_group
.
get_seqs
()
# Prefill has only 1 sequence.
assert
len
(
seqs
)
==
1
num_computed_tokens
=
seqs
[
0
].
data
.
get_num_computed_tokens
()
is_first_prefill
=
num_computed_tokens
==
0
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if
(
token_chunk_size
+
seqs
[
0
].
data
.
get_
num_computed_tokens
()
<
if
(
token_chunk_size
+
num_computed_tokens
<
seqs
[
0
].
data
.
get_len
()):
do_sample
=
False
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt
=
seq_group
.
is_prefill
()
seq_group_metadata
.
__init__
(
if
is_first_prefill
or
not
self
.
scheduler_config
.
send_delta_data
:
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
seq_group
.
request_id
,
is_prompt
=
is_prompt
,
seq_data
=
seq_data
,
...
...
@@ -1120,6 +1119,21 @@ class Scheduler:
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
)
else
:
# When SPMD mode is enabled, we only send delta data except for
# the first request to reduce serialization cost.
seq_data_delta
=
{}
for
id
,
data
in
seq_data
.
items
():
seq_data_delta
[
id
]
=
data
.
get_delta_and_reset
()
seq_group_metadata
=
SequenceGroupMetadataDelta
(
seq_data_delta
,
seq_group
.
request_id
,
block_tables
,
is_prompt
,
do_sample
=
do_sample
,
token_chunk_size
=
token_chunk_size
,
computed_block_nums
=
common_computed_block_nums
,
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
# Now that the batch has been created, we can assume all blocks in the
...
...
@@ -1130,8 +1144,6 @@ class Scheduler:
self
.
block_manager
.
mark_blocks_as_computed
(
scheduled_seq_group
.
seq_group
)
self
.
_seq_group_metadata_cache
.
reset
()
scheduler_time
=
time
.
perf_counter
()
-
scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
...
...
vllm/engine/arg_utils.py
View file @
ff7ec82c
...
...
@@ -5,6 +5,7 @@ from dataclasses import dataclass
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
...
...
@@ -905,6 +906,8 @@ class EngineArgs:
embedding_mode
=
model_config
.
embedding_mode
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
and
parallel_config
.
use_ray
),
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
...
...
vllm/engine/llm_engine.py
View file @
ff7ec82c
...
...
@@ -224,7 +224,6 @@ class LLMEngine:
cache_config
.
enable_prefix_caching
,
)
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
...
...
vllm/executor/msgspec_utils.py
0 → 100644
View file @
ff7ec82c
from
array
import
array
from
typing
import
Any
,
Type
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
def
encode_hook
(
obj
:
Any
)
->
Any
:
"""Custom msgspec enc hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if
isinstance
(
obj
,
array
):
assert
obj
.
typecode
==
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
f
"vLLM array type should use '
{
VLLM_TOKEN_ID_ARRAY_TYPE
}
' type. "
f
"Given array has a type code of
{
obj
.
typecode
}
."
)
return
obj
.
tobytes
()
def
decode_hook
(
type
:
Type
,
obj
:
Any
)
->
Any
:
"""Custom msgspec dec hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if
type
is
array
:
deserialized
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
)
deserialized
.
frombytes
(
obj
)
return
deserialized
vllm/executor/ray_gpu_executor.py
View file @
ff7ec82c
...
...
@@ -4,9 +4,12 @@ from collections import defaultdict
from
itertools
import
islice
,
repeat
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
import
msgspec
import
vllm.envs
as
envs
from
vllm.executor.distributed_gpu_executor
import
(
# yapf: disable
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
from
vllm.executor.msgspec_utils
import
encode_hook
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
...
...
@@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
self
.
input_encoder
=
msgspec
.
msgpack
.
Encoder
(
enc_hook
=
encode_hook
)
self
.
output_decoder
=
msgspec
.
msgpack
.
Decoder
(
Optional
[
List
[
SamplerOutput
]])
def
shutdown
(
self
)
->
None
:
if
hasattr
(
self
,
"forward_dag"
)
and
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
...
...
@@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_remote_kwargs
)
logger
.
info
(
"use_ray_spmd_worker: %s"
,
self
.
use_ray_spmd_worker
)
# Create the workers.
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
...
...
@@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
outputs
=
ray
.
get
(
self
.
forward_dag
.
execute
(
execute_model_req
))
return
outputs
[
0
]
serialized_data
=
self
.
input_encoder
.
encode
(
execute_model_req
)
outputs
=
ray
.
get
(
self
.
forward_dag
.
execute
(
serialized_data
))
output
=
self
.
output_decoder
.
decode
(
outputs
[
0
])
return
output
def
_run_workers
(
self
,
...
...
@@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
True
)
dag_future
=
await
self
.
forward_dag
.
execute_async
(
execute_model_req
)
serialized_data
=
self
.
input_encoder
.
encode
(
execute_model_req
)
dag_future
=
await
self
.
forward_dag
.
execute_async
(
serialized_data
)
outputs
=
await
dag_future
return
outputs
[
0
]
return
self
.
output_decoder
.
decode
(
outputs
[
0
]
)
async
def
_driver_execute_model_async
(
self
,
...
...
vllm/executor/ray_utils.py
View file @
ff7ec82c
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
msgspec
from
vllm.config
import
ParallelConfig
from
vllm.executor.msgspec_utils
import
decode_hook
,
encode_hook
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
...
...
@@ -24,6 +27,10 @@ try:
# that thread.
self
.
compiled_dag_cuda_device_set
=
False
self
.
input_decoder
=
msgspec
.
msgpack
.
Decoder
(
ExecuteModelRequest
,
dec_hook
=
decode_hook
)
self
.
output_encoder
=
msgspec
.
msgpack
.
Encoder
(
enc_hook
=
encode_hook
)
def
get_node_ip
(
self
)
->
str
:
return
get_ip
()
...
...
@@ -33,16 +40,26 @@ try:
return
node_id
,
gpu_ids
def
execute_model_spmd
(
self
,
req_or_tuple
:
Union
[
ExecuteModelRequest
,
Tuple
[
ExecuteModelRequest
,
IntermediateTensors
]]):
self
,
req_or_tuple
:
Union
[
bytes
,
Tuple
[
bytes
,
Optional
[
IntermediateTensors
]]]
)
->
bytes
:
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
Args:
req_or_tuple: The request to execute the model, or a tuple
containing the request and intermediate tensors.
req_or_tuple: A request or a tuple containing the
request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
"""
if
isinstance
(
req_or_tuple
,
bytes
):
serialized_req
,
intermediate_tensors
=
req_or_tuple
,
None
else
:
serialized_req
,
intermediate_tensors
=
req_or_tuple
execute_model_req
=
self
.
input_decoder
.
decode
(
serialized_req
)
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
...
...
@@ -51,16 +68,14 @@ try:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
if
isinstance
(
req_or_tuple
,
tuple
):
execute_model_req
,
intermediate_tensors
=
req_or_tuple
else
:
execute_model_req
=
req_or_tuple
intermediate_tensors
=
None
output
=
self
.
worker
.
_execute_model_spmd
(
execute_model_req
,
intermediate_tensors
)
# Pipeline model request and output to the next pipeline stage.
if
isinstance
(
output
,
IntermediateTensors
):
return
execute_model_req
,
output
output
=
serialized_req
,
output
else
:
output
=
self
.
output_encoder
.
encode
(
output
)
return
output
ray_import_err
=
None
...
...
vllm/inputs/registry.py
View file @
ff7ec82c
import
functools
from
array
import
array
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Dict
,
Mapping
,
Optional
,
Protocol
,
...
...
@@ -21,6 +22,10 @@ logger = init_logger(__name__)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
@
dataclass
(
frozen
=
True
)
class
InputContext
:
...
...
@@ -118,7 +123,8 @@ class InputRegistry:
# Avoid circular import
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
([
0
]
*
seq_len
)
dummy_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
seq_len
)
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment