Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
51679bbd
"examples/vscode:/vscode.git/clone" did not exist on "5b1cd9a66a9cbfdf6347096f2509734fcaa50734"
Commit
51679bbd
authored
Feb 01, 2024
by
zhuwenwen
Browse files
resolve merge confilcts
parents
4095d0db
1af090b5
Changes
170
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1654 additions
and
121 deletions
+1654
-121
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+1
-0
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+392
-0
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+91
-10
tests/test_regression.py
tests/test_regression.py
+13
-0
tests/test_sampling_params.py
tests/test_sampling_params.py
+13
-0
tests/worker/__init__.py
tests/worker/__init__.py
+0
-0
tests/worker/spec_decode/__init__.py
tests/worker/spec_decode/__init__.py
+0
-0
tests/worker/spec_decode/test_multi_step_worker.py
tests/worker/spec_decode/test_multi_step_worker.py
+261
-0
tests/worker/spec_decode/utils.py
tests/worker/spec_decode/utils.py
+177
-0
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+4
-3
vllm/__init__.py
vllm/__init__.py
+1
-1
vllm/block.py
vllm/block.py
+4
-0
vllm/config.py
vllm/config.py
+94
-2
vllm/core/block_manager.py
vllm/core/block_manager.py
+42
-5
vllm/core/policy.py
vllm/core/policy.py
+10
-8
vllm/core/scheduler.py
vllm/core/scheduler.py
+114
-33
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+67
-5
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+143
-14
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+222
-34
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+5
-6
No files found.
tests/samplers/test_logprobs.py
View file @
51679bbd
...
@@ -30,6 +30,7 @@ def test_get_prompt_logprobs(
...
@@ -30,6 +30,7 @@ def test_get_prompt_logprobs(
temperature
=
0.0
)
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
example_prompts
,
sampling_params
=
vllm_sampling_params
)
del
vllm_model
# Test whether logprobs are included in the results.
# Test whether logprobs are included in the results.
for
result
in
vllm_results
:
for
result
in
vllm_results
:
...
...
tests/samplers/test_rejection_sampler.py
0 → 100644
View file @
51679bbd
"""Tests for rejection sampling."""
import
pytest
from
typing
import
List
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
def
mock_causal_accepted_tensor
(
k
:
int
,
last_accepted_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Generate an "accepted" tensor which should yield causally-accepted tokens
up to last accepted indices.
Tokens after last_accepted_indices+1 may also be accepted, although they
will not be causally accepted.
"""
batch_size
=
last_accepted_indices
.
shape
[
0
]
accepted
=
(
torch
.
arange
(
k
).
expand
(
batch_size
,
k
)
<=
last_accepted_indices
.
unsqueeze
(
-
1
).
broadcast_to
(
batch_size
,
k
)).
to
(
device
=
"cuda"
)
# Sprinkle accepted values after the contiguous initial accepted values.
# This replicates the behavior of rejection sampling, which may "accept"
# a token that cannot be accepted because of causality.
sprinkle_candidates
=
(
torch
.
arange
(
k
).
expand
(
batch_size
,
k
)
>
last_accepted_indices
.
unsqueeze
(
-
1
).
broadcast_to
(
batch_size
,
k
)
+
1
)
sprinkle
=
torch
.
rand
(
batch_size
,
k
,
device
=
"cuda"
)
>
0.5
accepted
[
sprinkle_candidates
]
=
sprinkle
[
sprinkle_candidates
]
return
accepted
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"which_tokens_accepted"
,
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
):
"""Verify the output has correct format given predetermined accepted matrix.
"""
set_random_seed
(
seed
)
batch_size
=
10
k
=
5
vocab_size
=
3000
if
which_tokens_accepted
==
"all_tokens_accepted"
:
accepted
=
mock_causal_accepted_tensor
(
k
,
-
1
+
k
*
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
long
))
elif
which_tokens_accepted
==
"no_tokens_accepted"
:
accepted
=
mock_causal_accepted_tensor
(
k
,
-
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
long
))
elif
which_tokens_accepted
==
"some_tokens_accepted"
:
last_accepted_indices
=
torch
.
randint
(
low
=-
1
,
high
=
k
,
size
=
(
batch_size
,
))
accepted
=
mock_causal_accepted_tensor
(
k
,
last_accepted_indices
)
else
:
raise
AssertionError
()
recovered_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rejection_sampler
=
RejectionSampler
()
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
if
which_tokens_accepted
==
"all_tokens_accepted"
:
# Expect all tokens to be equal to draft tokens.
assert
torch
.
equal
(
output_token_ids
[:,
:
-
1
],
draft_token_ids
)
# Expect all bonus tokens to be included.
assert
torch
.
equal
(
output_token_ids
[:,
-
1
:],
bonus_token_ids
)
elif
which_tokens_accepted
==
"no_tokens_accepted"
:
# Expect first token to be equal to recovered tokens.
assert
torch
.
equal
(
output_token_ids
[:,
0
],
recovered_token_ids
[:,
0
])
# Expect everything else to be -1.
assert
torch
.
equal
(
output_token_ids
[:,
1
:],
torch
.
ones_like
(
output_token_ids
[:,
1
:])
*
-
1
)
elif
which_tokens_accepted
==
"some_tokens_accepted"
:
recovered_plus_bonus
=
torch
.
cat
(
(
recovered_token_ids
,
bonus_token_ids
),
dim
=-
1
)
# Assert first rejected token is a recovered token or bonus token.
assert
torch
.
equal
(
recovered_plus_bonus
[
torch
.
arange
(
0
,
batch_size
),
last_accepted_indices
+
1
],
output_token_ids
[
torch
.
arange
(
0
,
batch_size
),
last_accepted_indices
+
1
])
# Assert every subsequent token is -1.
subsequent_mask
=
torch
.
arange
(
0
,
k
+
1
).
expand
(
batch_size
,
k
+
1
)
>=
(
last_accepted_indices
+
2
).
unsqueeze
(
-
1
)
assert
torch
.
all
(
output_token_ids
[
subsequent_mask
]
==
-
1
)
@
pytest
.
mark
.
parametrize
(
"k"
,
list
(
range
(
1
,
6
)))
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
@
torch
.
inference_mode
()
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
):
rejection_sampler
=
RejectionSampler
()
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"above_or_below_vocab_range"
,
[
"above"
,
"below"
])
@
pytest
.
mark
.
parametrize
(
"which_token_ids"
,
[
"bonus_token_ids"
,
"draft_token_ids"
])
@
torch
.
inference_mode
()
def
test_raises_when_vocab_oob
(
above_or_below_vocab_range
:
str
,
which_token_ids
:
str
):
k
=
3
batch_size
=
5
vocab_size
=
30_000
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
)
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
oob_token_ids
=
None
if
which_token_ids
==
"bonus_token_ids"
:
oob_token_ids
=
bonus_token_ids
elif
which_token_ids
==
"draft_token_ids"
:
oob_token_ids
=
draft_token_ids
else
:
raise
AssertionError
()
if
above_or_below_vocab_range
==
"above"
:
rogue_token_id
=
vocab_size
+
1
elif
above_or_below_vocab_range
==
"below"
:
rogue_token_id
=
-
1
else
:
raise
AssertionError
()
oob_token_ids
[
0
][
0
]
=
rogue_token_id
with
pytest
.
raises
(
AssertionError
):
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
torch
.
inference_mode
()
def
test_rejection_sampling_approximates_target_distribution
(
seed
:
int
,
draft_and_target_probs_equal
:
bool
):
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
This is done by first creating a random target probability
distribution and a random draft probability distribution. We then
sample token ids from the rejection sampler using these draft
and target distributions. The samples are used to estimate
the output probability distribution, which we expect to approximate
the target distribution.
A basic distance metric is used to determine similarity between
distributions.
We expect that as we increase the number of samples,
the distance between the observed distribution and the target
distribution decreases. To measure this, we compare the distance
of the observed distribution against both the target distribution
and a uniform random distribution. We expect the distance between
the observed distribution and the target distribution to improve
much more than the distance improvement between the observed
distribution and the random distribution.
When draft_and_target_probs_equal=True, the draft and target
probabilities are exactly equal. Rejection sampling should
still work without any NaNs or exceptions.
"""
set_random_seed
(
seed
)
helper
=
_CorrectnessTestHelper
(
vocab_size
=
10
,
rejection_sampler
=
RejectionSampler
(),
)
draft_probs
,
target_probs
,
reference_probs
=
helper
.
generate_probs_for_test
(
draft_and_target_probs_equal
)
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
distance_wrt_reference
=
[]
distance_wrt_target
=
[]
for
num_samples
in
sample_sizes
:
(
reference_vs_rejsample_dist
,
target_vs_rejsample_dist
)
=
helper
.
run_and_compare_distributions
(
draft_probs
,
target_probs
,
reference_probs
,
num_samples
,
)
distance_wrt_reference
.
append
(
reference_vs_rejsample_dist
)
distance_wrt_target
.
append
(
target_vs_rejsample_dist
)
relative_change_in_distance_wrt_target
=
get_ratio_first_to_last
(
distance_wrt_target
)
relative_change_in_distance_wrt_reference
=
get_ratio_first_to_last
(
distance_wrt_reference
)
print
(
f
"
{
num_samples
=
}
{
target_vs_rejsample_dist
=
:.
05
f
}
"
f
"
{
reference_vs_rejsample_dist
=
:.
05
f
}
"
)
print
(
f
"
{
num_samples
=
}
{
relative_change_in_distance_wrt_target
=
:.
02
f
}
"
f
"
{
relative_change_in_distance_wrt_reference
=
:.
02
f
}
"
)
relative_change_in_distance_wrt_target
=
get_ratio_first_to_last
(
distance_wrt_target
)
relative_change_in_distance_wrt_reference
=
get_ratio_first_to_last
(
distance_wrt_reference
)
expected_improvement_multiplier
=
20
assert
(
relative_change_in_distance_wrt_target
>
relative_change_in_distance_wrt_reference
*
expected_improvement_multiplier
)
def
get_ratio_first_to_last
(
elements
:
List
[
float
])
->
float
:
return
elements
[
0
]
/
elements
[
-
1
]
class
_CorrectnessTestHelper
:
"""Class that packages together logic required for the unit-level
rejection sampling correctness test.
"""
def
__init__
(
self
,
vocab_size
:
int
,
rejection_sampler
:
RejectionSampler
):
self
.
rejection_sampler
=
rejection_sampler
self
.
vocab_size
=
vocab_size
self
.
vocab_range
=
(
0
,
vocab_size
)
self
.
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
# Keep test simple, use k=1
self
.
k
=
1
# Bonus tokens not used, but rejection sampler requires
# correct shape.
self
.
num_bonus_tokens
=
1
def
generate_probs_for_test
(
self
,
draft_and_target_probs_equal
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
draft_probs
,
target_probs
=
[
F
.
softmax
(
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
)
for
_
in
range
(
2
)
]
num_reference_probs
=
100
reference_probs
=
F
.
softmax
(
torch
.
rand
(
num_reference_probs
,
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
)
if
draft_and_target_probs_equal
:
target_probs
=
draft_probs
.
clone
()
return
draft_probs
,
target_probs
,
reference_probs
def
run_and_compare_distributions
(
self
,
draft_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
reference_probs
:
torch
.
Tensor
,
num_samples
:
int
)
->
Tuple
[
float
,
float
]:
# Sample using rejection sampling.
rej_sample_probs
=
self
.
_estimate_rejection_sampling_pdf
(
draft_probs
,
target_probs
,
num_samples
)
# Average distance from reference probs.
reference_vs_rejsample_dist
=
torch
.
dist
(
reference_probs
,
rej_sample_probs
).
item
()
/
reference_probs
.
shape
[
0
]
target_vs_rejsample_dist
=
torch
.
dist
(
target_probs
,
rej_sample_probs
).
item
()
return
reference_vs_rejsample_dist
,
target_vs_rejsample_dist
def
_estimate_rejection_sampling_pdf
(
self
,
draft_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
num_samples
:
int
,
)
->
torch
.
Tensor
:
# Repeat draft probs num_samples times.
draft_probs
=
draft_probs
.
reshape
(
1
,
self
.
k
,
self
.
vocab_size
).
repeat
(
num_samples
,
1
,
1
)
# Repeat target probs num_samples * k times.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs
=
target_probs
.
reshape
(
1
,
1
,
self
.
vocab_size
).
repeat
(
num_samples
,
self
.
k
,
1
)
# Randomly sample draft token ids from draft probs.
draft_token_ids
=
torch
.
multinomial
(
draft_probs
[:,
0
,
:],
num_samples
=
1
,
replacement
=
True
).
reshape
(
num_samples
,
self
.
k
)
# Bonus tokens not used but required.
bonus_token_ids
=
torch
.
zeros
((
1
,
self
.
num_bonus_tokens
),
dtype
=
torch
.
int64
,
device
=
"cuda"
).
repeat
(
num_samples
,
1
)
# Get output tokens via rejection sampling.
output_token_ids
=
self
.
rejection_sampler
(
target_probs
.
to
(
"cuda"
),
bonus_token_ids
.
to
(
"cuda"
),
draft_probs
.
to
(
"cuda"
),
draft_token_ids
.
to
(
"cuda"
))
# Remove bonus tokens
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
()
# Estimate probability density function
hist
=
torch
.
histogram
(
output_token_ids
.
to
(
dtype
=
torch
.
float
,
device
=
"cpu"
),
bins
=
self
.
vocab_size
,
range
=
self
.
vocab_range
,
density
=
True
)
return
hist
.
hist
tests/samplers/test_sampler.py
View file @
51679bbd
...
@@ -4,6 +4,7 @@ from unittest.mock import patch
...
@@ -4,6 +4,7 @@ from unittest.mock import patch
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
GenerationConfig
,
GenerationMixin
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
...
@@ -18,10 +19,11 @@ class MockLogitsSampler(Sampler):
...
@@ -18,10 +19,11 @@ class MockLogitsSampler(Sampler):
self
.
fake_logits
=
fake_logits
self
.
fake_logits
=
fake_logits
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
with
patch
(
"vllm.model_executor.layers.sampler._prune_hidden_states"
,
with
patch
(
lambda
x
,
y
:
x
),
patch
(
"vllm.model_executor.layers.sampler._prune_hidden_states"
,
"vllm.model_executor.layers.sampler._get_logits"
,
lambda
x
,
y
:
x
),
patch
(
lambda
*
args
,
**
kwargs
:
self
.
fake_logits
):
"vllm.model_executor.layers.sampler.Sampler._get_logits"
,
lambda
*
args
,
**
kwargs
:
self
.
fake_logits
):
return
super
().
forward
(
*
args
,
**
kwargs
)
return
super
().
forward
(
*
args
,
**
kwargs
)
...
@@ -37,7 +39,7 @@ def _prepare_test(
...
@@ -37,7 +39,7 @@ def _prepare_test(
device
=
input_tensor
.
device
,
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
)
return
input_tensor
,
fake_logits
,
sampler
,
model_runner
return
input_tensor
,
fake_logits
,
sampler
,
model_runner
...
@@ -65,7 +67,8 @@ def test_sampler_all_greedy(seed: int):
...
@@ -65,7 +67,8 @@ def test_sampler_all_greedy(seed: int):
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
@@ -74,6 +77,8 @@ def test_sampler_all_greedy(seed: int):
...
@@ -74,6 +77,8 @@ def test_sampler_all_greedy(seed: int):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_all_random
(
seed
:
int
):
def
test_sampler_all_random
(
seed
:
int
):
...
@@ -102,7 +107,8 @@ def test_sampler_all_random(seed: int):
...
@@ -102,7 +107,8 @@ def test_sampler_all_random(seed: int):
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
@@ -110,6 +116,8 @@ def test_sampler_all_random(seed: int):
...
@@ -110,6 +116,8 @@ def test_sampler_all_random(seed: int):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
assert
nth_output
.
output_token
==
i
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_all_beam
(
seed
:
int
):
def
test_sampler_all_beam
(
seed
:
int
):
...
@@ -135,7 +143,8 @@ def test_sampler_all_beam(seed: int):
...
@@ -135,7 +143,8 @@ def test_sampler_all_beam(seed: int):
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler
(
embedding
=
None
,
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
@@ -143,6 +152,7 @@ def test_sampler_all_beam(seed: int):
...
@@ -143,6 +152,7 @@ def test_sampler_all_beam(seed: int):
# the outputs are expected - in other words, this just tests
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
# when handling an all-beam search case.
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
...
@@ -187,7 +197,8 @@ def test_sampler_mixed(seed: int):
...
@@ -187,7 +197,8 @@ def test_sampler_mixed(seed: int):
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
@@ -197,6 +208,8 @@ def test_sampler_mixed(seed: int):
...
@@ -197,6 +208,8 @@ def test_sampler_mixed(seed: int):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
in
expected_tokens
assert
nth_output
.
output_token
in
expected_tokens
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_logits_processors
(
seed
:
int
):
def
test_sampler_logits_processors
(
seed
:
int
):
...
@@ -226,10 +239,78 @@ def test_sampler_logits_processors(seed: int):
...
@@ -226,10 +239,78 @@ def test_sampler_logits_processors(seed: int):
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
for
_
,
sequence_output
in
enumerate
(
sampler_output
):
for
_
,
sequence_output
in
enumerate
(
sampler_output
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
nth_output
.
output_token
==
idx
assert
nth_output
.
output_token
==
idx
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_top_k_top_p
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
top_k
=
random
.
randint
(
100
,
500
)
top_p
=
random
.
random
()
*
0.1
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
normal
(
0
,
5
,
size
=
(
batch_size
,
vocab_size
),
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
)
generation_model
=
GenerationMixin
()
generation_config
=
GenerationConfig
(
top_k
=
top_k
,
top_p
=
top_p
,
do_sample
=
True
)
warpers
=
generation_model
.
_get_logits_warper
(
generation_config
)
assert
len
(
warpers
)
==
2
# top_p and top_k
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
1
,
top_k
=
top_k
,
top_p
=
top_p
,
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sample_probs
=
None
def
mock_sample
(
probs
,
logprobs
,
sampling_metadata
):
nonlocal
sample_probs
sample_probs
=
probs
return
[[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
]
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
hf_probs
=
warpers
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
assert
torch
.
allclose
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
del
model_runner
tests/test_regression.py
View file @
51679bbd
...
@@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group():
...
@@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group():
assert
len
(
prompts
)
==
len
(
outputs
)
assert
len
(
prompts
)
==
len
(
outputs
)
def
test_max_tokens_none
():
sampling_params
=
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.1
,
max_tokens
=
None
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
max_num_batched_tokens
=
4096
,
tensor_parallel_size
=
1
)
prompts
=
[
"Just say hello!"
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
prompts
)
==
len
(
outputs
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
pytest
import
pytest
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
tests/test_sampling_params.py
0 → 100644
View file @
51679bbd
"""Tests for the SamplingParams class.
"""
from
vllm
import
SamplingParams
def
test_max_tokens_none
():
"""max_tokens=None should be allowed"""
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.1
,
max_tokens
=
None
)
if
__name__
==
"__main__"
:
import
pytest
pytest
.
main
([
__file__
])
tests/worker/__init__.py
0 → 100644
View file @
51679bbd
tests/worker/spec_decode/__init__.py
0 → 100644
View file @
51679bbd
tests/worker/spec_decode/test_multi_step_worker.py
0 → 100644
View file @
51679bbd
import
torch
import
random
import
pytest
from
unittest.mock
import
MagicMock
from
vllm.worker.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.worker.worker
import
Worker
from
vllm.model_executor.utils
import
set_random_seed
from
.utils
import
(
create_execute_model_data
,
create_worker
,
create_seq_group_metadata_from_prompts
,
zero_kv_cache
,
patch_execute_model_with_seeds
,
assert_logprobs_dict_allclose
)
@
pytest
.
mark
.
parametrize
(
'num_steps'
,
list
(
range
(
1
,
17
)))
def
test_assert_enough_kv_space
(
num_steps
:
int
):
"""Test that the multi step worker checks for sufficient space in the KV
cache. It should throw if it cannot run all the steps.
"""
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
prompts
=
[
list
(
range
(
block_size
*
3
)),
list
(
range
(
block_size
*
2
)),
]
prev_output_tokens
=
[
list
(
range
(
block_size
*
1
)),
list
(
range
(
block_size
*
2
)),
]
final_seq_lens
=
[
len
(
prompt
+
output
)
+
num_steps
for
prompt
,
output
in
zip
(
prompts
,
prev_output_tokens
)
]
inputs
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
,
continuations
=
prev_output_tokens
)
assert_enough_kv_space
=
MultiStepWorker
.
_assert_enough_kv_space
# pylint: disable=protected-access
worker
=
MagicMock
()
worker
.
model_runner
.
block_size
=
block_size
for
seq_group_metadata
in
inputs
:
original_block_tables
=
seq_group_metadata
.
block_tables
# No exception.
assert_enough_kv_space
(
worker
,
inputs
,
num_steps
)
seq_group_metadata
.
block_tables
=
{
seq_id
:
[]
for
seq_id
,
physical_blocks
in
original_block_tables
.
items
()
}
# Expect exception.
with
pytest
.
raises
(
ValueError
,
match
=
'times but found insufficient KV space for'
):
assert_enough_kv_space
(
worker
,
inputs
,
num_steps
)
seq_group_metadata
.
block_tables
=
original_block_tables
@
torch
.
inference_mode
()
def
test_same_output_for_single_step
():
"""Verify the multi step worker produces the same output as the normal
worker for num_steps=1.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
multi_step_worker
.
model_runner
=
worker
.
model_runner
multi_step_worker
.
cache_engine
=
worker
.
cache_engine
num_steps
=
1
prompts
=
[
[
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
],
]
final_seq_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
multi_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
single_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
actual_output
=
multi_step_worker
.
execute_model_multi_step
(
**
multi_step_execute_model_data
.
to_dict
(),
num_steps
=
num_steps
)
assert
len
(
actual_output
)
==
num_steps
actual_output
=
actual_output
[
0
]
zero_kv_cache
(
worker
.
cache_engine
)
set_random_seed
(
seed
)
expected_output
=
worker
.
execute_model
(
**
single_step_execute_model_data
.
to_dict
(),
)
actual_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
actual_output
]
actual_logprobs
=
[
output
.
samples
[
0
].
logprobs
for
output
in
actual_output
]
expected_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
expected_output
]
expected_logprobs
=
[
output
.
samples
[
0
].
logprobs
for
output
in
expected_output
]
assert
actual_token_ids
==
expected_token_ids
print
(
f
'
{
actual_logprobs
=
}
'
)
print
(
f
'
{
expected_logprobs
=
}
'
)
assert_logprobs_dict_allclose
(
actual_logprobs
,
expected_logprobs
)
@
torch
.
inference_mode
()
def
test_same_output_for_multi_step
():
"""Verify the multi-step worker produces the same output as the normal
worker when num_steps > 1. This test runs the multi-step worker once, and
then runs the worker num_steps times, and compares the output.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
# Make sure we go over the block boundary.
num_steps
=
block_size
+
1
random
.
seed
(
seed
)
prompts
=
[[
random
.
randint
(
0
,
1000
)
for
_
in
range
(
random
.
randint
(
10
,
20
))
]
for
_
in
range
(
10
)]
final_seq_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
multi_step_worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
continuations
=
[[
1
]
for
_
in
prompts
]
execute_model_data
=
create_execute_model_data
(
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_seq_lens
=
final_seq_lens
),
)
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
multi_step_output
=
multi_step_worker
.
execute_model_multi_step
(
**
execute_model_data
.
to_dict
(),
num_steps
=
num_steps
)
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
=
[]
continuations
=
[[
1
]
for
_
in
prompts
]
set_random_seed
(
seed
)
for
_
in
multi_step_output
:
execute_model_data
=
create_execute_model_data
(
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_seq_lens
=
final_seq_lens
))
single_step_output
.
append
(
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
))
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs
=
[[]
for
_
in
prompts
]
single_step_output_logprobs
=
[[]
for
_
in
prompts
]
multi_step_output_token_ids
=
[[]
for
_
in
prompts
]
single_step_output_token_ids
=
[[]
for
_
in
prompts
]
for
i
,
_
in
enumerate
(
prompts
):
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
single_step_output
):
multi_step_output_token_ids
[
i
].
append
(
multi_step
[
i
].
samples
[
0
].
output_token
)
single_step_output_token_ids
[
i
].
append
(
single_step
[
i
].
samples
[
0
].
output_token
)
multi_step_output_logprobs
[
i
].
append
(
multi_step
[
i
].
samples
[
0
].
logprobs
)
single_step_output_logprobs
[
i
].
append
(
single_step
[
i
].
samples
[
0
].
logprobs
)
# Print per-sequence token ids
for
i
,
(
multi_step_tokens
,
single_step_tokens
)
in
enumerate
(
zip
(
multi_step_output_token_ids
,
single_step_output_token_ids
)):
print
(
f
'
{
i
=
}
{
multi_step_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
single_step_tokens
=
}
'
)
print
(
f
'
{
i
=
}
equal
{
multi_step_tokens
==
single_step_tokens
}
'
)
# Assert token ids are equal.
for
multi_step_tokens
,
single_step_tokens
in
zip
(
multi_step_output_token_ids
,
single_step_output_token_ids
):
assert
multi_step_tokens
==
single_step_tokens
# Assert logprobs are equal.
for
multi_step_logprobs
,
single_step_logprobs
in
zip
(
multi_step_output_logprobs
,
single_step_output_logprobs
):
assert_logprobs_dict_allclose
(
multi_step_logprobs
,
single_step_logprobs
)
tests/worker/spec_decode/utils.py
0 → 100644
View file @
51679bbd
import
torch
from
typing
import
List
,
Optional
,
Dict
from
vllm.worker.worker
import
Worker
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
SequenceGroupMetadata
,
SequenceData
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.model_executor.utils
import
set_random_seed
from
dataclasses
import
dataclass
,
fields
@
dataclass
class
ExecuteModelData
:
"""Helper data structure which facilitates cleaner tests.
"""
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
def
to_dict
(
self
):
return
dict
(
(
field
.
name
,
getattr
(
self
,
field
.
name
))
for
field
in
fields
(
self
))
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
def
create_execute_model_data
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
)
->
ExecuteModelData
:
if
blocks_to_swap_in
is
None
:
blocks_to_swap_in
=
{}
if
blocks_to_swap_out
is
None
:
blocks_to_swap_out
=
{}
if
blocks_to_copy
is
None
:
blocks_to_copy
=
{}
return
ExecuteModelData
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
def
patch_execute_model_with_seeds
(
worker
:
Worker
,
rand_seeds
:
List
[
int
]):
seed_iter
=
iter
(
rand_seeds
)
original_execute_model
=
worker
.
execute_model
def
new_execute_model
(
*
args
,
**
kwargs
):
result
=
original_execute_model
(
*
args
,
**
kwargs
)
set_random_seed
(
next
(
seed_iter
))
return
result
return
new_execute_model
def
zero_kv_cache
(
cache_engine
:
CacheEngine
):
assert
cache_engine
.
gpu_cache
for
key_blocks
,
value_blocks
in
cache_engine
.
gpu_cache
:
key_blocks
.
zero_
()
value_blocks
.
zero_
()
def
create_worker
(
cls
:
type
,
model_name
:
str
,
block_size
:
int
,
num_gpu_blocks
:
int
,
seed
:
int
,
is_driver_worker
:
bool
=
True
,
enforce_eager
:
bool
=
True
):
engine_args
=
EngineArgs
(
model
=
model_name
,
seed
=
seed
,
block_size
=
block_size
,
enforce_eager
=
enforce_eager
,
)
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
_
)
=
engine_args
.
create_engine_configs
()
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
worker
=
cls
(
model_config
=
model_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
)
worker
.
init_model
()
worker
.
load_model
()
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
cache_config
.
num_cpu_blocks
=
0
worker
.
init_cache_engine
(
cache_config
)
worker
.
warm_up_model
()
return
worker
def
create_seq_group_metadata_from_prompts
(
prompts
:
List
[
List
[
int
]],
num_gpu_blocks
:
int
,
block_size
:
int
,
final_seq_lens
:
List
[
int
],
continuations
:
Optional
[
List
[
List
[
int
]]]
=
None
,
num_tokens_processed
:
Optional
[
List
[
int
]]
=
None
,
seq_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
SequenceGroupMetadata
]:
if
continuations
is
None
:
continuations
=
[[]
for
_
in
prompts
]
if
num_tokens_processed
is
None
:
# Default to 1 token missing from kv cache for generation sequences.
num_tokens_processed
=
[]
for
continuation
,
prompt
in
zip
(
continuations
,
prompts
):
# If prefill, then default to zero tokens processed.
if
not
continuation
:
num_tokens_processed
.
append
(
0
)
else
:
# If generation, then default to all but one tokens processed.
num_tokens_processed
.
append
(
len
(
continuation
)
+
len
(
prompt
)
-
1
)
if
seq_ids
is
None
:
seq_ids
=
list
(
i
for
i
,
_
in
enumerate
(
prompts
))
free_gpu_blocks
=
list
(
range
(
num_gpu_blocks
))
block_allocations
=
{
i
:
[
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
final_len
,
block_size
))
]
for
i
,
final_len
in
enumerate
(
final_seq_lens
)
}
return
[
SequenceGroupMetadata
(
request_id
=
str
(
i
),
is_prompt
=
len
(
cont_token_ids
)
==
0
,
seq_data
=
{
i
:
SequenceData
(
prompt_token_ids
=
prompt_token_ids
[:]
+
cont_token_ids
[:])
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
block_tables
=
{
i
:
block_allocations
[
i
][:]},
)
for
i
,
(
prompt_token_ids
,
cont_token_ids
,
num_tokens_saved
)
in
enumerate
(
zip
(
prompts
,
continuations
,
num_tokens_processed
))
]
def
assert_logprobs_dict_allclose
(
actual_logprobs
:
List
[
Dict
[
int
,
float
]],
expected_logprobs
:
List
[
Dict
[
int
,
float
]])
->
None
:
for
single_step_actual_logprobs
,
single_step_expected_logprobs
in
zip
(
actual_logprobs
,
expected_logprobs
):
assert
set
(
single_step_actual_logprobs
.
keys
())
==
set
(
single_step_expected_logprobs
.
keys
())
for
token_id
in
single_step_actual_logprobs
:
actual
=
torch
.
tensor
(
single_step_actual_logprobs
[
token_id
])
expected
=
torch
.
tensor
(
single_step_expected_logprobs
[
token_id
])
assert
torch
.
allclose
(
actual
,
expected
)
tests/worker/test_model_runner.py
View file @
51679bbd
...
@@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
...
@@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
def
test_prepare_prompt
():
def
test_prepare_prompt
():
model_runner
=
ModelRunner
(
None
,
None
,
None
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
model_runner
.
set_block_size
(
16
)
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
...
@@ -33,11 +33,12 @@ def test_prepare_prompt():
...
@@ -33,11 +33,12 @@ def test_prepare_prompt():
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
prompt_len
-
1
)
selected_token_start_idx
+=
max_seq_len
selected_token_start_idx
+=
max_seq_len
input_tokens
,
input_positions
,
_
,
return_prompt_lens
=
(
input_tokens
,
input_positions
,
_
,
return_prompt_lens
,
_
,
_
,
_
,
_
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_prompt_lens
==
prompt_lens
assert
return_prompt_lens
==
prompt_lens
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
)
assert
input_tokens
.
shape
==
(
batch_size
,
max_seq_len
)
assert
input_tokens
.
shape
==
(
batch_size
,
max_seq_len
)
assert
input_positions
.
shape
==
(
batch_size
,
max_seq_len
)
assert
input_positions
.
shape
==
(
batch_size
,
max_seq_len
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
...
...
vllm/__init__.py
View file @
51679bbd
...
@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
...
@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__dcu_version__
from
vllm.version
import
__dcu_version__
__version__
=
"0.
2.7
"
__version__
=
"0.
3.0
"
__all__
=
[
__all__
=
[
"LLM"
,
"LLM"
,
...
...
vllm/block.py
View file @
51679bbd
...
@@ -66,3 +66,7 @@ class PhysicalTokenBlock:
...
@@ -66,3 +66,7 @@ class PhysicalTokenBlock:
return
(
f
'PhysicalTokenBlock(device=
{
self
.
device
}
, '
return
(
f
'PhysicalTokenBlock(device=
{
self
.
device
}
, '
f
'block_number=
{
self
.
block_number
}
, '
f
'block_number=
{
self
.
block_number
}
, '
f
'ref_count=
{
self
.
ref_count
}
)'
)
f
'ref_count=
{
self
.
ref_count
}
)'
)
# Mapping: logical block number -> physical block.
BlockTable
=
List
[
PhysicalTokenBlock
]
vllm/config.py
View file @
51679bbd
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
,
ClassVar
from
dataclasses
import
dataclass
import
os
import
os
from
packaging.version
import
Version
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.config
import
get_config
from
vllm.utils
import
get_cpu_memory
,
is_hip
from
vllm.utils
import
get_cpu_memory
,
is_hip
,
get_nvcc_cuda_version
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -212,6 +214,8 @@ class ModelConfig:
...
@@ -212,6 +214,8 @@ class ModelConfig:
return
self
.
hf_config
.
hidden_size
return
self
.
hf_config
.
hidden_size
def
get_head_size
(
self
)
->
int
:
def
get_head_size
(
self
)
->
int
:
if
hasattr
(
self
.
hf_config
,
"head_dim"
):
return
self
.
hf_config
.
head_dim
# FIXME(woosuk): This may not be true for all models.
# FIXME(woosuk): This may not be true for all models.
return
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
return
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
...
@@ -272,6 +276,7 @@ class CacheConfig:
...
@@ -272,6 +276,7 @@ class CacheConfig:
gpu_memory_utilization: Fraction of GPU memory to use for the
gpu_memory_utilization: Fraction of GPU memory to use for the
vLLM execution.
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -279,13 +284,16 @@ class CacheConfig:
...
@@ -279,13 +284,16 @@ class CacheConfig:
block_size
:
int
,
block_size
:
int
,
gpu_memory_utilization
:
float
,
gpu_memory_utilization
:
float
,
swap_space
:
int
,
swap_space
:
int
,
cache_dtype
:
str
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
swap_space_bytes
=
swap_space
*
_GB
self
.
swap_space_bytes
=
swap_space
*
_GB
self
.
cache_dtype
=
cache_dtype
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
self
.
_verify_args
()
self
.
_verify_args
()
self
.
_verify_cache_dtype
()
# Will be set after profiling.
# Will be set after profiling.
self
.
num_gpu_blocks
=
None
self
.
num_gpu_blocks
=
None
...
@@ -297,6 +305,28 @@ class CacheConfig:
...
@@ -297,6 +305,28 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
"GPU memory utilization must be less than 1.0. Got "
f
"
{
self
.
gpu_memory_utilization
}
."
)
f
"
{
self
.
gpu_memory_utilization
}
."
)
def
_verify_cache_dtype
(
self
)
->
None
:
if
self
.
cache_dtype
==
"auto"
:
pass
elif
self
.
cache_dtype
==
"fp8_e5m2"
:
nvcc_cuda_version
=
get_nvcc_cuda_version
()
if
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
ValueError
(
"FP8 is not supported when cuda version is lower than 11.8."
)
device_name
=
torch
.
cuda
.
get_device_name
()
if
"AMD"
in
device_name
:
raise
NotImplementedError
(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet."
)
logger
.
info
(
"Using fp8_e5m2 data type to store kv cache. It reduces "
"the GPU memory footprint and boosts the performance. "
"But it may cause slight accuracy drop. "
"Currently we only support fp8 without scaling factors and "
"make e5m2 as a default format."
)
else
:
raise
ValueError
(
f
"Unknown kv cache dtype:
{
self
.
cache_dtype
}
"
)
def
verify_with_parallel_config
(
def
verify_with_parallel_config
(
self
,
self
,
parallel_config
:
"ParallelConfig"
,
parallel_config
:
"ParallelConfig"
,
...
@@ -325,6 +355,8 @@ class ParallelConfig:
...
@@ -325,6 +355,8 @@ class ParallelConfig:
worker_use_ray: Whether to use Ray for model workers. Will be set to
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
greater than 1.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -333,11 +365,13 @@ class ParallelConfig:
...
@@ -333,11 +365,13 @@ class ParallelConfig:
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
worker_use_ray
:
bool
,
worker_use_ray
:
bool
,
max_parallel_loading_workers
:
Optional
[
int
]
=
None
,
max_parallel_loading_workers
:
Optional
[
int
]
=
None
,
disable_custom_all_reduce
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
worker_use_ray
=
worker_use_ray
self
.
worker_use_ray
=
worker_use_ray
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
disable_custom_all_reduce
=
disable_custom_all_reduce
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
if
self
.
world_size
>
1
:
if
self
.
world_size
>
1
:
...
@@ -348,6 +382,16 @@ class ParallelConfig:
...
@@ -348,6 +382,16 @@ class ParallelConfig:
if
self
.
pipeline_parallel_size
>
1
:
if
self
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Pipeline parallelism is not supported yet."
)
"Pipeline parallelism is not supported yet."
)
if
is_hip
():
self
.
disable_custom_all_reduce
=
True
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs."
)
elif
self
.
pipeline_parallel_size
>
1
:
self
.
disable_custom_all_reduce
=
True
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism."
)
class
SchedulerConfig
:
class
SchedulerConfig
:
...
@@ -397,6 +441,54 @@ class SchedulerConfig:
...
@@ -397,6 +441,54 @@ class SchedulerConfig:
f
"(
{
self
.
max_num_seqs
}
)."
)
f
"(
{
self
.
max_num_seqs
}
)."
)
@
dataclass
class
LoRAConfig
:
max_lora_rank
:
int
max_loras
:
int
max_cpu_loras
:
Optional
[
int
]
=
None
lora_dtype
:
Optional
[
torch
.
dtype
]
=
None
lora_extra_vocab_size
:
int
=
256
# This is a constant.
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
def
__post_init__
(
self
):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
possible_max_ranks
=
(
8
,
16
,
32
,
64
)
possible_lora_extra_vocab_size
=
(
0
,
256
,
512
)
if
self
.
max_lora_rank
not
in
possible_max_ranks
:
raise
ValueError
(
f
"max_lora_rank (
{
self
.
max_lora_rank
}
) must be one of "
f
"
{
possible_max_ranks
}
."
)
if
self
.
lora_extra_vocab_size
not
in
possible_lora_extra_vocab_size
:
raise
ValueError
(
f
"lora_extra_vocab_size (
{
self
.
lora_extra_vocab_size
}
) "
f
"must be one of
{
possible_lora_extra_vocab_size
}
."
)
if
self
.
max_loras
<
1
:
raise
ValueError
(
f
"max_loras (
{
self
.
max_loras
}
) must be >= 1."
)
if
self
.
max_cpu_loras
is
None
:
self
.
max_cpu_loras
=
self
.
max_loras
elif
self
.
max_cpu_loras
<
self
.
max_loras
:
raise
ValueError
(
f
"max_cpu_loras (
{
self
.
max_cpu_loras
}
) must be >= "
f
"max_num_seqs (
{
self
.
max_loras
}
)"
)
def
verify_with_model_config
(
self
,
model_config
:
ModelConfig
):
if
self
.
lora_dtype
in
(
None
,
"auto"
):
self
.
lora_dtype
=
model_config
.
dtype
elif
isinstance
(
self
.
lora_dtype
,
str
):
self
.
lora_dtype
=
getattr
(
torch
,
self
.
lora_dtype
)
if
model_config
.
quantization
is
not
None
:
raise
ValueError
(
"LoRA is not supported with quantized models yet."
)
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
if
scheduler_config
.
max_num_batched_tokens
>
65528
:
raise
ValueError
(
"Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
"LoRA is enabled."
)
_STR_DTYPE_TO_TORCH_DTYPE
=
{
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
float16
,
"half"
:
torch
.
float16
,
"float16"
:
torch
.
float16
,
"float16"
:
torch
.
float16
,
...
...
vllm/core/block_manager.py
View file @
51679bbd
...
@@ -2,13 +2,10 @@
...
@@ -2,13 +2,10 @@
import
enum
import
enum
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
vllm.block
import
PhysicalTokenBlock
from
vllm.block
import
BlockTable
,
PhysicalTokenBlock
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.utils
import
Device
from
vllm.utils
import
Device
# Mapping: logical block number -> physical block.
BlockTable
=
List
[
PhysicalTokenBlock
]
class
BlockAllocator
:
class
BlockAllocator
:
"""Manages free physical token blocks for a device.
"""Manages free physical token blocks for a device.
...
@@ -105,6 +102,10 @@ class BlockSpaceManager:
...
@@ -105,6 +102,10 @@ class BlockSpaceManager:
# the same prompt. This may not be true for preempted sequences.
# the same prompt. This may not be true for preempted sequences.
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
num_required_blocks
=
len
(
seq
.
logical_token_blocks
)
num_required_blocks
=
len
(
seq
.
logical_token_blocks
)
if
seq_group
.
prefix
is
not
None
and
seq_group
.
prefix
.
allocated
:
num_required_blocks
-=
seq_group
.
prefix
.
get_num_blocks
()
if
self
.
block_sliding_window
is
not
None
:
if
self
.
block_sliding_window
is
not
None
:
num_required_blocks
=
min
(
num_required_blocks
,
num_required_blocks
=
min
(
num_required_blocks
,
self
.
block_sliding_window
)
self
.
block_sliding_window
)
...
@@ -125,8 +126,21 @@ class BlockSpaceManager:
...
@@ -125,8 +126,21 @@ class BlockSpaceManager:
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
# Allocate new physical token blocks that will store the prompt tokens.
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks
=
len
(
seq
.
logical_token_blocks
)
block_table
:
BlockTable
=
[]
block_table
:
BlockTable
=
[]
for
logical_idx
in
range
(
len
(
seq
.
logical_token_blocks
)):
prefix_block_table
:
BlockTable
=
[]
num_prefix_blocks
=
0
prefix
=
seq_group
.
prefix
if
prefix
is
not
None
and
prefix
.
allocated
:
# Prefix has already been allocated. Use the existing block table.
num_prompt_blocks
-=
prefix
.
get_num_blocks
()
for
block
in
prefix
.
block_table
:
block
.
ref_count
+=
seq_group
.
num_seqs
()
block_table
.
append
(
block
)
for
logical_idx
in
range
(
num_prompt_blocks
):
if
(
self
.
block_sliding_window
is
not
None
if
(
self
.
block_sliding_window
is
not
None
and
logical_idx
>=
self
.
block_sliding_window
):
and
logical_idx
>=
self
.
block_sliding_window
):
block
=
block_table
[
logical_idx
%
self
.
block_sliding_window
]
block
=
block_table
[
logical_idx
%
self
.
block_sliding_window
]
...
@@ -136,6 +150,15 @@ class BlockSpaceManager:
...
@@ -136,6 +150,15 @@ class BlockSpaceManager:
block
.
ref_count
=
seq_group
.
num_seqs
()
block
.
ref_count
=
seq_group
.
num_seqs
()
block_table
.
append
(
block
)
block_table
.
append
(
block
)
if
prefix
is
not
None
and
not
prefix
.
allocated
:
# Allocate blocks for the prefix, we will compute the prefix's
# KV cache in this run.
num_prefix_blocks
=
prefix
.
get_num_blocks
()
prefix_block_table
=
block_table
[:
num_prefix_blocks
]
for
block
in
prefix_block_table
:
block
.
ref_count
+=
1
prefix
.
set_block_table
(
prefix_block_table
)
# Assign the block table for each sequence.
# Assign the block table for each sequence.
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
...
@@ -210,10 +233,18 @@ class BlockSpaceManager:
...
@@ -210,10 +233,18 @@ class BlockSpaceManager:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
# CPU block -> GPU block.
# CPU block -> GPU block.
if
seq_group
.
prefix
is
not
None
:
# make sure to swap in the prefix first
assert
seq_group
.
prefix
.
allocated
and
seq_group
.
prefix
.
computed
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
if
seq_group
.
prefix
is
not
None
:
for
block
in
seq_group
.
prefix
.
block_table
:
new_block_table
.
append
(
block
)
block
.
ref_count
+=
1
for
cpu_block
in
block_table
:
for
cpu_block
in
block_table
:
if
cpu_block
in
mapping
:
if
cpu_block
in
mapping
:
...
@@ -245,6 +276,12 @@ class BlockSpaceManager:
...
@@ -245,6 +276,12 @@ class BlockSpaceManager:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
for
gpu_block
in
block_table
:
for
gpu_block
in
block_table
:
if
(
seq_group
.
prefix
is
not
None
and
gpu_block
in
seq_group
.
prefix
.
block_table
):
# NOTE: We do not swap out the prefix blocks for now.
self
.
gpu_allocator
.
free
(
gpu_block
)
continue
if
gpu_block
in
mapping
:
if
gpu_block
in
mapping
:
cpu_block
=
mapping
[
gpu_block
]
cpu_block
=
mapping
[
gpu_block
]
cpu_block
.
ref_count
+=
1
cpu_block
.
ref_count
+=
1
...
...
vllm/core/policy.py
View file @
51679bbd
from
typing
import
List
from
collections
import
deque
from
typing
import
Deque
from
vllm.sequence
import
SequenceGroup
from
vllm.sequence
import
SequenceGroup
...
@@ -15,13 +16,14 @@ class Policy:
...
@@ -15,13 +16,14 @@ class Policy:
def
sort_by_priority
(
def
sort_by_priority
(
self
,
self
,
now
:
float
,
now
:
float
,
seq_groups
:
List
[
SequenceGroup
],
seq_groups
:
Deque
[
SequenceGroup
],
)
->
List
[
SequenceGroup
]:
)
->
Deque
[
SequenceGroup
]:
return
sorted
(
return
deque
(
seq_groups
,
sorted
(
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
seq_groups
,
reverse
=
True
,
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
)
reverse
=
True
,
))
class
FCFS
(
Policy
):
class
FCFS
(
Policy
):
...
...
vllm/core/scheduler.py
View file @
51679bbd
from
collections
import
deque
import
enum
import
enum
import
time
import
time
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
Set
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.block_manager
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.block_manager
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.policy
import
PolicyFactory
from
vllm.core.policy
import
PolicyFactory
from
vllm.lora.request
import
LoRARequest
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.prefix
import
PrefixPool
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -29,7 +32,7 @@ class SchedulerOutputs:
...
@@ -29,7 +32,7 @@ class SchedulerOutputs:
def
__init__
(
def
__init__
(
self
,
self
,
scheduled_seq_groups
:
List
[
SequenceGroup
],
scheduled_seq_groups
:
Iterable
[
SequenceGroup
],
prompt_run
:
bool
,
prompt_run
:
bool
,
num_batched_tokens
:
int
,
num_batched_tokens
:
int
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
...
@@ -47,11 +50,25 @@ class SchedulerOutputs:
...
@@ -47,11 +50,25 @@ class SchedulerOutputs:
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
self
.
ignored_seq_groups
=
ignored_seq_groups
self
.
ignored_seq_groups
=
ignored_seq_groups
self
.
num_loras
=
len
(
self
.
lora_requests
)
if
self
.
num_loras
>
0
:
self
.
_sort_by_lora_ids
()
def
is_empty
(
self
)
->
bool
:
def
is_empty
(
self
)
->
bool
:
# NOTE: We do not consider the ignored sequence groups.
# NOTE: We do not consider the ignored sequence groups.
return
(
not
self
.
scheduled_seq_groups
and
not
self
.
blocks_to_swap_in
return
(
not
self
.
scheduled_seq_groups
and
not
self
.
blocks_to_swap_in
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
def
_sort_by_lora_ids
(
self
)
->
bool
:
self
.
scheduled_seq_groups
=
sorted
(
self
.
scheduled_seq_groups
,
key
=
lambda
g
:
(
g
.
lora_request
.
lora_int_id
if
g
.
lora_request
else
0
,
g
.
request_id
))
@
property
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
return
{
g
.
lora_request
for
g
in
self
.
scheduled_seq_groups
}
class
Scheduler
:
class
Scheduler
:
...
@@ -59,9 +76,14 @@ class Scheduler:
...
@@ -59,9 +76,14 @@ class Scheduler:
self
,
self
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
)
->
None
:
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
# Note for LoRA scheduling: the current policy is extremely
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self
.
lora_config
=
lora_config
self
.
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_num_batched_tokens
)
self
.
scheduler_config
.
max_num_batched_tokens
)
...
@@ -75,38 +97,59 @@ class Scheduler:
...
@@ -75,38 +97,59 @@ class Scheduler:
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
sliding_window
=
self
.
cache_config
.
sliding_window
)
sliding_window
=
self
.
cache_config
.
sliding_window
)
# TODO(zhuohan): Use deque instead of list for better performance.
# Create the prefix pool to cache the prefixes.
self
.
prefix_pool
=
PrefixPool
(
self
.
cache_config
.
block_size
)
# Sequence groups in the WAITING state.
# Sequence groups in the WAITING state.
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
self
.
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the RUNNING state.
# Sequence groups in the RUNNING state.
self
.
running
:
List
[
SequenceGroup
]
=
[]
self
.
running
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the SWAPPED state.
# Sequence groups in the SWAPPED state.
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
self
.
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
@
property
def
lora_enabled
(
self
)
->
bool
:
return
bool
(
self
.
lora_config
)
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
"""Aborts a sequence group with the given ID.
Check if the sequence group with the given ID
is present in any of the state queue.
If present, remove the sequence group from the state queue.
Also, if any of the sequences in the sequence group is not finished,
free the sequence with status `FINISHED_ABORTED`.
Otherwise, do nothing.
Args:
request_id: The ID(s) of the sequence group to abort.
"""
if
isinstance
(
request_id
,
str
):
if
isinstance
(
request_id
,
str
):
request_id
=
(
request_id
,
)
request_id
=
(
request_id
,
)
request_ids
=
set
(
request_id
)
request_ids
=
set
(
request_id
)
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
# We need to reverse the list as we are removing elements
aborted_groups
:
List
[
SequenceGroup
]
=
[]
# from it as we iterate over it. If we don't do it,
for
seq_group
in
state_queue
:
# indices will get messed up and we will skip over elements.
if
not
request_ids
:
for
seq_group
in
reversed
(
state_queue
):
# Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity .
break
if
seq_group
.
request_id
in
request_ids
:
if
seq_group
.
request_id
in
request_ids
:
# Remove the sequence group from the state queue.
# Appending aborted group into pending list.
state_queue
.
remove
(
seq_group
)
aborted_groups
.
append
(
seq_group
)
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
is_finished
():
continue
seq
.
status
=
SequenceStatus
.
FINISHED_ABORTED
self
.
free_seq
(
seq
)
request_ids
.
remove
(
seq_group
.
request_id
)
request_ids
.
remove
(
seq_group
.
request_id
)
if
not
request_ids
:
for
aborted_group
in
aborted_groups
:
return
# Remove the sequence group from the state queue.
state_queue
.
remove
(
aborted_group
)
for
seq
in
aborted_group
.
get_seqs
():
if
seq
.
is_finished
():
continue
seq
.
status
=
SequenceStatus
.
FINISHED_ABORTED
self
.
free_seq
(
seq
)
def
has_unfinished_seqs
(
self
)
->
bool
:
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
return
self
.
waiting
or
self
.
running
or
self
.
swapped
...
@@ -131,14 +174,17 @@ class Scheduler:
...
@@ -131,14 +174,17 @@ class Scheduler:
# requests in the generation phase.
# requests in the generation phase.
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
for
seq_group
in
self
.
running
)
for
seq_group
in
self
.
running
)
curr_loras
=
set
(
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
seq_lens
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
# Optimization: We do not sort the waiting queue since the preempted
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# sequence groups are added to the front and the new sequence groups
# are added to the back.
# are added to the back.
leftover_waiting_sequences
=
deque
()
while
self
.
waiting
:
while
self
.
waiting
:
seq_group
=
self
.
waiting
[
0
]
seq_group
=
self
.
waiting
[
0
]
waiting_seqs
=
seq_group
.
get_seqs
(
waiting_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)
status
=
SequenceStatus
.
WAITING
)
assert
len
(
waiting_seqs
)
==
1
,
(
assert
len
(
waiting_seqs
)
==
1
,
(
...
@@ -152,7 +198,7 @@ class Scheduler:
...
@@ -152,7 +198,7 @@ class Scheduler:
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
self
.
waiting
.
pop
left
()
continue
continue
# If the sequence group cannot be allocated, stop.
# If the sequence group cannot be allocated, stop.
...
@@ -166,9 +212,20 @@ class Scheduler:
...
@@ -166,9 +212,20 @@ class Scheduler:
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
self
.
waiting
.
pop
left
()
continue
continue
lora_int_id
=
0
if
self
.
lora_enabled
:
lora_int_id
=
seq_group
.
lora_int_id
if
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
:
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences
.
appendleft
(
seq_group
)
self
.
waiting
.
popleft
()
continue
# If the number of batched tokens exceeds the limit, stop.
# If the number of batched tokens exceeds the limit, stop.
new_seq_lens
=
seq_lens
+
[
num_prompt_tokens
]
new_seq_lens
=
seq_lens
+
[
num_prompt_tokens
]
num_batched_tokens
=
len
(
new_seq_lens
)
*
max
(
new_seq_lens
)
num_batched_tokens
=
len
(
new_seq_lens
)
*
max
(
new_seq_lens
)
...
@@ -188,12 +245,16 @@ class Scheduler:
...
@@ -188,12 +245,16 @@ class Scheduler:
break
break
seq_lens
=
new_seq_lens
seq_lens
=
new_seq_lens
seq_group
=
self
.
waiting
.
pop
(
0
)
if
lora_int_id
>
0
:
curr_loras
.
add
(
lora_int_id
)
self
.
waiting
.
popleft
()
self
.
_allocate
(
seq_group
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
scheduled
.
append
(
seq_group
)
scheduled
.
append
(
seq_group
)
self
.
waiting
.
extendleft
(
leftover_waiting_sequences
)
if
scheduled
or
ignored_seq_groups
:
if
scheduled
or
ignored_seq_groups
:
scheduler_outputs
=
SchedulerOutputs
(
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
scheduled_seq_groups
=
scheduled
,
...
@@ -214,14 +275,14 @@ class Scheduler:
...
@@ -214,14 +275,14 @@ class Scheduler:
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
# Reserve new token slots for the running sequence groups.
# Reserve new token slots for the running sequence groups.
running
:
List
[
SequenceGroup
]
=
[]
running
:
Deque
[
SequenceGroup
]
=
deque
()
preempted
:
List
[
SequenceGroup
]
=
[]
preempted
:
List
[
SequenceGroup
]
=
[]
while
self
.
running
:
while
self
.
running
:
seq_group
=
self
.
running
.
pop
(
0
)
seq_group
=
self
.
running
.
pop
left
()
while
not
self
.
block_manager
.
can_append_slot
(
seq_group
):
while
not
self
.
block_manager
.
can_append_slot
(
seq_group
):
if
self
.
running
:
if
self
.
running
:
# Preempt the lowest-priority sequence groups.
# Preempt the lowest-priority sequence groups.
victim_seq_group
=
self
.
running
.
pop
(
-
1
)
victim_seq_group
=
self
.
running
.
pop
()
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
preempted
.
append
(
victim_seq_group
)
preempted
.
append
(
victim_seq_group
)
else
:
else
:
...
@@ -241,9 +302,25 @@ class Scheduler:
...
@@ -241,9 +302,25 @@ class Scheduler:
if
not
preempted
:
if
not
preempted
:
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
for
seq_group
in
self
.
running
)
for
seq_group
in
self
.
running
)
curr_loras
=
set
(
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
leftover_swapped
=
deque
()
while
self
.
swapped
:
while
self
.
swapped
:
seq_group
=
self
.
swapped
[
0
]
seq_group
=
self
.
swapped
[
0
]
lora_int_id
=
0
if
self
.
lora_enabled
:
lora_int_id
=
seq_group
.
lora_int_id
if
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
:
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped
.
appendleft
(
seq_group
)
self
.
swapped
.
popleft
()
continue
# If the sequence group cannot be swapped in, stop.
# If the sequence group cannot be swapped in, stop.
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
break
break
...
@@ -255,12 +332,16 @@ class Scheduler:
...
@@ -255,12 +332,16 @@ class Scheduler:
self
.
scheduler_config
.
max_num_seqs
):
self
.
scheduler_config
.
max_num_seqs
):
break
break
seq_group
=
self
.
swapped
.
pop
(
0
)
if
lora_int_id
>
0
:
curr_loras
.
add
(
lora_int_id
)
self
.
swapped
.
popleft
()
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
self
.
swapped
.
extendleft
(
leftover_swapped
)
# Each sequence in the generation phase only takes one token slot.
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
# sequences in the RUNNING state.
...
@@ -301,6 +382,8 @@ class Scheduler:
...
@@ -301,6 +382,8 @@ class Scheduler:
seq_data
=
seq_data
,
seq_data
=
seq_data
,
sampling_params
=
seq_group
.
sampling_params
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
lora_request
=
seq_group
.
lora_request
,
prefix
=
seq_group
.
prefix
,
)
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
return
seq_group_metadata_list
,
scheduler_outputs
...
@@ -312,10 +395,8 @@ class Scheduler:
...
@@ -312,10 +395,8 @@ class Scheduler:
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
def
free_finished_seq_groups
(
self
)
->
None
:
self
.
running
=
[
self
.
running
=
deque
(
seq_group
for
seq_group
in
self
.
running
seq_group
for
seq_group
in
self
.
running
if
not
seq_group
.
is_finished
())
if
not
seq_group
.
is_finished
()
]
def
_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
self
.
block_manager
.
allocate
(
seq_group
)
...
@@ -376,7 +457,7 @@ class Scheduler:
...
@@ -376,7 +457,7 @@ class Scheduler:
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
# NOTE: For FCFS, we insert the preempted sequence group to the front
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
# of the waiting queue.
self
.
waiting
.
insert
(
0
,
seq_group
)
self
.
waiting
.
appendleft
(
seq_group
)
def
_preempt_by_swap
(
def
_preempt_by_swap
(
self
,
self
,
...
...
vllm/engine/arg_utils.py
View file @
51679bbd
...
@@ -4,7 +4,7 @@ from dataclasses import dataclass
...
@@ -4,7 +4,7 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
,
LoRAConfig
)
@
dataclass
@
dataclass
...
@@ -17,6 +17,7 @@ class EngineArgs:
...
@@ -17,6 +17,7 @@ class EngineArgs:
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
load_format
:
str
=
'auto'
load_format
:
str
=
'auto'
dtype
:
str
=
'auto'
dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
seed
:
int
=
0
seed
:
int
=
0
max_model_len
:
Optional
[
int
]
=
None
max_model_len
:
Optional
[
int
]
=
None
worker_use_ray
:
bool
=
False
worker_use_ray
:
bool
=
False
...
@@ -35,6 +36,13 @@ class EngineArgs:
...
@@ -35,6 +36,13 @@ class EngineArgs:
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
enforce_eager
:
bool
=
False
enforce_eager
:
bool
=
False
max_context_len_to_capture
:
int
=
8192
max_context_len_to_capture
:
int
=
8192
disable_custom_all_reduce
:
bool
=
False
enable_lora
:
bool
=
False
max_loras
:
int
=
1
max_lora_rank
:
int
=
16
lora_extra_vocab_size
:
int
=
256
lora_dtype
=
'auto'
max_cpu_loras
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
...
@@ -115,6 +123,14 @@ class EngineArgs:
...
@@ -115,6 +123,14 @@ class EngineArgs:
'The "auto" option will use FP16 precision '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
'for BF16 models.'
)
parser
.
add_argument
(
'--kv-cache-dtype'
,
type
=
str
,
choices
=
[
'auto'
,
'fp8_e5m2'
],
default
=
'auto'
,
help
=
'Data type for kv cache storage. If "auto", will use model '
'data type. Note FP8 is not supported when cuda version is '
'lower than 11.8.'
)
parser
.
add_argument
(
'--max-model-len'
,
parser
.
add_argument
(
'--max-model-len'
,
type
=
int
,
type
=
int
,
default
=
None
,
default
=
None
,
...
@@ -202,6 +218,43 @@ class EngineArgs:
...
@@ -202,6 +218,43 @@ class EngineArgs:
help
=
'maximum context length covered by CUDA '
help
=
'maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.'
)
'larger than this, we fall back to eager mode.'
)
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
action
=
'store_true'
,
default
=
EngineArgs
.
disable_custom_all_reduce
,
help
=
'See ParallelConfig'
)
# LoRA related configs
parser
.
add_argument
(
'--enable-lora'
,
action
=
'store_true'
,
help
=
'If True, enable handling of LoRA adapters.'
)
parser
.
add_argument
(
'--max-loras'
,
type
=
int
,
default
=
EngineArgs
.
max_loras
,
help
=
'Max number of LoRAs in a single batch.'
)
parser
.
add_argument
(
'--max-lora-rank'
,
type
=
int
,
default
=
EngineArgs
.
max_lora_rank
,
help
=
'Max LoRA rank.'
)
parser
.
add_argument
(
'--lora-extra-vocab-size'
,
type
=
int
,
default
=
EngineArgs
.
lora_extra_vocab_size
,
help
=
(
'Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'
))
parser
.
add_argument
(
'--lora-dtype'
,
type
=
str
,
default
=
EngineArgs
.
lora_dtype
,
choices
=
[
'auto'
,
'float16'
,
'bfloat16'
,
'float32'
],
help
=
(
'Data type for LoRA. If auto, will default to '
'base model dtype.'
))
parser
.
add_argument
(
'--max-cpu-loras'
,
type
=
int
,
default
=
EngineArgs
.
max_cpu_loras
,
help
=
(
'Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'
))
return
parser
return
parser
@
classmethod
@
classmethod
...
@@ -214,7 +267,8 @@ class EngineArgs:
...
@@ -214,7 +267,8 @@ class EngineArgs:
def
create_engine_configs
(
def
create_engine_configs
(
self
,
self
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
,
Optional
[
LoRAConfig
]]:
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
download_dir
,
self
.
load_format
,
self
.
download_dir
,
self
.
load_format
,
...
@@ -224,17 +278,25 @@ class EngineArgs:
...
@@ -224,17 +278,25 @@ class EngineArgs:
self
.
max_context_len_to_capture
)
self
.
max_context_len_to_capture
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
model_config
.
get_sliding_window
())
model_config
.
get_sliding_window
())
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
,
self
.
worker_use_ray
,
self
.
max_parallel_loading_workers
)
self
.
max_parallel_loading_workers
,
self
.
disable_custom_all_reduce
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
self
.
max_num_seqs
,
model_config
.
max_model_len
,
model_config
.
max_model_len
,
self
.
max_paddings
)
self
.
max_paddings
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_loras
=
self
.
max_loras
,
lora_extra_vocab_size
=
self
.
lora_extra_vocab_size
,
lora_dtype
=
self
.
lora_dtype
,
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
lora_config
@
dataclass
@
dataclass
...
...
vllm/engine/async_llm_engine.py
View file @
51679bbd
...
@@ -4,6 +4,7 @@ from functools import partial
...
@@ -4,6 +4,7 @@ from functools import partial
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
AsyncIterator
)
Union
,
AsyncIterator
)
from
vllm.lora.request
import
LoRARequest
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
...
@@ -52,7 +53,7 @@ class AsyncStream:
...
@@ -52,7 +53,7 @@ class AsyncStream:
self
.
_queue
.
put_nowait
(
item
)
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
self
)
->
None
:
def
finish
(
self
)
->
None
:
self
.
_queue
.
put_nowait
(
StopIteration
)
self
.
_queue
.
put_nowait
(
Stop
Async
Iteration
()
)
self
.
_finished
=
True
self
.
_finished
=
True
@
property
@
property
...
@@ -64,9 +65,7 @@ class AsyncStream:
...
@@ -64,9 +65,7 @@ class AsyncStream:
async
def
__anext__
(
self
)
->
RequestOutput
:
async
def
__anext__
(
self
)
->
RequestOutput
:
result
=
await
self
.
_queue
.
get
()
result
=
await
self
.
_queue
.
get
()
if
result
is
StopIteration
:
if
isinstance
(
result
,
Exception
):
raise
StopAsyncIteration
elif
isinstance
(
result
,
Exception
):
raise
result
raise
result
return
result
return
result
...
@@ -203,6 +202,52 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -203,6 +202,52 @@ class _AsyncLLMEngine(LLMEngine):
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
async
def
encode_request_async
(
self
,
request_id
:
str
,
# pylint: disable=unused-argument
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
):
if
prompt_token_ids
is
None
:
assert
prompt
is
not
None
prompt_token_ids
=
await
self
.
tokenizer
.
encode_async
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
return
prompt_token_ids
async
def
add_request_async
(
self
,
request_id
:
str
,
prompt
:
Optional
[
str
],
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix_pos
:
Optional
[
int
]
=
None
,
)
->
None
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
prompt_token_ids
=
await
self
.
encode_request_async
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
return
self
.
add_request
(
request_id
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prefix_pos
=
prefix_pos
,
)
async
def
_run_workers_async
(
async
def
_run_workers_async
(
self
,
self
,
method
:
str
,
method
:
str
,
...
@@ -253,7 +298,8 @@ class AsyncLLMEngine:
...
@@ -253,7 +298,8 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
will be automatically started in the generate call.
*args, *kwargs: Arguments for LLMEngine.
*args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
"""
"""
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
...
@@ -331,7 +377,7 @@ class AsyncLLMEngine:
...
@@ -331,7 +377,7 @@ class AsyncLLMEngine:
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
add_request
.
remote
(
**
new_request
)
await
self
.
engine
.
add_request
.
remote
(
**
new_request
)
else
:
else
:
self
.
engine
.
add_request
(
**
new_request
)
await
self
.
engine
.
add_request
_async
(
**
new_request
)
if
finished_requests
:
if
finished_requests
:
await
self
.
_engine_abort
(
finished_requests
)
await
self
.
_engine_abort
(
finished_requests
)
...
@@ -370,6 +416,8 @@ class AsyncLLMEngine:
...
@@ -370,6 +416,8 @@ class AsyncLLMEngine:
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix_pos
:
Optional
[
int
]
=
None
,
)
->
AsyncStream
:
)
->
AsyncStream
:
if
self
.
log_requests
:
if
self
.
log_requests
:
shortened_prompt
=
prompt
shortened_prompt
=
prompt
...
@@ -382,8 +430,10 @@ class AsyncLLMEngine:
...
@@ -382,8 +430,10 @@ class AsyncLLMEngine:
max_log_len
]
max_log_len
]
logger
.
info
(
f
"Received request
{
request_id
}
: "
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
shortened_prompt
!
r
}
, "
f
"prompt:
{
shortened_prompt
!
r
}
, "
f
"prefix_pos:
{
prefix_pos
}
,"
f
"sampling params:
{
sampling_params
}
, "
f
"sampling params:
{
sampling_params
}
, "
f
"prompt token ids:
{
shortened_token_ids
}
."
)
f
"prompt token ids:
{
shortened_token_ids
}
, "
f
"lora_request:
{
lora_request
}
."
)
if
not
self
.
is_running
:
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
if
self
.
start_engine_loop
:
...
@@ -395,12 +445,30 @@ class AsyncLLMEngine:
...
@@ -395,12 +445,30 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"error that caused the background loop to stop "
"(AsyncEngineDeadError)."
)
"(AsyncEngineDeadError)."
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
if
self
.
engine_use_ray
:
prompt_token_ids
=
await
self
.
engine
.
encode_request_async
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
else
:
prompt_token_ids
=
await
self
.
engine
.
encode_request_async
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
stream
=
self
.
_request_tracker
.
add_request
(
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prefix_pos
=
prefix_pos
)
return
stream
return
stream
...
@@ -409,7 +477,9 @@ class AsyncLLMEngine:
...
@@ -409,7 +477,9 @@ class AsyncLLMEngine:
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix_pos
:
Optional
[
int
]
=
None
,
)
->
AsyncIterator
[
RequestOutput
]:
)
->
AsyncIterator
[
RequestOutput
]:
"""Generate outputs for a request.
"""Generate outputs for a request.
...
@@ -424,21 +494,74 @@ class AsyncLLMEngine:
...
@@ -424,21 +494,74 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Yields:
Yields:
The output `RequestOutput` objects from the LLMEngine for the
The output `RequestOutput` objects from the LLMEngine for the
request.
request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "prompt": "What is LLM?",
>>> "stream": False, # assume the non-streaming case
>>> "temperature": 0.0,
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.generate(
>>> example_input["prompt"],
>>> SamplingParams(temperature=example_input["temperature"]),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
"""
# Preprocess the request.
# Preprocess the request.
# This should not be used for logging, as it is monotonic time.
# This should not be used for logging, as it is monotonic time.
arrival_time
=
time
.
monotonic
()
arrival_time
=
time
.
monotonic
()
try
:
try
:
stream
=
await
self
.
add_request
(
request_id
,
stream
=
await
self
.
add_request
(
prompt
,
request_id
,
sampling_params
,
prompt
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
,
arrival_time
=
arrival_time
)
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prefix_pos
=
prefix_pos
,
)
async
for
request_output
in
stream
:
async
for
request_output
in
stream
:
yield
request_output
yield
request_output
...
@@ -506,3 +629,9 @@ class AsyncLLMEngine:
...
@@ -506,3 +629,9 @@ class AsyncLLMEngine:
max_log_len
=
engine_args
.
max_log_len
,
max_log_len
=
engine_args
.
max_log_len
,
start_engine_loop
=
start_engine_loop
)
start_engine_loop
=
start_engine_loop
)
return
engine
return
engine
async
def
do_log_stats
(
self
)
->
None
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
do_log_stats
.
remote
()
else
:
self
.
engine
.
do_log_stats
()
vllm/engine/llm_engine.py
View file @
51679bbd
...
@@ -5,8 +5,9 @@ import time
...
@@ -5,8 +5,9 @@ import time
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
)
Union
)
from
vllm.lora.request
import
LoRARequest
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
,
LoRAConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
record_metrics
from
vllm.engine.metrics
import
record_metrics
...
@@ -17,8 +18,8 @@ from vllm.sampling_params import SamplingParams
...
@@ -17,8 +18,8 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_t
okenizer
)
T
okenizer
Group
)
from
vllm.utils
import
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
from
vllm.utils
import
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
,
get_distributed_init_method
if
ray
:
if
ray
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
...
@@ -64,6 +65,7 @@ class LLMEngine:
...
@@ -64,6 +65,7 @@ class LLMEngine:
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
placement_group
:
Optional
[
"PlacementGroup"
],
placement_group
:
Optional
[
"PlacementGroup"
],
log_stats
:
bool
,
log_stats
:
bool
,
)
->
None
:
)
->
None
:
...
@@ -80,24 +82,22 @@ class LLMEngine:
...
@@ -80,24 +82,22 @@ class LLMEngine:
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"load_format=
{
model_config
.
load_format
}
, "
f
"load_format=
{
model_config
.
load_format
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"disable_custom_all_reduce=
{
parallel_config
.
disable_custom_all_reduce
}
, "
f
"quantization=
{
model_config
.
quantization
}
, "
f
"quantization=
{
model_config
.
quantization
}
, "
f
"enforce_eager=
{
model_config
.
enforce_eager
}
, "
f
"enforce_eager=
{
model_config
.
enforce_eager
}
, "
f
"kv_cache_dtype=
{
cache_config
.
cache_dtype
}
, "
f
"seed=
{
model_config
.
seed
}
)"
)
f
"seed=
{
model_config
.
seed
}
)"
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
_verify_args
()
self
.
_verify_args
()
self
.
tokenizer
=
get_tokenizer
(
self
.
_init_tokenizer
()
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
,
trust_remote_code
=
model_config
.
trust_remote_code
,
tokenizer_revision
=
model_config
.
tokenizer_revision
,
revision
=
model_config
.
revision
)
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
# Create the parallel GPU workers.
...
@@ -114,7 +114,7 @@ class LLMEngine:
...
@@ -114,7 +114,7 @@ class LLMEngine:
self
.
_init_cache
()
self
.
_init_cache
()
# Create the scheduler.
# Create the scheduler.
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
)
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
# Logging.
# Logging.
self
.
last_logging_time
=
0.0
self
.
last_logging_time
=
0.0
...
@@ -123,6 +123,9 @@ class LLMEngine:
...
@@ -123,6 +123,9 @@ class LLMEngine:
# List of (timestamp, num_tokens)
# List of (timestamp, num_tokens)
self
.
num_generation_tokens
:
List
[
Tuple
[
float
,
int
]]
=
[]
self
.
num_generation_tokens
:
List
[
Tuple
[
float
,
int
]]
=
[]
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
):
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
def
_init_workers
(
self
):
def
_init_workers
(
self
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
# before CUDA_VISIBLE_DEVICES is set in the Worker
...
@@ -132,7 +135,8 @@ class LLMEngine:
...
@@ -132,7 +135,8 @@ class LLMEngine:
"Ray is required if parallel_config.world_size > 1."
)
"Ray is required if parallel_config.world_size > 1."
)
self
.
workers
:
List
[
Worker
]
=
[]
self
.
workers
:
List
[
Worker
]
=
[]
distributed_init_method
=
f
"tcp://
{
get_ip
()
}
:
{
get_open_port
()
}
"
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
Worker
(
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
model_config
,
self
.
parallel_config
,
self
.
parallel_config
,
...
@@ -140,11 +144,25 @@ class LLMEngine:
...
@@ -140,11 +144,25 @@ class LLMEngine:
local_rank
=
0
,
local_rank
=
0
,
rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
self
.
_run_workers
(
"init_model"
)
self
.
_run_workers
(
"init_model"
)
self
.
_run_workers
(
"load_model"
)
self
.
_run_workers
(
"load_model"
)
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
):
init_kwargs
=
dict
(
enable_lora
=
bool
(
self
.
lora_config
),
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
,
max_input_length
=
None
,
tokenizer_mode
=
self
.
model_config
.
tokenizer_mode
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
revision
=
self
.
model_config
.
tokenizer_revision
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
self
.
tokenizer
:
TokenizerGroup
=
TokenizerGroup
(
self
.
model_config
.
tokenizer
,
**
init_kwargs
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
...
@@ -207,7 +225,8 @@ class LLMEngine:
...
@@ -207,7 +225,8 @@ class LLMEngine:
for
worker
,
(
node_id
,
_
)
in
zip
(
self
.
workers
,
worker_node_and_gpu_ids
):
for
worker
,
(
node_id
,
_
)
in
zip
(
self
.
workers
,
worker_node_and_gpu_ids
):
worker
.
set_cuda_visible_devices
.
remote
(
node_gpus
[
node_id
])
worker
.
set_cuda_visible_devices
.
remote
(
node_gpus
[
node_id
])
distributed_init_method
=
f
"tcp://
{
driver_ip
}
:
{
get_open_port
()
}
"
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
# Lazy import the Worker to avoid importing torch.cuda/xformers
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
# before CUDA_VISIBLE_DEVICES is set in the Worker
...
@@ -231,6 +250,8 @@ class LLMEngine:
...
@@ -231,6 +250,8 @@ class LLMEngine:
local_rank
,
local_rank
,
rank
,
rank
,
distributed_init_method
,
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
))
))
driver_rank
=
0
driver_rank
=
0
...
@@ -242,6 +263,8 @@ class LLMEngine:
...
@@ -242,6 +263,8 @@ class LLMEngine:
driver_local_rank
,
driver_local_rank
,
driver_rank
,
driver_rank
,
distributed_init_method
,
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
...
@@ -255,15 +278,39 @@ class LLMEngine:
...
@@ -255,15 +278,39 @@ class LLMEngine:
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
lora_config
:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
scheduler_config
)
def
_init_cache
(
self
)
->
None
:
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache."""
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameters.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
"profile_num_available_blocks"
,
block_size
=
self
.
cache_config
.
block_size
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
cache_dtype
=
self
.
cache_config
.
cache_dtype
,
)
)
# Since we use a shared centralized controller, we take the minimum
# Since we use a shared centralized controller, we take the minimum
...
@@ -311,6 +358,20 @@ class LLMEngine:
...
@@ -311,6 +358,20 @@ class LLMEngine:
log_stats
=
not
engine_args
.
disable_log_stats
)
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
return
engine
def
encode_request
(
self
,
request_id
:
str
,
# pylint: disable=unused-argument
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
):
if
prompt_token_ids
is
None
:
assert
prompt
is
not
None
prompt_token_ids
=
self
.
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
return
prompt_token_ids
def
add_request
(
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
@@ -318,6 +379,8 @@ class LLMEngine:
...
@@ -318,6 +379,8 @@ class LLMEngine:
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix_pos
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
"""Add a request to the engine's request pool.
"""Add a request to the engine's request pool.
...
@@ -334,21 +397,61 @@ class LLMEngine:
...
@@ -334,21 +397,61 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs.
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
the current monotonic time.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Example:
>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>> str(request_id),
>>> example_prompt,
>>> SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
"""
"""
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
monotonic
()
arrival_time
=
time
.
monotonic
()
if
prompt_token_ids
is
None
:
prompt_token_ids
=
self
.
encode_request
(
assert
prompt
is
not
None
request_id
=
request_id
,
prompt_token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
,
lora_request
)
# Check whether the input specifies prefix
prefix
=
self
.
scheduler
.
prefix_pool
.
add_or_get_prefix
(
prompt_token_ids
[:
prefix_pos
],
lora_request
.
lora_int_id
if
lora_request
else
0
)
if
prefix_pos
is
not
None
else
None
# Create the sequence group.
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
arrival_time
)
arrival_time
,
lora_request
,
prefix
)
# Add the sequence group to the scheduler.
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
self
.
scheduler
.
add_seq_group
(
seq_group
)
...
@@ -358,6 +461,17 @@ class LLMEngine:
...
@@ -358,6 +461,17 @@ class LLMEngine:
Args:
Args:
request_id: The ID(s) of the request to abort.
request_id: The ID(s) of the request to abort.
Details:
- Refer to the
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
from class :class:`~vllm.core.scheduler.Scheduler`.
Example:
>>> # initialize engine and add a request with request_id
>>> request_id = str(0)
>>> # abort the request
>>> engine.abort_request(request_id)
"""
"""
self
.
scheduler
.
abort_seq_group
(
request_id
)
self
.
scheduler
.
abort_seq_group
(
request_id
)
...
@@ -387,11 +501,13 @@ class LLMEngine:
...
@@ -387,11 +501,13 @@ class LLMEngine:
current_worst_score
=
(
current_worst_seq
.
get_beam_search_score
(
current_worst_score
=
(
current_worst_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
))
eos_token_id
=
self
.
get_tokenizer_for_seq
(
current_worst_seq
).
eos_token_id
))
if
early_stopping
is
False
:
if
early_stopping
is
False
:
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
))
eos_token_id
=
self
.
get_tokenizer_for_seq
(
best_running_seq
).
eos_token_id
))
else
:
else
:
assert
early_stopping
==
"never"
assert
early_stopping
==
"never"
if
length_penalty
>
0.0
:
if
length_penalty
>
0.0
:
...
@@ -405,7 +521,8 @@ class LLMEngine:
...
@@ -405,7 +521,8 @@ class LLMEngine:
highest_attainable_score
=
(
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
best_running_seq
).
eos_token_id
,
seq_len
=
max_possible_length
))
seq_len
=
max_possible_length
))
else
:
else
:
# Otherwise, beam search will prefer shorter sequences. The
# Otherwise, beam search will prefer shorter sequences. The
...
@@ -414,7 +531,8 @@ class LLMEngine:
...
@@ -414,7 +531,8 @@ class LLMEngine:
highest_attainable_score
=
(
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
))
eos_token_id
=
self
.
get_tokenizer_for_seq
(
best_running_seq
).
eos_token_id
))
return
current_worst_score
>=
highest_attainable_score
return
current_worst_score
>=
highest_attainable_score
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
...
@@ -505,7 +623,7 @@ class LLMEngine:
...
@@ -505,7 +623,7 @@ class LLMEngine:
# Sort the finished sequences by their scores.
# Sort the finished sequences by their scores.
all_finished_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
all_finished_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
),
eos_token_id
=
self
.
get_
tokenizer
_for_seq
(
x
[
0
])
.
eos_token_id
),
reverse
=
True
)
reverse
=
True
)
for
seq
,
parent
,
is_new
in
all_finished_seqs
[:
beam_width
]:
for
seq
,
parent
,
is_new
in
all_finished_seqs
[:
beam_width
]:
if
is_new
:
if
is_new
:
...
@@ -533,7 +651,7 @@ class LLMEngine:
...
@@ -533,7 +651,7 @@ class LLMEngine:
# Sort the running sequences by their scores.
# Sort the running sequences by their scores.
running_child_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
running_child_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
),
eos_token_id
=
self
.
get_
tokenizer
_for_seq
(
x
[
0
])
.
eos_token_id
),
reverse
=
True
)
reverse
=
True
)
# Check if we can stop the beam search.
# Check if we can stop the beam search.
...
@@ -601,10 +719,18 @@ class LLMEngine:
...
@@ -601,10 +719,18 @@ class LLMEngine:
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
(
scheduled_seq_groups
+
for
seq_group
in
scheduled_seq_groups
:
scheduler_outputs
.
ignored_seq_groups
):
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
# Update prefix state, now all the uncomputed prefixes are computed.
for
seq_group
in
scheduled_seq_groups
:
if
(
seq_group
.
prefix
is
not
None
and
seq_group
.
prefix
.
allocated
and
not
seq_group
.
prefix
.
computed
):
seq_group
.
prefix
.
computed
=
True
if
self
.
log_stats
:
if
self
.
log_stats
:
# Log the system stats.
# Log the system stats.
...
@@ -615,11 +741,53 @@ class LLMEngine:
...
@@ -615,11 +741,53 @@ class LLMEngine:
def
step
(
self
)
->
List
[
RequestOutput
]:
def
step
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
.. figure:: https://i.imgur.com/sv2HssD.png
schedules the sequences to be executed in the next iteration and the
:alt: Overview of the step function
token blocks to be swapped in/out/copy. Then, it executes the model
:align: center
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
Overview of the step function.
Details:
- Step 1: Schedules the sequences to be executed in the next
iteration and the token blocks to be swapped in/out/copy.
- Depending on the scheduling policy,
sequences may be `preempted/reordered`.
- A Sequence Group (SG) refer to a group of sequences
that are generated from the same prompt.
- Step 2: Calls the workers to execute the model.
- Step 3: Processes the model output. This mainly includes:
- Decodes the relevant outputs.
- Updates the scheduled sequence groups with model outputs
based on its `sampling parameters` (`use_beam_search` or not).
- Frees the finished sequence groups.
- Finally, it creates and returns the newly generated results.
Example:
>>> # Please see the example/ folder for more detailed examples.
>>>
>>> # initialize engine and request arguments
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> example_inputs = [(0, "What is LLM?",
>>> SamplingParams(temperature=0.0))]
>>>
>>> # Start the engine with an event loop
>>> while True:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params)
>>>
>>> # continue the request processing
>>> request_outputs = engine.step()
>>> for request_output in request_outputs:
>>> if request_output.finished:
>>> # return or show the request output
>>>
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
"""
"""
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
...
@@ -641,6 +809,9 @@ class LLMEngine:
...
@@ -641,6 +809,9 @@ class LLMEngine:
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
def
do_log_stats
(
self
)
->
None
:
self
.
_log_system_stats
(
False
,
0
)
def
_log_system_stats
(
def
_log_system_stats
(
self
,
self
,
prompt_run
:
bool
,
prompt_run
:
bool
,
...
@@ -718,7 +889,7 @@ class LLMEngine:
...
@@ -718,7 +889,7 @@ class LLMEngine:
"""Decodes the new token for a sequence."""
"""Decodes the new token for a sequence."""
(
new_tokens
,
new_output_text
,
prefix_offset
,
(
new_tokens
,
new_output_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
read_offset
)
=
detokenize_incrementally
(
self
.
tokenizer
,
self
.
get_
tokenizer
_for_seq
(
seq
)
,
all_input_ids
=
seq
.
get_token_ids
(),
all_input_ids
=
seq
.
get_token_ids
(),
prev_tokens
=
seq
.
tokens
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
prefix_offset
=
seq
.
prefix_offset
,
...
@@ -760,11 +931,28 @@ class LLMEngine:
...
@@ -760,11 +931,28 @@ class LLMEngine:
return
return
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
and
se
q
.
get_
last_token_id
()
==
self
.
tokenizer
.
eos_token_id
):
==
se
lf
.
get_
tokenizer_for_seq
(
seq
)
.
eos_token_id
):
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
return
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"add_lora"
,
lora_request
=
lora_request
,
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"remove_lora"
,
lora_id
=
lora_id
,
)
def
list_loras
(
self
)
->
List
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
str
,
...
...
vllm/engine/ray_utils.py
View file @
51679bbd
...
@@ -43,7 +43,7 @@ try:
...
@@ -43,7 +43,7 @@ try:
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
f
"Failed to import Ray with
{
e
!
r
}
. "
logger
.
warning
(
f
"Failed to import Ray with
{
e
!
r
}
. "
"For distributed inference, please install Ray with "
"For distributed inference, please install Ray with "
"`pip install ray
pandas pyarrow
`."
)
"`pip install ray`."
)
ray
=
None
ray
=
None
RayWorkerVllm
=
None
RayWorkerVllm
=
None
...
@@ -55,7 +55,7 @@ def initialize_cluster(
...
@@ -55,7 +55,7 @@ def initialize_cluster(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
engine_use_ray
:
bool
=
False
,
engine_use_ray
:
bool
=
False
,
ray_address
:
Optional
[
str
]
=
None
,
ray_address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
Optional
[
"PlacementGroup"
]
]
:
)
->
Optional
[
"PlacementGroup"
]:
"""Initialize the distributed cluster probably with Ray.
"""Initialize the distributed cluster probably with Ray.
Args:
Args:
...
@@ -65,10 +65,9 @@ def initialize_cluster(
...
@@ -65,10 +65,9 @@ def initialize_cluster(
the default Ray cluster address.
the default Ray cluster address.
Returns:
Returns:
A tuple of (`distributed_init_method`, `placement_group`). The
An optional `PlacementGroup`. It includes the specification
`distributed_init_method` is the address for initializing the
of the resources for each distributed worker. None if Ray is
distributed backend. `placement_group` includes the specification
not used.
of the resources for each distributed worker.
"""
"""
if
parallel_config
.
worker_use_ray
or
engine_use_ray
:
if
parallel_config
.
worker_use_ray
or
engine_use_ray
:
if
ray
is
None
:
if
ray
is
None
:
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment