Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ff7ec82c
Unverified
Commit
ff7ec82c
authored
Aug 18, 2024
by
SangBin Cho
Committed by
GitHub
Aug 18, 2024
Browse files
[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
parent
200a2ffa
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
256 additions
and
84 deletions
+256
-84
requirements-common.txt
requirements-common.txt
+1
-0
tests/basic_correctness/test_preemption.py
tests/basic_correctness/test_preemption.py
+18
-0
tests/core/test_serialization.py
tests/core/test_serialization.py
+33
-0
tests/distributed/test_basic_distributed_correctness.py
tests/distributed/test_basic_distributed_correctness.py
+2
-1
tests/distributed/test_chunked_prefill_distributed.py
tests/distributed/test_chunked_prefill_distributed.py
+7
-0
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+19
-6
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+6
-3
tests/test_logits_processor.py
tests/test_logits_processor.py
+6
-2
tests/test_sequence.py
tests/test_sequence.py
+5
-2
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+11
-5
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+10
-5
vllm/adapter_commons/request.py
vllm/adapter_commons/request.py
+0
-2
vllm/config.py
vllm/config.py
+9
-2
vllm/core/scheduler.py
vllm/core/scheduler.py
+50
-38
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+3
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+0
-1
vllm/executor/msgspec_utils.py
vllm/executor/msgspec_utils.py
+27
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+15
-4
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+27
-12
vllm/inputs/registry.py
vllm/inputs/registry.py
+7
-1
No files found.
requirements-common.txt
View file @
ff7ec82c
...
@@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
...
@@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
pyzmq
msgspec
librosa # Required for audio processing
librosa # Required for audio processing
soundfile # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
gguf == 0.9.1
...
...
tests/basic_correctness/test_preemption.py
View file @
ff7ec82c
...
@@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`.
...
@@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`.
import
pytest
import
pytest
from
prometheus_client
import
REGISTRY
from
prometheus_client
import
REGISTRY
import
vllm.envs
as
envs
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.core.scheduler
import
(
ARTIFICIAL_PREEMPTION_MAX_CNT
,
from
vllm.core.scheduler
import
(
ARTIFICIAL_PREEMPTION_MAX_CNT
,
ENABLE_ARTIFICIAL_PREEMPT
)
ENABLE_ARTIFICIAL_PREEMPT
)
...
@@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, (
...
@@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"tests/basic_correctness/test_preemption.py`"
)
"tests/basic_correctness/test_preemption.py`"
)
@
pytest
.
fixture
def
worker_use_ray
()
->
bool
:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return
envs
.
VLLM_USE_RAY_SPMD_WORKER
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
...
@@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
...
@@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
chunked_prefill_token_size
:
int
,
chunked_prefill_token_size
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
)
->
None
:
"""Ensure that chunked prefill works with preemption."""
"""Ensure that chunked prefill works with preemption."""
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
...
@@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
...
@@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
enable_chunked_prefill
=
enable_chunked_prefill
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
assert
(
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
artificial_preempt_cnt
assert
(
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
artificial_preempt_cnt
...
@@ -79,6 +89,7 @@ def test_preemption(
...
@@ -79,6 +89,7 @@ def test_preemption(
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
)
->
None
:
"""By default, recompute preemption is enabled"""
"""By default, recompute preemption is enabled"""
...
@@ -89,6 +100,7 @@ def test_preemption(
...
@@ -89,6 +100,7 @@ def test_preemption(
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
disable_log_stats
=
False
,
disable_log_stats
=
False
,
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
assert
(
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
artificial_preempt_cnt
assert
(
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
artificial_preempt_cnt
...
@@ -132,6 +144,7 @@ def test_swap(
...
@@ -132,6 +144,7 @@ def test_swap(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
beam_width
:
int
,
beam_width
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
)
->
None
:
"""Use beam search enables swapping."""
"""Use beam search enables swapping."""
example_prompts
=
example_prompts
[:
1
]
example_prompts
=
example_prompts
[:
1
]
...
@@ -144,6 +157,7 @@ def test_swap(
...
@@ -144,6 +157,7 @@ def test_swap(
dtype
=
dtype
,
dtype
=
dtype
,
swap_space
=
10
,
swap_space
=
10
,
disable_log_stats
=
False
,
disable_log_stats
=
False
,
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_beam_search
(
example_prompts
,
vllm_outputs
=
vllm_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
max_tokens
)
beam_width
,
max_tokens
)
...
@@ -188,6 +202,7 @@ def test_swap_infeasible(
...
@@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
beam_width
:
int
,
beam_width
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
)
->
None
:
"""Verify infeasible swap request will be ignored."""
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE
=
16
BLOCK_SIZE
=
16
...
@@ -204,6 +219,7 @@ def test_swap_infeasible(
...
@@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish.
# decode blocks are not enough to finish.
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
,
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
,
max_model_len
=
(
prefill_blocks
+
decode_blocks
)
*
BLOCK_SIZE
,
max_model_len
=
(
prefill_blocks
+
decode_blocks
)
*
BLOCK_SIZE
,
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
)
as
vllm_model
:
sampling_params
=
SamplingParams
(
n
=
beam_width
,
sampling_params
=
SamplingParams
(
n
=
beam_width
,
use_beam_search
=
True
,
use_beam_search
=
True
,
...
@@ -230,6 +246,7 @@ def test_preemption_infeasible(
...
@@ -230,6 +246,7 @@ def test_preemption_infeasible(
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
worker_use_ray
:
bool
,
)
->
None
:
)
->
None
:
"""Verify infeasible preemption request will be ignored."""
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE
=
16
BLOCK_SIZE
=
16
...
@@ -244,6 +261,7 @@ def test_preemption_infeasible(
...
@@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever.
# ignored instead of hanging forever.
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
//
2
,
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
//
2
,
max_model_len
=
((
prefill_blocks
+
decode_blocks
//
2
)
*
BLOCK_SIZE
),
max_model_len
=
((
prefill_blocks
+
decode_blocks
//
2
)
*
BLOCK_SIZE
),
worker_use_ray
=
worker_use_ray
,
)
as
vllm_model
:
)
as
vllm_model
:
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
ignore_eos
=
True
)
ignore_eos
=
True
)
...
...
tests/core/test_serialization.py
0 → 100644
View file @
ff7ec82c
import
msgspec
from
vllm.executor.msgspec_utils
import
decode_hook
,
encode_hook
from
vllm.sequence
import
ExecuteModelRequest
from
..spec_decode.utils
import
create_batch
def
test_msgspec_serialization
():
num_lookahead_slots
=
4
seq_group_metadata_list
,
_
,
_
=
create_batch
(
16
,
num_lookahead_slots
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
num_lookahead_slots
,
running_queue_size
=
4
)
encoder
=
msgspec
.
msgpack
.
Encoder
(
enc_hook
=
encode_hook
)
decoder
=
msgspec
.
msgpack
.
Decoder
(
ExecuteModelRequest
,
dec_hook
=
decode_hook
)
req
=
decoder
.
decode
(
encoder
.
encode
(
execute_model_req
))
expected
=
execute_model_req
.
seq_group_metadata_list
actual
=
req
.
seq_group_metadata_list
assert
(
len
(
expected
)
==
len
(
actual
))
expected
=
expected
[
0
]
actual
=
actual
[
0
]
assert
expected
.
block_tables
==
actual
.
block_tables
assert
expected
.
is_prompt
==
actual
.
is_prompt
assert
expected
.
request_id
==
actual
.
request_id
assert
(
expected
.
seq_data
[
0
].
prompt_token_ids
==
actual
.
seq_data
[
0
].
prompt_token_ids
)
assert
(
expected
.
seq_data
[
0
].
output_token_ids
==
actual
.
seq_data
[
0
].
output_token_ids
)
tests/distributed/test_basic_distributed_correctness.py
View file @
ff7ec82c
...
@@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
...
@@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
@
pytest
.
mark
.
skipif
(
cuda_device_count_stateless
()
<
2
,
@
pytest
.
mark
.
skipif
(
cuda_device_count_stateless
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model, distributed_executor_backend, attention_backend, test_suite"
,
[
"model, distributed_executor_backend, attention_backend, "
"test_suite"
,
[
(
"facebook/opt-125m"
,
"ray"
,
""
,
"L4"
),
(
"facebook/opt-125m"
,
"ray"
,
""
,
"L4"
),
(
"facebook/opt-125m"
,
"mp"
,
""
,
"L4"
),
(
"facebook/opt-125m"
,
"mp"
,
""
,
"L4"
),
(
"meta-llama/Llama-2-7b-hf"
,
"ray"
,
""
,
"L4"
),
(
"meta-llama/Llama-2-7b-hf"
,
"ray"
,
""
,
"L4"
),
...
...
tests/distributed/test_chunked_prefill_distributed.py
View file @
ff7ec82c
...
@@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py
...
@@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py
```
```
"""
"""
import
os
import
pytest
import
pytest
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
...
@@ -30,6 +32,11 @@ def test_models(
...
@@ -30,6 +32,11 @@ def test_models(
model
:
str
,
model
:
str
,
distributed_executor_backend
:
str
,
distributed_executor_backend
:
str
,
)
->
None
:
)
->
None
:
if
model
==
"meta-llama/Llama-2-7b-hf"
and
distributed_executor_backend
==
"ray"
:
# noqa
assert
distributed_executor_backend
==
"ray"
# test ray adag
os
.
environ
[
'VLLM_USE_RAY_SPMD_WORKER'
]
=
"1"
os
.
environ
[
'VLLM_USE_RAY_COMPILED_DAG'
]
=
"1"
dtype
=
"half"
dtype
=
"half"
max_tokens
=
5
max_tokens
=
5
...
...
tests/samplers/test_sampler.py
View file @
ff7ec82c
import
itertools
import
itertools
import
random
import
random
from
array
import
array
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
Mock
,
patch
from
unittest.mock
import
Mock
,
patch
...
@@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin
...
@@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
Counter
,
is_pin_memory_available
from
vllm.utils
import
Counter
,
is_pin_memory_available
...
@@ -56,7 +58,9 @@ def _do_sample(
...
@@ -56,7 +58,9 @@ def _do_sample(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
...
@@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
seq_data
=
SequenceData
(
seq_data
=
SequenceData
(
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
)))
if
num_generated
>
0
:
if
num_generated
>
0
:
seq_data
.
output_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
seq_data
.
output_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_generated
)
k
=
num_generated
)
...
@@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
...
@@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
1
,
temperature
=
1
,
top_k
=
top_k
,
top_k
=
top_k
,
...
@@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str):
...
@@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
[
i
],
sampling_params
=
sampling_params
[
i
],
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
...
...
tests/spec_decode/utils.py
View file @
ff7ec82c
from
array
import
array
from
itertools
import
count
from
itertools
import
count
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -9,7 +10,8 @@ import torch
...
@@ -9,7 +10,8 @@ import torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
...
@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
...
@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
seq_data
=
{
seq_data
=
{
i
:
i
:
SequenceData
(
SequenceData
(
prompt_token_ids
=
prompt_token_ids
[:],
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
prompt_token_ids
[:]),
output_token_ids
=
cont_token_ids
[:],
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
cont_token_ids
[:]),
),
),
},
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
...
...
tests/test_logits_processor.py
View file @
ff7ec82c
import
random
import
random
from
array
import
array
from
typing
import
Tuple
from
typing
import
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -8,7 +9,8 @@ import torch
...
@@ -8,7 +9,8 @@ import torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
...
@@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
SamplingParams
(
temperature
=
0
,
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
...
...
tests/test_sequence.py
View file @
ff7ec82c
from
array
import
array
import
pytest
import
pytest
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SamplerOutput
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
SamplerOutput
,
SequenceData
,
SequenceOutput
)
SequenceData
,
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
from
.core.utils
import
create_dummy_prompt
...
@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
...
@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
def
test_sequence_data_prefill
():
def
test_sequence_data_prefill
():
seq_data
=
SequenceData
(
prompt_token_ids
=
[
1
,
2
,
3
,
4
])
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
,
4
])
)
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_computed_tokens
()
==
0
assert
seq_data
.
get_num_computed_tokens
()
==
0
# advance by 2
# advance by 2
...
...
tests/worker/test_encoder_decoder_model_runner.py
View file @
ff7ec82c
from
array
import
array
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_cpu
from
vllm.utils
import
is_cpu
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
...
@@ -125,10 +127,12 @@ def test_prepare_prompt(
...
@@ -125,10 +127,12 @@ def test_prepare_prompt(
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
list
(
range
(
encoder_seq_len
)))
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
encoder_seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -319,10 +323,12 @@ def test_prepare_decode(
...
@@ -319,10 +323,12 @@ def test_prepare_decode(
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
list
(
range
(
encoder_seq_len
)))
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
is_prompt
=
False
,
...
...
tests/worker/test_model_runner.py
View file @
ff7ec82c
from
array
import
array
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
...
@@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
...
@@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
@@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size):
...
@@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_lens
.
append
(
context_len
)
context_lens
.
append
(
context_len
)
seq_data
=
SequenceData
(
list
(
range
(
context_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
)))
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
# Append one token ID since prefill is finished.
# Append one token ID since prefill is finished.
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
...
@@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
list
(
range
(
context_len
))
prompt_toks
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
...
...
vllm/adapter_commons/request.py
View file @
ff7ec82c
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
@
dataclass
class
AdapterRequest
(
ABC
):
class
AdapterRequest
(
ABC
):
"""
"""
Base class for adapter requests.
Base class for adapter requests.
...
...
vllm/config.py
View file @
ff7ec82c
...
@@ -770,8 +770,8 @@ class ParallelConfig:
...
@@ -770,8 +770,8 @@ class ParallelConfig:
self
.
tokenizer_pool_config
=
tokenizer_pool_config
self
.
tokenizer_pool_config
=
tokenizer_pool_config
self
.
ray_workers_use_nsight
=
ray_workers_use_nsight
self
.
ray_workers_use_nsight
=
ray_workers_use_nsight
self
.
placement_group
=
placement_group
self
.
placement_group
=
placement_group
self
.
world_size
=
pipeline_parallel_size
*
self
.
tensor_parallel_size
self
.
world_size
=
pipeline_parallel_size
*
self
.
tensor_parallel_size
if
worker_use_ray
:
if
worker_use_ray
:
if
self
.
distributed_executor_backend
is
None
:
if
self
.
distributed_executor_backend
is
None
:
self
.
distributed_executor_backend
=
"ray"
self
.
distributed_executor_backend
=
"ray"
...
@@ -867,6 +867,11 @@ class SchedulerConfig:
...
@@ -867,6 +867,11 @@ class SchedulerConfig:
swapping. However, when the sequence group has multiple sequences
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead.
such a case, we use swapping instead.
send_delta_data: Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -879,7 +884,8 @@ class SchedulerConfig:
...
@@ -879,7 +884,8 @@ class SchedulerConfig:
enable_chunked_prefill
:
bool
=
False
,
enable_chunked_prefill
:
bool
=
False
,
embedding_mode
:
Optional
[
bool
]
=
False
,
embedding_mode
:
Optional
[
bool
]
=
False
,
preemption_mode
:
Optional
[
str
]
=
None
,
preemption_mode
:
Optional
[
str
]
=
None
,
num_scheduler_steps
:
int
=
1
)
->
None
:
num_scheduler_steps
:
int
=
1
,
send_delta_data
:
bool
=
False
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
else
:
else
:
...
@@ -909,6 +915,7 @@ class SchedulerConfig:
...
@@ -909,6 +915,7 @@ class SchedulerConfig:
self
.
embedding_mode
=
embedding_mode
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
preemption_mode
=
preemption_mode
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
send_delta_data
=
send_delta_data
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/scheduler.py
View file @
ff7ec82c
...
@@ -12,7 +12,8 @@ from vllm.logger import init_logger
...
@@ -12,7 +12,8 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
,
SequenceStatus
)
from
vllm.utils
import
PyObjectCache
from
vllm.utils
import
PyObjectCache
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -363,8 +364,6 @@ class Scheduler:
...
@@ -363,8 +364,6 @@ class Scheduler:
self
.
num_cumulative_preemption
:
int
=
0
self
.
num_cumulative_preemption
:
int
=
0
# Used to cache python objects
# Used to cache python objects
self
.
_seq_group_metadata_cache
:
PyObjectCache
=
PyObjectCache
(
seq_group_metadata_builder
)
self
.
_scheduler_running_outputs_cache
:
PyObjectCache
=
PyObjectCache
(
self
.
_scheduler_running_outputs_cache
:
PyObjectCache
=
PyObjectCache
(
scheduler_running_outputs_builder
)
scheduler_running_outputs_builder
)
self
.
_scheduled_seq_group_cache
:
PyObjectCache
=
PyObjectCache
(
self
.
_scheduled_seq_group_cache
:
PyObjectCache
=
PyObjectCache
(
...
@@ -1048,15 +1047,10 @@ class Scheduler:
...
@@ -1048,15 +1047,10 @@ class Scheduler:
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group_metadata
=
self
.
_seq_group_metadata_cache
.
get_object
()
seq_group_metadata
.
seq_data
.
clear
()
seq_group_metadata
.
block_tables
.
clear
()
# seq_id -> SequenceData
# seq_id -> SequenceData
seq_data
:
Dict
[
int
,
SequenceData
]
=
seq_group_metadata
.
seq_data
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
# seq_id -> physical block numbers
# seq_id -> physical block numbers
block_tables
:
Dict
[
int
,
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
List
[
int
]]
=
seq_group_metadata
.
block_tables
if
seq_group
.
is_encoder_decoder
():
if
seq_group
.
is_encoder_decoder
():
# Encoder associated with SequenceGroup
# Encoder associated with SequenceGroup
...
@@ -1081,45 +1075,65 @@ class Scheduler:
...
@@ -1081,45 +1075,65 @@ class Scheduler:
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
do_sample
=
True
do_sample
=
True
if
seq_group
.
is_prefill
():
is_prompt
=
seq_group
.
is_prefill
()
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill
=
False
if
is_prompt
:
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
# Prefill has only 1 sequence.
# Prefill has only 1 sequence.
assert
len
(
seqs
)
==
1
assert
len
(
seqs
)
==
1
num_computed_tokens
=
seqs
[
0
].
data
.
get_num_computed_tokens
()
is_first_prefill
=
num_computed_tokens
==
0
# In the next iteration, all prompt tokens are not computed.
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# a sequence is preempted, prefill includes previous generated
# output tokens.
# output tokens.
if
(
token_chunk_size
+
seqs
[
0
].
data
.
get_
num_computed_tokens
()
<
if
(
token_chunk_size
+
num_computed_tokens
<
seqs
[
0
].
data
.
get_len
()):
seqs
[
0
].
data
.
get_len
()):
do_sample
=
False
do_sample
=
False
# It assumes the scheduled_seq_groups is ordered by
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
# prefill < decoding.
is_prompt
=
seq_group
.
is_prefill
()
if
is_first_prefill
or
not
self
.
scheduler_config
.
send_delta_data
:
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
.
__init__
(
request_id
=
seq_group
.
request_id
,
request_id
=
seq_group
.
request_id
,
is_prompt
=
is_prompt
,
is_prompt
=
is_prompt
,
seq_data
=
seq_data
,
seq_data
=
seq_data
,
sampling_params
=
seq_group
.
sampling_params
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
do_sample
=
do_sample
,
do_sample
=
do_sample
,
pooling_params
=
seq_group
.
pooling_params
,
pooling_params
=
seq_group
.
pooling_params
,
token_chunk_size
=
token_chunk_size
,
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
encoder_seq_data
=
encoder_seq_data
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
cross_block_table
=
cross_block_table
,
state
=
seq_group
.
state
,
state
=
seq_group
.
state
,
# `multi_modal_data` will only be present for the 1st comm
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# between engine and worker.
# the subsequent comms can still use delta, but
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
# `multi_modal_data` will be None.
multi_modal_data
=
seq_group
.
multi_modal_data
multi_modal_data
=
seq_group
.
multi_modal_data
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
)
)
else
:
# When SPMD mode is enabled, we only send delta data except for
# the first request to reduce serialization cost.
seq_data_delta
=
{}
for
id
,
data
in
seq_data
.
items
():
seq_data_delta
[
id
]
=
data
.
get_delta_and_reset
()
seq_group_metadata
=
SequenceGroupMetadataDelta
(
seq_data_delta
,
seq_group
.
request_id
,
block_tables
,
is_prompt
,
do_sample
=
do_sample
,
token_chunk_size
=
token_chunk_size
,
computed_block_nums
=
common_computed_block_nums
,
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
# Now that the batch has been created, we can assume all blocks in the
# Now that the batch has been created, we can assume all blocks in the
...
@@ -1130,8 +1144,6 @@ class Scheduler:
...
@@ -1130,8 +1144,6 @@ class Scheduler:
self
.
block_manager
.
mark_blocks_as_computed
(
self
.
block_manager
.
mark_blocks_as_computed
(
scheduled_seq_group
.
seq_group
)
scheduled_seq_group
.
seq_group
)
self
.
_seq_group_metadata_cache
.
reset
()
scheduler_time
=
time
.
perf_counter
()
-
scheduler_start_time
scheduler_time
=
time
.
perf_counter
()
-
scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
# running. This will help estimate if the scheduler is a significant
...
...
vllm/engine/arg_utils.py
View file @
ff7ec82c
...
@@ -5,6 +5,7 @@ from dataclasses import dataclass
...
@@ -5,6 +5,7 @@ from dataclasses import dataclass
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
Union
)
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ObservabilityConfig
,
ParallelConfig
,
...
@@ -905,6 +906,8 @@ class EngineArgs:
...
@@ -905,6 +906,8 @@ class EngineArgs:
embedding_mode
=
model_config
.
embedding_mode
,
embedding_mode
=
model_config
.
embedding_mode
,
preemption_mode
=
self
.
preemption_mode
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
and
parallel_config
.
use_ray
),
)
)
lora_config
=
LoRAConfig
(
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
...
...
vllm/engine/llm_engine.py
View file @
ff7ec82c
...
@@ -224,7 +224,6 @@ class LLMEngine:
...
@@ -224,7 +224,6 @@ class LLMEngine:
cache_config
.
enable_prefix_caching
,
cache_config
.
enable_prefix_caching
,
)
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
load_general_plugins
()
...
...
vllm/executor/msgspec_utils.py
0 → 100644
View file @
ff7ec82c
from
array
import
array
from
typing
import
Any
,
Type
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
def
encode_hook
(
obj
:
Any
)
->
Any
:
"""Custom msgspec enc hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if
isinstance
(
obj
,
array
):
assert
obj
.
typecode
==
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
f
"vLLM array type should use '
{
VLLM_TOKEN_ID_ARRAY_TYPE
}
' type. "
f
"Given array has a type code of
{
obj
.
typecode
}
."
)
return
obj
.
tobytes
()
def
decode_hook
(
type
:
Type
,
obj
:
Any
)
->
Any
:
"""Custom msgspec dec hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if
type
is
array
:
deserialized
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
)
deserialized
.
frombytes
(
obj
)
return
deserialized
vllm/executor/ray_gpu_executor.py
View file @
ff7ec82c
...
@@ -4,9 +4,12 @@ from collections import defaultdict
...
@@ -4,9 +4,12 @@ from collections import defaultdict
from
itertools
import
islice
,
repeat
from
itertools
import
islice
,
repeat
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
import
msgspec
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.executor.distributed_gpu_executor
import
(
# yapf: disable
from
vllm.executor.distributed_gpu_executor
import
(
# yapf: disable
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
from
vllm.executor.msgspec_utils
import
encode_hook
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
...
@@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
self
.
_init_workers_ray
(
placement_group
)
self
.
input_encoder
=
msgspec
.
msgpack
.
Encoder
(
enc_hook
=
encode_hook
)
self
.
output_decoder
=
msgspec
.
msgpack
.
Decoder
(
Optional
[
List
[
SamplerOutput
]])
def
shutdown
(
self
)
->
None
:
def
shutdown
(
self
)
->
None
:
if
hasattr
(
self
,
"forward_dag"
)
and
self
.
forward_dag
is
not
None
:
if
hasattr
(
self
,
"forward_dag"
)
and
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
self
.
forward_dag
.
teardown
()
...
@@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_remote_kwargs
)
ray_remote_kwargs
)
logger
.
info
(
"use_ray_spmd_worker: %s"
,
self
.
use_ray_spmd_worker
)
logger
.
info
(
"use_ray_spmd_worker: %s"
,
self
.
use_ray_spmd_worker
)
# Create the workers.
# Create the workers.
driver_ip
=
get_ip
()
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
...
@@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
if
self
.
forward_dag
is
None
:
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
outputs
=
ray
.
get
(
self
.
forward_dag
.
execute
(
execute_model_req
))
serialized_data
=
self
.
input_encoder
.
encode
(
execute_model_req
)
return
outputs
[
0
]
outputs
=
ray
.
get
(
self
.
forward_dag
.
execute
(
serialized_data
))
output
=
self
.
output_decoder
.
decode
(
outputs
[
0
])
return
output
def
_run_workers
(
def
_run_workers
(
self
,
self
,
...
@@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
...
@@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
if
self
.
forward_dag
is
None
:
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
True
)
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
True
)
dag_future
=
await
self
.
forward_dag
.
execute_async
(
execute_model_req
)
serialized_data
=
self
.
input_encoder
.
encode
(
execute_model_req
)
dag_future
=
await
self
.
forward_dag
.
execute_async
(
serialized_data
)
outputs
=
await
dag_future
outputs
=
await
dag_future
return
outputs
[
0
]
return
self
.
output_decoder
.
decode
(
outputs
[
0
]
)
async
def
_driver_execute_model_async
(
async
def
_driver_execute_model_async
(
self
,
self
,
...
...
vllm/executor/ray_utils.py
View file @
ff7ec82c
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
msgspec
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
from
vllm.executor.msgspec_utils
import
decode_hook
,
encode_hook
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
...
@@ -24,6 +27,10 @@ try:
...
@@ -24,6 +27,10 @@ try:
# that thread.
# that thread.
self
.
compiled_dag_cuda_device_set
=
False
self
.
compiled_dag_cuda_device_set
=
False
self
.
input_decoder
=
msgspec
.
msgpack
.
Decoder
(
ExecuteModelRequest
,
dec_hook
=
decode_hook
)
self
.
output_encoder
=
msgspec
.
msgpack
.
Encoder
(
enc_hook
=
encode_hook
)
def
get_node_ip
(
self
)
->
str
:
def
get_node_ip
(
self
)
->
str
:
return
get_ip
()
return
get_ip
()
...
@@ -33,16 +40,26 @@ try:
...
@@ -33,16 +40,26 @@ try:
return
node_id
,
gpu_ids
return
node_id
,
gpu_ids
def
execute_model_spmd
(
def
execute_model_spmd
(
self
,
req_or_tuple
:
Union
[
ExecuteModelRequest
,
self
,
req_or_tuple
:
Union
[
bytes
,
Tuple
[
ExecuteModelRequest
,
Tuple
[
bytes
,
IntermediateTensors
]]):
Optional
[
IntermediateTensors
]]]
)
->
bytes
:
"""Execute model in SPMD fashion: used only when SPMD worker and
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
compiled DAG are both enabled.
Args:
Args:
req_or_tuple: The request to execute the model, or a tuple
req_or_tuple: A request or a tuple containing the
containing the request and intermediate tensors.
request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
"""
"""
if
isinstance
(
req_or_tuple
,
bytes
):
serialized_req
,
intermediate_tensors
=
req_or_tuple
,
None
else
:
serialized_req
,
intermediate_tensors
=
req_or_tuple
execute_model_req
=
self
.
input_decoder
.
decode
(
serialized_req
)
# TODO(swang): This is needed right now because Ray aDAG executes
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# on a background thread, so we need to reset torch's current
# device.
# device.
...
@@ -51,16 +68,14 @@ try:
...
@@ -51,16 +68,14 @@ try:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
self
.
compiled_dag_cuda_device_set
=
True
if
isinstance
(
req_or_tuple
,
tuple
):
execute_model_req
,
intermediate_tensors
=
req_or_tuple
else
:
execute_model_req
=
req_or_tuple
intermediate_tensors
=
None
output
=
self
.
worker
.
_execute_model_spmd
(
execute_model_req
,
output
=
self
.
worker
.
_execute_model_spmd
(
execute_model_req
,
intermediate_tensors
)
intermediate_tensors
)
# Pipeline model request and output to the next pipeline stage.
if
isinstance
(
output
,
IntermediateTensors
):
if
isinstance
(
output
,
IntermediateTensors
):
return
execute_model_req
,
output
output
=
serialized_req
,
output
else
:
output
=
self
.
output_encoder
.
encode
(
output
)
return
output
return
output
ray_import_err
=
None
ray_import_err
=
None
...
...
vllm/inputs/registry.py
View file @
ff7ec82c
import
functools
import
functools
from
array
import
array
from
collections
import
UserDict
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Dict
,
Mapping
,
Optional
,
Protocol
,
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Dict
,
Mapping
,
Optional
,
Protocol
,
...
@@ -21,6 +22,10 @@ logger = init_logger(__name__)
...
@@ -21,6 +22,10 @@ logger = init_logger(__name__)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
InputContext
:
class
InputContext
:
...
@@ -118,7 +123,8 @@ class InputRegistry:
...
@@ -118,7 +123,8 @@ class InputRegistry:
# Avoid circular import
# Avoid circular import
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
([
0
]
*
seq_len
)
dummy_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
seq_len
)
dummy_multi_modal_data
=
None
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
return
dummy_seq_data
,
dummy_multi_modal_data
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment