Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7d2dcce1
Unverified
Commit
7d2dcce1
authored
Feb 21, 2024
by
Nick Hill
Committed by
GitHub
Feb 21, 2024
Browse files
Support per-request seed (#2514)
parent
dc903e70
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
289 additions
and
84 deletions
+289
-84
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+147
-75
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+82
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+22
-7
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+3
-0
vllm/sampling_params.py
vllm/sampling_params.py
+8
-1
vllm/sequence.py
vllm/sequence.py
+12
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+10
-0
No files found.
tests/samplers/test_sampler.py
View file @
7d2dcce1
import
random
from
typing
import
Tuple
from
typing
import
Tuple
,
List
from
unittest.mock
import
patch
import
pytest
import
torch
from
transformers
import
GenerationConfig
,
GenerationMixin
from
typing
import
Optional
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.utils
import
set_random_seed
...
...
@@ -46,15 +47,13 @@ CUDA_DEVICES = [
]
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_greedy
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
def
_do_sample
(
batch_size
:
int
,
input_tensor
:
torch
.
Tensor
,
sampler
:
MockLogitsSampler
,
model_runner
:
ModelRunner
,
sampling_params
:
SamplingParams
,
):
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
...
...
@@ -63,7 +62,7 @@ def test_sampler_all_greedy(seed: int, device: str):
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
S
ampling
P
arams
(
temperature
=
0
,
)
,
sampling_params
=
s
ampling
_p
arams
,
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
...
...
@@ -71,9 +70,23 @@ def test_sampler_all_greedy(seed: int, device: str):
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
return
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_greedy
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
...
...
@@ -94,28 +107,40 @@ def test_sampler_all_random(seed: int, device: str):
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
seq_group_metadata_list
=
[]
prompt_lens
=
[]
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
)
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random_seed
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
fake_logits
[
i
,
i
]
=
1e2
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
...
...
@@ -123,6 +148,31 @@ def test_sampler_all_random(seed: int, device: str):
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random_seed_deterministic
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
first_sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
second_sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
assert
first_sampler_output
==
second_sampler_output
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_beam
(
seed
:
int
,
device
:
str
):
...
...
@@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
best_of
=
2
,
use_beam_search
=
True
,
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
best_of
=
2
,
use_beam_search
=
True
,
)
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
sampling_params
)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
...
...
@@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size
)
seq_group_metadata_list
=
[]
expected_tokens
=
[]
expected_tokens
:
List
[
Optional
[
List
[
int
]]]
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
n
=
1
sampling_type
=
random
.
randint
(
0
,
2
)
expected
:
Optional
[
List
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
3
)
if
sampling_type
==
0
:
sampling_params
=
SamplingParams
(
temperature
=
0
)
elif
sampling_type
==
1
:
expected
=
[
torch
.
argmax
(
fake_logits
[
i
],
dim
=-
1
).
item
()]
elif
sampling_type
in
(
1
,
2
):
n
=
random
.
randint
(
1
,
10
)
sampling_params
=
SamplingParams
(
temperature
=
random
.
random
()
+
0.1
,
...
...
@@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
n
=
n
,
presence_penalty
=
random
.
randint
(
0
,
1
),
)
if
sampling_type
==
2
:
sampling_params
.
seed
=
random
.
randint
(
0
,
10000
)
else
:
for
idx
in
range
(
n
):
fake_logits
[
i
,
i
+
idx
]
=
1e2
expected
=
list
(
range
(
i
,
i
+
n
))
else
:
sampling_params
=
SamplingParams
(
temperature
=
0
,
use_beam_search
=
True
,
best_of
=
2
)
for
idx
in
range
(
n
):
fake_logits
[
i
,
i
+
idx
]
=
1e2
expected_tokens
.
append
(
i
+
idx
)
expected_tokens
.
append
(
expected
)
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
...
...
@@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
if
seq_group_metadata_list
[
i
].
sampling_params
.
use_beam_search
:
continue
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
in
expected_tokens
def
test_sampling
(
model_runner
:
ModelRunner
):
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
for
i
,
(
sequence_output
,
metadata
)
in
enumerate
(
zip
(
sampler_output
,
seq_group_metadata_list
)):
if
metadata
.
sampling_params
.
use_beam_search
:
continue
if
metadata
.
sampling_params
.
seed
is
not
None
\
and
expected_tokens
[
i
]
is
None
:
# Record seeded random result to compare with results of second invocation
expected_tokens
[
i
]
=
[
nth_output
.
output_token
for
nth_output
in
sequence_output
.
samples
]
continue
for
n
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
if
metadata
.
sampling_params
.
temperature
==
0
or
metadata
.
sampling_params
.
seed
is
not
None
:
# Ensure exact matches for greedy or random with seed
assert
nth_output
.
output_token
==
expected_tokens
[
i
][
n
]
else
:
# For non-seeded random check that one of the high-logit tokens were chosen
assert
nth_output
.
output_token
in
expected_tokens
[
i
]
# Test batch
test_sampling
(
model_runner
)
# Shuffle the batch and resample
target_index
=
list
(
range
(
batch_size
))
for
list_to_shuffle
in
(
target_index
,
seq_group_metadata_list
,
expected_tokens
,
prompt_lens
):
random
.
Random
(
seed
).
shuffle
(
list_to_shuffle
)
target_index
=
torch
.
tensor
(
target_index
)
input_tensor
.
data
=
input_tensor
.
index_select
(
0
,
target_index
)
fake_logits
.
data
=
fake_logits
.
index_select
(
0
,
target_index
)
# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
test_sampling
(
model_runner
)
del
model_runner
...
...
tests/samplers/test_seeded_generate.py
0 → 100644
View file @
7d2dcce1
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py --forked`.
"""
import
copy
import
random
from
itertools
import
combinations
import
pytest
from
vllm.model_executor.utils
import
set_random_seed
from
vllm
import
SamplingParams
MODEL
=
"facebook/opt-125m"
RANDOM_SEEDS
=
list
(
range
(
5
))
@
pytest
.
fixture
def
vllm_model
(
vllm_runner
):
vllm_model
=
vllm_runner
(
MODEL
,
dtype
=
"half"
)
yield
vllm_model
del
vllm_model
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_random_sample_with_seed
(
vllm_model
,
example_prompts
,
seed
:
int
,
)
->
None
:
set_random_seed
(
seed
)
sampling_params
=
SamplingParams
(
# Parameters to ensure sufficient randomness
temperature
=
2.0
,
top_p
=
min
(
random
.
random
()
+
0.3
,
1
),
top_k
=
random
.
randint
(
5
,
20
),
n
=
random
.
randint
(
1
,
10
),
presence_penalty
=
random
.
randint
(
0
,
1
),
max_tokens
=
8
,
ignore_eos
=
True
,
)
sampling_params_seed_1
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_1
.
seed
=
100
sampling_params_seed_2
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_2
.
seed
=
200
llm
=
vllm_model
.
model
for
prompt
in
example_prompts
:
for
params
in
(
sampling_params
,
sampling_params_seed_1
,
sampling_params_seed_2
,
sampling_params
,
sampling_params_seed_1
,
sampling_params_seed_2
,
):
llm
.
_add_request
(
prompt
=
prompt
,
prompt_token_ids
=
None
,
sampling_params
=
params
,
)
results
=
llm
.
_run_engine
(
use_tqdm
=
False
)
all_outputs
=
[[
out
.
token_ids
for
out
in
output
.
outputs
]
for
output
in
results
]
for
i
in
range
(
0
,
len
(
example_prompts
),
6
):
outputs
=
all_outputs
[
i
:
i
+
6
]
# verify all non-seeded requests differ
for
output_a
,
output_b
in
combinations
(
(
outputs
[
0
],
outputs
[
1
],
outputs
[
2
],
outputs
[
3
]),
2
,
):
assert
output_a
!=
output_b
# verify requests with the same seed match
assert
outputs
[
1
]
==
outputs
[
4
]
assert
outputs
[
2
]
==
outputs
[
5
]
vllm/core/scheduler.py
View file @
7d2dcce1
...
...
@@ -387,6 +387,7 @@ class Scheduler:
block_tables
=
block_tables
,
lora_request
=
seq_group
.
lora_request
,
prefix
=
seq_group
.
prefix
,
state
=
seq_group
.
state
,
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
...
...
vllm/engine/arg_utils.py
View file @
7d2dcce1
...
...
@@ -173,7 +173,6 @@ class EngineArgs:
default
=
EngineArgs
.
block_size
,
choices
=
[
8
,
16
,
32
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
EngineArgs
.
seed
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
7d2dcce1
...
...
@@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
None
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
...
...
@@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
min_p
=
self
.
min_p
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
...
...
@@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
logprobs
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
seed
:
Optional
[
int
]
=
None
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
best_of
:
Optional
[
int
]
=
None
...
...
@@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
min_p
=
self
.
min_p
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
ignore_eos
=
self
.
ignore_eos
,
...
...
vllm/model_executor/layers/sampler.py
View file @
7d2dcce1
...
...
@@ -342,7 +342,9 @@ def _beam_search_sample(
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
):
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]]
=
None
,
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
...
...
@@ -352,7 +354,15 @@ def _multinomial(
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1
)
q
=
torch
.
empty_like
(
probs
)
if
seq_groups
is
None
:
q
.
exponential_
()
else
:
sample_idx
=
0
for
(
seq_ids
,
_
),
generator
in
zip
(
seq_groups
,
generators
):
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
generator
)
sample_idx
=
next_sample_idx
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
...
...
@@ -370,6 +380,7 @@ def _sample(
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
=
{}
multinomial_samples
=
{}
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
...
...
@@ -385,14 +396,18 @@ def _sample(
is_prompts
,
sample_indices
)
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
)
:
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
if
is_prompt
:
_
,
sampling_params
=
seq_group
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
multinomial_samples
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
)
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
"seq_groups"
:
seq_groups
,
"generators"
:
sampling_metadata
.
generators
,
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
,
**
seeded_args
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
...
...
@@ -407,9 +422,9 @@ def _sample(
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
)
:
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
multinomial_samples
)
multinomial_samples
[
sampling_type
]
)
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sampling_metadata
.
seq_data
,
...
...
vllm/model_executor/sampling_metadata.py
View file @
7d2dcce1
...
...
@@ -19,6 +19,7 @@ class SamplingMetadata:
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
...
...
@@ -31,6 +32,7 @@ class SamplingMetadata:
prompt_lens
:
Optional
[
List
[
int
]],
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
perform_sampling
:
bool
=
True
,
)
->
None
:
self
.
seq_groups
=
seq_groups
...
...
@@ -38,6 +40,7 @@ class SamplingMetadata:
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
generators
=
generators
self
.
perform_sampling
=
perform_sampling
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
...
...
vllm/sampling_params.py
View file @
7d2dcce1
...
...
@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
class
SamplingType
(
IntEnum
):
GREEDY
=
0
RANDOM
=
1
BEAM
=
2
RANDOM_SEED
=
2
BEAM
=
3
LogitsProcessor
=
Callable
[[
List
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]
...
...
@@ -56,6 +57,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
...
...
@@ -101,6 +103,7 @@ class SamplingParams:
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
min_p
:
float
=
0.0
,
seed
:
Optional
[
int
]
=
None
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
...
...
@@ -124,6 +127,7 @@ class SamplingParams:
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
min_p
=
min_p
self
.
seed
=
seed
self
.
use_beam_search
=
use_beam_search
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
...
...
@@ -229,6 +233,8 @@ class SamplingParams:
return
SamplingType
.
BEAM
if
self
.
temperature
<
_SAMPLING_EPS
:
return
SamplingType
.
GREEDY
if
self
.
seed
is
not
None
:
return
SamplingType
.
RANDOM_SEED
return
SamplingType
.
RANDOM
def
__repr__
(
self
)
->
str
:
...
...
@@ -242,6 +248,7 @@ class SamplingParams:
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"seed=
{
self
.
seed
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
...
...
vllm/sequence.py
View file @
7d2dcce1
...
...
@@ -248,6 +248,14 @@ class Sequence:
f
"num_blocks=
{
len
(
self
.
logical_token_blocks
)
}
)"
)
@
dataclass
class
SequenceGroupState
:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator
:
Optional
=
None
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
...
...
@@ -280,6 +288,7 @@ class SequenceGroup:
self
.
lora_request
=
lora_request
self
.
prefix
:
Optional
[
Prefix
]
=
prefix
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
@
property
def
prompt
(
self
)
->
str
:
...
...
@@ -397,6 +406,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
...
...
@@ -410,6 +420,7 @@ class SequenceGroupMetadata:
block_tables
:
Dict
[
int
,
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix
:
Optional
[
Prefix
]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
is_prompt
=
is_prompt
...
...
@@ -418,6 +429,7 @@ class SequenceGroupMetadata:
self
.
block_tables
=
block_tables
self
.
lora_request
=
lora_request
self
.
prefix
=
prefix
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
@
property
def
lora_int_id
(
self
)
->
int
:
...
...
vllm/worker/model_runner.py
View file @
7d2dcce1
...
...
@@ -389,6 +389,7 @@ class ModelRunner:
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
...
...
@@ -419,6 +420,10 @@ class ModelRunner:
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
selected_token_start_idx
+=
max_subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
...
...
@@ -432,6 +437,9 @@ class ModelRunner:
categorized_sample_indices_start_idx
+
num_seqs
))
categorized_sample_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
_async_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
self
.
device
,
...
...
@@ -454,6 +462,7 @@ class ModelRunner:
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
...
...
@@ -536,6 +545,7 @@ class ModelRunner:
prompt_lens
=
None
,
selected_token_indices
=
metadata_dict
[
"selected_token_indices"
],
categorized_sample_indices
=
None
,
generators
=
None
,
perform_sampling
=
False
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment