Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7d2dcce1
Unverified
Commit
7d2dcce1
authored
Feb 21, 2024
by
Nick Hill
Committed by
GitHub
Feb 21, 2024
Browse files
Support per-request seed (#2514)
parent
dc903e70
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
289 additions
and
84 deletions
+289
-84
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+147
-75
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+82
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+22
-7
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+3
-0
vllm/sampling_params.py
vllm/sampling_params.py
+8
-1
vllm/sequence.py
vllm/sequence.py
+12
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+10
-0
No files found.
tests/samplers/test_sampler.py
View file @
7d2dcce1
import
random
import
random
from
typing
import
Tuple
from
typing
import
Tuple
,
List
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
GenerationConfig
,
GenerationMixin
from
transformers
import
GenerationConfig
,
GenerationMixin
from
typing
import
Optional
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
...
@@ -46,15 +47,13 @@ CUDA_DEVICES = [
...
@@ -46,15 +47,13 @@ CUDA_DEVICES = [
]
]
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
_do_sample
(
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
batch_size
:
int
,
def
test_sampler_all_greedy
(
seed
:
int
,
device
:
str
):
input_tensor
:
torch
.
Tensor
,
set_random_seed
(
seed
)
sampler
:
MockLogitsSampler
,
torch
.
set_default_device
(
device
)
model_runner
:
ModelRunner
,
batch_size
=
random
.
randint
(
1
,
256
)
sampling_params
:
SamplingParams
,
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
):
batch_size
)
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
prompt_lens
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
@@ -63,7 +62,7 @@ def test_sampler_all_greedy(seed: int, device: str):
...
@@ -63,7 +62,7 @@ def test_sampler_all_greedy(seed: int, device: str):
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
S
ampling
P
arams
(
temperature
=
0
,
)
,
sampling_params
=
s
ampling
_p
arams
,
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
...
@@ -71,9 +70,23 @@ def test_sampler_all_greedy(seed: int, device: str):
...
@@ -71,9 +70,23 @@ def test_sampler_all_greedy(seed: int, device: str):
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
return
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_greedy
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -94,28 +107,40 @@ def test_sampler_all_random(seed: int, device: str):
...
@@ -94,28 +107,40 @@ def test_sampler_all_random(seed: int, device: str):
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
fake_logits
[
i
,
i
]
=
1e2
seq_group_metadata_list
=
[]
sampling_params
=
SamplingParams
(
prompt_lens
=
[]
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
)
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random_seed
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
fake_logits
[
i
,
i
]
=
1e2
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
sampling_params
=
SamplingParams
(
is_prompt
=
True
,
temperature
=
1.0
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
n
=
random
.
randint
(
1
,
10
),
sampling_params
=
SamplingParams
(
seed
=
random
.
randint
(
0
,
10000
),
temperature
=
1.0
,
)
n
=
random
.
randint
(
1
,
10
),
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
),
model_runner
,
sampling_params
)
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
assert
nth_output
.
output_token
==
i
...
@@ -123,6 +148,31 @@ def test_sampler_all_random(seed: int, device: str):
...
@@ -123,6 +148,31 @@ def test_sampler_all_random(seed: int, device: str):
del
model_runner
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random_seed_deterministic
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
first_sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
second_sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
assert
first_sampler_output
==
second_sampler_output
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_beam
(
seed
:
int
,
device
:
str
):
def
test_sampler_all_beam
(
seed
:
int
,
device
:
str
):
...
@@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
...
@@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
input_tensor
,
_
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
sampling_params
=
SamplingParams
(
prompt_lens
=
[]
temperature
=
0
,
for
i
in
range
(
batch_size
):
best_of
=
2
,
seq_group_metadata_list
.
append
(
use_beam_search
=
True
,
SequenceGroupMetadata
(
)
request_id
=
f
"test_
{
i
}
"
,
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
is_prompt
=
True
,
sampling_params
)
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
best_of
=
2
,
use_beam_search
=
True
,
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
# no assertion here as I am not sure how to determine whether
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# whether there are no exceptions in the sampler
...
@@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size
)
batch_size
)
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
expected_tokens
=
[]
expected_tokens
:
List
[
Optional
[
List
[
int
]]]
=
[]
prompt_lens
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
n
=
1
expected
:
Optional
[
List
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
2
)
sampling_type
=
random
.
randint
(
0
,
3
)
if
sampling_type
==
0
:
if
sampling_type
==
0
:
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
elif
sampling_type
==
1
:
expected
=
[
torch
.
argmax
(
fake_logits
[
i
],
dim
=-
1
).
item
()]
elif
sampling_type
in
(
1
,
2
):
n
=
random
.
randint
(
1
,
10
)
n
=
random
.
randint
(
1
,
10
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
random
.
random
()
+
0.1
,
temperature
=
random
.
random
()
+
0.1
,
...
@@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
n
=
n
,
n
=
n
,
presence_penalty
=
random
.
randint
(
0
,
1
),
presence_penalty
=
random
.
randint
(
0
,
1
),
)
)
if
sampling_type
==
2
:
sampling_params
.
seed
=
random
.
randint
(
0
,
10000
)
else
:
for
idx
in
range
(
n
):
fake_logits
[
i
,
i
+
idx
]
=
1e2
expected
=
list
(
range
(
i
,
i
+
n
))
else
:
else
:
sampling_params
=
SamplingParams
(
temperature
=
0
,
sampling_params
=
SamplingParams
(
temperature
=
0
,
use_beam_search
=
True
,
use_beam_search
=
True
,
best_of
=
2
)
best_of
=
2
)
for
idx
in
range
(
n
):
expected_tokens
.
append
(
expected
)
fake_logits
[
i
,
i
+
idx
]
=
1e2
expected_tokens
.
append
(
i
+
idx
)
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
...
@@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
))
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
def
test_sampling
(
model_runner
:
ModelRunner
):
prompt_lens
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
subquery_lens
=
prompt_lens
)
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
if
seq_group_metadata_list
[
i
].
sampling_params
.
use_beam_search
:
for
i
,
(
sequence_output
,
metadata
)
in
enumerate
(
continue
zip
(
sampler_output
,
seq_group_metadata_list
)):
for
nth_output
in
sequence_output
.
samples
:
if
metadata
.
sampling_params
.
use_beam_search
:
assert
nth_output
.
output_token
in
expected_tokens
continue
if
metadata
.
sampling_params
.
seed
is
not
None
\
and
expected_tokens
[
i
]
is
None
:
# Record seeded random result to compare with results of second invocation
expected_tokens
[
i
]
=
[
nth_output
.
output_token
for
nth_output
in
sequence_output
.
samples
]
continue
for
n
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
if
metadata
.
sampling_params
.
temperature
==
0
or
metadata
.
sampling_params
.
seed
is
not
None
:
# Ensure exact matches for greedy or random with seed
assert
nth_output
.
output_token
==
expected_tokens
[
i
][
n
]
else
:
# For non-seeded random check that one of the high-logit tokens were chosen
assert
nth_output
.
output_token
in
expected_tokens
[
i
]
# Test batch
test_sampling
(
model_runner
)
# Shuffle the batch and resample
target_index
=
list
(
range
(
batch_size
))
for
list_to_shuffle
in
(
target_index
,
seq_group_metadata_list
,
expected_tokens
,
prompt_lens
):
random
.
Random
(
seed
).
shuffle
(
list_to_shuffle
)
target_index
=
torch
.
tensor
(
target_index
)
input_tensor
.
data
=
input_tensor
.
index_select
(
0
,
target_index
)
fake_logits
.
data
=
fake_logits
.
index_select
(
0
,
target_index
)
# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
test_sampling
(
model_runner
)
del
model_runner
del
model_runner
...
...
tests/samplers/test_seeded_generate.py
0 → 100644
View file @
7d2dcce1
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py --forked`.
"""
import
copy
import
random
from
itertools
import
combinations
import
pytest
from
vllm.model_executor.utils
import
set_random_seed
from
vllm
import
SamplingParams
MODEL
=
"facebook/opt-125m"
RANDOM_SEEDS
=
list
(
range
(
5
))
@
pytest
.
fixture
def
vllm_model
(
vllm_runner
):
vllm_model
=
vllm_runner
(
MODEL
,
dtype
=
"half"
)
yield
vllm_model
del
vllm_model
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_random_sample_with_seed
(
vllm_model
,
example_prompts
,
seed
:
int
,
)
->
None
:
set_random_seed
(
seed
)
sampling_params
=
SamplingParams
(
# Parameters to ensure sufficient randomness
temperature
=
2.0
,
top_p
=
min
(
random
.
random
()
+
0.3
,
1
),
top_k
=
random
.
randint
(
5
,
20
),
n
=
random
.
randint
(
1
,
10
),
presence_penalty
=
random
.
randint
(
0
,
1
),
max_tokens
=
8
,
ignore_eos
=
True
,
)
sampling_params_seed_1
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_1
.
seed
=
100
sampling_params_seed_2
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_2
.
seed
=
200
llm
=
vllm_model
.
model
for
prompt
in
example_prompts
:
for
params
in
(
sampling_params
,
sampling_params_seed_1
,
sampling_params_seed_2
,
sampling_params
,
sampling_params_seed_1
,
sampling_params_seed_2
,
):
llm
.
_add_request
(
prompt
=
prompt
,
prompt_token_ids
=
None
,
sampling_params
=
params
,
)
results
=
llm
.
_run_engine
(
use_tqdm
=
False
)
all_outputs
=
[[
out
.
token_ids
for
out
in
output
.
outputs
]
for
output
in
results
]
for
i
in
range
(
0
,
len
(
example_prompts
),
6
):
outputs
=
all_outputs
[
i
:
i
+
6
]
# verify all non-seeded requests differ
for
output_a
,
output_b
in
combinations
(
(
outputs
[
0
],
outputs
[
1
],
outputs
[
2
],
outputs
[
3
]),
2
,
):
assert
output_a
!=
output_b
# verify requests with the same seed match
assert
outputs
[
1
]
==
outputs
[
4
]
assert
outputs
[
2
]
==
outputs
[
5
]
vllm/core/scheduler.py
View file @
7d2dcce1
...
@@ -387,6 +387,7 @@ class Scheduler:
...
@@ -387,6 +387,7 @@ class Scheduler:
block_tables
=
block_tables
,
block_tables
=
block_tables
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
prefix
=
seq_group
.
prefix
,
prefix
=
seq_group
.
prefix
,
state
=
seq_group
.
state
,
)
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
return
seq_group_metadata_list
,
scheduler_outputs
...
...
vllm/engine/arg_utils.py
View file @
7d2dcce1
...
@@ -173,7 +173,6 @@ class EngineArgs:
...
@@ -173,7 +173,6 @@ class EngineArgs:
default
=
EngineArgs
.
block_size
,
default
=
EngineArgs
.
block_size
,
choices
=
[
8
,
16
,
32
],
choices
=
[
8
,
16
,
32
],
help
=
'token block size'
)
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
parser
.
add_argument
(
'--seed'
,
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
seed
,
default
=
EngineArgs
.
seed
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
7d2dcce1
...
@@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
top_p
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
None
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
...
@@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
temperature
=
self
.
temperature
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_p
=
self
.
top_p
,
min_p
=
self
.
min_p
,
min_p
=
self
.
min_p
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
max_tokens
=
self
.
max_tokens
,
...
@@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
...
@@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
logprobs
:
Optional
[
int
]
=
None
logprobs
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
echo
:
Optional
[
bool
]
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
seed
:
Optional
[
int
]
=
None
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
best_of
:
Optional
[
int
]
=
None
best_of
:
Optional
[
int
]
=
None
...
@@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
...
@@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
top_p
=
self
.
top_p
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
top_k
=
self
.
top_k
,
min_p
=
self
.
min_p
,
min_p
=
self
.
min_p
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
...
...
vllm/model_executor/layers/sampler.py
View file @
7d2dcce1
...
@@ -342,7 +342,9 @@ def _beam_search_sample(
...
@@ -342,7 +342,9 @@ def _beam_search_sample(
def
_multinomial
(
def
_multinomial
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
):
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]]
=
None
,
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# forces a GPU<->CPU sync).
...
@@ -352,7 +354,15 @@ def _multinomial(
...
@@ -352,7 +354,15 @@ def _multinomial(
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1
)
q
=
torch
.
empty_like
(
probs
)
if
seq_groups
is
None
:
q
.
exponential_
()
else
:
sample_idx
=
0
for
(
seq_ids
,
_
),
generator
in
zip
(
seq_groups
,
generators
):
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
generator
)
sample_idx
=
next_sample_idx
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
...
@@ -370,6 +380,7 @@ def _sample(
...
@@ -370,6 +380,7 @@ def _sample(
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
=
{}
sample_metadata
=
{}
multinomial_samples
=
{}
# Counterintiutively, having two loops here is actually faster.
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
# The first loop can run without waiting on GPU<->CPU sync.
...
@@ -385,14 +396,18 @@ def _sample(
...
@@ -385,14 +396,18 @@ def _sample(
is_prompts
,
sample_indices
)
is_prompts
,
sample_indices
)
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
)
:
max_best_of
=
1
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
if
is_prompt
:
if
is_prompt
:
_
,
sampling_params
=
seq_group
_
,
sampling_params
=
seq_group
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
multinomial_samples
=
_multinomial
(
probs
[
sample_indices
],
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
max_best_of
)
"seq_groups"
:
seq_groups
,
"generators"
:
sampling_metadata
.
generators
,
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
,
**
seeded_args
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
else
:
...
@@ -407,9 +422,9 @@ def _sample(
...
@@ -407,9 +422,9 @@ def _sample(
sampling_type
]
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
)
:
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
multinomial_samples
)
multinomial_samples
[
sampling_type
]
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sampling_metadata
.
seq_data
,
sampling_metadata
.
seq_data
,
...
...
vllm/model_executor/sampling_metadata.py
View file @
7d2dcce1
...
@@ -19,6 +19,7 @@ class SamplingMetadata:
...
@@ -19,6 +19,7 @@ class SamplingMetadata:
prompt_lens: Lengths of prompts.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
sampling in other worker processes.
...
@@ -31,6 +32,7 @@ class SamplingMetadata:
...
@@ -31,6 +32,7 @@ class SamplingMetadata:
prompt_lens
:
Optional
[
List
[
int
]],
prompt_lens
:
Optional
[
List
[
int
]],
selected_token_indices
:
torch
.
Tensor
,
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
perform_sampling
:
bool
=
True
,
perform_sampling
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
...
@@ -38,6 +40,7 @@ class SamplingMetadata:
...
@@ -38,6 +40,7 @@ class SamplingMetadata:
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
generators
=
generators
self
.
perform_sampling
=
perform_sampling
self
.
perform_sampling
=
perform_sampling
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
...
...
vllm/sampling_params.py
View file @
7d2dcce1
...
@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
...
@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
class
SamplingType
(
IntEnum
):
class
SamplingType
(
IntEnum
):
GREEDY
=
0
GREEDY
=
0
RANDOM
=
1
RANDOM
=
1
BEAM
=
2
RANDOM_SEED
=
2
BEAM
=
3
LogitsProcessor
=
Callable
[[
List
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]
LogitsProcessor
=
Callable
[[
List
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]
...
@@ -56,6 +57,7 @@ class SamplingParams:
...
@@ -56,6 +57,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
Used in beam search.
...
@@ -101,6 +103,7 @@ class SamplingParams:
...
@@ -101,6 +103,7 @@ class SamplingParams:
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
min_p
:
float
=
0.0
,
min_p
:
float
=
0.0
,
seed
:
Optional
[
int
]
=
None
,
use_beam_search
:
bool
=
False
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
...
@@ -124,6 +127,7 @@ class SamplingParams:
...
@@ -124,6 +127,7 @@ class SamplingParams:
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
min_p
=
min_p
self
.
min_p
=
min_p
self
.
seed
=
seed
self
.
use_beam_search
=
use_beam_search
self
.
use_beam_search
=
use_beam_search
self
.
length_penalty
=
length_penalty
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
early_stopping
=
early_stopping
...
@@ -229,6 +233,8 @@ class SamplingParams:
...
@@ -229,6 +233,8 @@ class SamplingParams:
return
SamplingType
.
BEAM
return
SamplingType
.
BEAM
if
self
.
temperature
<
_SAMPLING_EPS
:
if
self
.
temperature
<
_SAMPLING_EPS
:
return
SamplingType
.
GREEDY
return
SamplingType
.
GREEDY
if
self
.
seed
is
not
None
:
return
SamplingType
.
RANDOM_SEED
return
SamplingType
.
RANDOM
return
SamplingType
.
RANDOM
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
...
@@ -242,6 +248,7 @@ class SamplingParams:
...
@@ -242,6 +248,7 @@ class SamplingParams:
f
"top_p=
{
self
.
top_p
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"seed=
{
self
.
seed
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
...
...
vllm/sequence.py
View file @
7d2dcce1
...
@@ -248,6 +248,14 @@ class Sequence:
...
@@ -248,6 +248,14 @@ class Sequence:
f
"num_blocks=
{
len
(
self
.
logical_token_blocks
)
}
)"
)
f
"num_blocks=
{
len
(
self
.
logical_token_blocks
)
}
)"
)
@
dataclass
class
SequenceGroupState
:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator
:
Optional
=
None
class
SequenceGroup
:
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
"""A group of sequences that are generated from the same prompt.
...
@@ -280,6 +288,7 @@ class SequenceGroup:
...
@@ -280,6 +288,7 @@ class SequenceGroup:
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prefix
:
Optional
[
Prefix
]
=
prefix
self
.
prefix
:
Optional
[
Prefix
]
=
prefix
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
@
property
@
property
def
prompt
(
self
)
->
str
:
def
prompt
(
self
)
->
str
:
...
@@ -397,6 +406,7 @@ class SequenceGroupMetadata:
...
@@ -397,6 +406,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
block_tables: The block tables. (Seq id -> list of physical block
numbers)
numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
prefix: The prefix of the prompt of the sequence group.
"""
"""
...
@@ -410,6 +420,7 @@ class SequenceGroupMetadata:
...
@@ -410,6 +420,7 @@ class SequenceGroupMetadata:
block_tables
:
Dict
[
int
,
List
[
int
]],
block_tables
:
Dict
[
int
,
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix
:
Optional
[
Prefix
]
=
None
,
prefix
:
Optional
[
Prefix
]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
is_prompt
=
is_prompt
self
.
is_prompt
=
is_prompt
...
@@ -418,6 +429,7 @@ class SequenceGroupMetadata:
...
@@ -418,6 +429,7 @@ class SequenceGroupMetadata:
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
...
...
vllm/worker/model_runner.py
View file @
7d2dcce1
...
@@ -389,6 +389,7 @@ class ModelRunner:
...
@@ -389,6 +389,7 @@ class ModelRunner:
)
->
SamplingMetadata
:
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sample_indices_start_idx
=
0
...
@@ -419,6 +420,10 @@ class ModelRunner:
...
@@ -419,6 +420,10 @@ class ModelRunner:
selected_token_indices
.
append
(
selected_token_start_idx
+
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
subquery_len
-
1
)
selected_token_start_idx
+=
max_subquery_len
selected_token_start_idx
+=
max_subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
sampling_params
.
seed
)
else
:
else
:
num_seqs
=
len
(
seq_ids
)
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
selected_token_indices
.
extend
(
...
@@ -432,6 +437,9 @@ class ModelRunner:
...
@@ -432,6 +437,9 @@ class ModelRunner:
categorized_sample_indices_start_idx
+
num_seqs
))
categorized_sample_indices_start_idx
+
num_seqs
))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sample_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
_async_h2d
(
selected_token_indices
,
selected_token_indices
=
_async_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
target_device
=
self
.
device
,
target_device
=
self
.
device
,
...
@@ -454,6 +462,7 @@ class ModelRunner:
...
@@ -454,6 +462,7 @@ class ModelRunner:
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
)
return
sampling_metadata
return
sampling_metadata
...
@@ -536,6 +545,7 @@ class ModelRunner:
...
@@ -536,6 +545,7 @@ class ModelRunner:
prompt_lens
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
metadata_dict
[
"selected_token_indices"
],
selected_token_indices
=
metadata_dict
[
"selected_token_indices"
],
categorized_sample_indices
=
None
,
categorized_sample_indices
=
None
,
generators
=
None
,
perform_sampling
=
False
,
perform_sampling
=
False
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment