Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5cf9254a
Unverified
Commit
5cf9254a
authored
Jul 30, 2024
by
Nick Hill
Committed by
GitHub
Jul 30, 2024
Browse files
[BugFix] Fix use of per-request seed with pipeline parallel (#6698)
parent
f0584036
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
220 additions
and
136 deletions
+220
-136
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+8
-15
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+4
-1
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+53
-1
tests/spec_decode/e2e/test_seed.py
tests/spec_decode/e2e/test_seed.py
+1
-1
tests/spec_decode/test_batch_expansion.py
tests/spec_decode/test_batch_expansion.py
+1
-0
tests/utils.py
tests/utils.py
+31
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+0
-1
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+39
-56
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+2
-2
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+13
-7
vllm/sequence.py
vllm/sequence.py
+0
-12
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+22
-15
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+3
-1
vllm/spec_decode/mlp_speculator_worker.py
vllm/spec_decode/mlp_speculator_worker.py
+3
-1
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+1
-2
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+11
-14
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+2
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+9
-5
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+15
-0
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+2
-1
No files found.
tests/samplers/test_rejection_sampler.py
View file @
5cf9254a
...
@@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
...
@@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
generators
=
[
None
]
*
batch_size
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
,
generators
)
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"frac_seeded"
,
[
0.0
,
0.25
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"frac_seeded"
,
[
0.0
,
0.25
,
0.5
,
1.0
])
...
@@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
...
@@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
results
=
[]
results
=
[]
for
_
in
range
(
n_rep
):
for
_
in
range
(
n_rep
):
generators
=
[
seeded_seqs
=
{
torch
.
Generator
(
i
:
torch
.
Generator
(
device
=
device
).
manual_seed
(
i
)
device
=
device
).
manual_seed
(
i
)
if
seeded_mask
[
i
]
else
None
for
i
in
range
(
batch_size
)
if
seeded_mask
[
i
]
for
i
in
range
(
batch_size
)
}
]
results
.
append
(
results
.
append
(
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
,
generator
s
))
draft_token_ids
,
seeded_seq
s
))
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
if
seeded_mask
[
i
]:
if
seeded_mask
[
i
]:
...
@@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
...
@@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
raise
AssertionError
()
raise
AssertionError
()
oob_token_ids
[
0
][
0
]
=
rogue_token_id
oob_token_ids
[
0
][
0
]
=
rogue_token_id
generators
=
[
None
]
*
batch_size
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
,
generators
)
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
...
@@ -417,15 +414,11 @@ class _CorrectnessTestHelper:
...
@@ -417,15 +414,11 @@ class _CorrectnessTestHelper:
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
"cuda"
).
repeat
(
num_samples
,
1
)
device
=
"cuda"
).
repeat
(
num_samples
,
1
)
# unseeded
generators
=
[
None
]
# Get output tokens via rejection sampling.
# Get output tokens via rejection sampling.
output_token_ids
=
self
.
rejection_sampler
(
target_probs
.
to
(
"cuda"
),
output_token_ids
=
self
.
rejection_sampler
(
target_probs
.
to
(
"cuda"
),
bonus_token_ids
.
to
(
"cuda"
),
bonus_token_ids
.
to
(
"cuda"
),
draft_probs
.
to
(
"cuda"
),
draft_probs
.
to
(
"cuda"
),
draft_token_ids
.
to
(
"cuda"
),
draft_token_ids
.
to
(
"cuda"
))
generators
)
# Remove bonus tokens
# Remove bonus tokens
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
()
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
()
...
...
tests/samplers/test_sampler.py
View file @
5cf9254a
...
@@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
))
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
def
test_sampling
():
def
test_sampling
():
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_group_metadata_list
,
seq_lens
,
seq_lens
,
query_lens
=
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
pin_memory
=
is_pin_memory_available
(),
generators
=
generators
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
...
tests/spec_decode/e2e/test_mlp_correctness.py
View file @
5cf9254a
...
@@ -21,7 +21,8 @@ correctess for the target model outputs.
...
@@ -21,7 +21,8 @@ correctess for the target model outputs.
import
pytest
import
pytest
from
.conftest
import
run_greedy_equality_correctness_test
from
.conftest
import
(
run_equality_correctness_test
,
run_greedy_equality_correctness_test
)
# main model
# main model
MAIN_MODEL
=
"JackFram/llama-160m"
MAIN_MODEL
=
"JackFram/llama-160m"
...
@@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
...
@@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len
=
True
)
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
# Speculative model
"speculative_model"
:
SPEC_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"seed"
:
5
}])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.1
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
def
test_mlp_e2e_seeded_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
,
temperature
:
float
):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
True
,
force_output_len
=
True
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with
pytest
.
raises
(
AssertionError
):
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
False
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
tests/spec_decode/e2e/test_seed.py
View file @
5cf9254a
...
@@ -29,7 +29,7 @@ from .conftest import run_equality_correctness_test
...
@@ -29,7 +29,7 @@ from .conftest import run_equality_correctness_test
"output_len"
,
"output_len"
,
[
[
# Use smaller output len for fast test.
# Use smaller output len for fast test.
1
0
,
2
0
,
])
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
def
test_seeded_consistency
(
baseline_llm_generator
,
test_llm_generator
,
def
test_seeded_consistency
(
baseline_llm_generator
,
test_llm_generator
,
...
...
tests/spec_decode/test_batch_expansion.py
View file @
5cf9254a
...
@@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
...
@@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id
,
input_seq_id
,
target_seq_id
,
target_seq_id
,
token_ids
,
token_ids
,
input_seq_group_metadata
.
sampling_params
,
)
)
assert
output
.
request_id
==
input_seq_group_metadata
.
request_id
assert
output
.
request_id
==
input_seq_group_metadata
.
request_id
...
...
tests/utils.py
View file @
5cf9254a
...
@@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
...
@@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage"
:
completion
.
usage
,
"usage"
:
completion
.
usage
,
})
})
# test seeded random sampling
completion
=
client
.
completions
.
create
(
model
=
model
,
prompt
=
prompt
,
max_tokens
=
5
,
seed
=
33
,
temperature
=
1.0
)
results
.
append
({
"test"
:
"seeded_sampling"
,
"text"
:
completion
.
choices
[
0
].
text
,
"finish_reason"
:
completion
.
choices
[
0
].
finish_reason
,
"usage"
:
completion
.
usage
,
})
# test seeded random sampling with multiple prompts
completion
=
client
.
completions
.
create
(
model
=
model
,
prompt
=
[
prompt
,
prompt
],
max_tokens
=
5
,
seed
=
33
,
temperature
=
1.0
)
results
.
append
({
"test"
:
"seeded_sampling"
,
"text"
:
[
choice
.
text
for
choice
in
completion
.
choices
],
"finish_reason"
:
[
choice
.
finish_reason
for
choice
in
completion
.
choices
],
"usage"
:
completion
.
usage
,
})
# test simple list
# test simple list
batch
=
client
.
completions
.
create
(
batch
=
client
.
completions
.
create
(
model
=
model
,
model
=
model
,
...
...
vllm/core/scheduler.py
View file @
5cf9254a
...
@@ -1029,7 +1029,6 @@ class Scheduler:
...
@@ -1029,7 +1029,6 @@ class Scheduler:
token_chunk_size
=
token_chunk_size
,
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
state
=
seq_group
.
state
,
# `multi_modal_data` will only be present for the 1st comm
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# between engine and worker.
# the subsequent comms can still use delta, but
# the subsequent comms can still use delta, but
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
5cf9254a
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.jit
import
torch.jit
...
@@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
tokens proposed by the draft model using the probability of each token
...
@@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
probabilities.
probabilities.
shape = [batch_size, num_speculative_tokens]
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
Returns:
output_token_ids: The token ids sampled via rejection sampling,
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
or -1 if unable to sample a token because the previous token
...
@@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
,
target_probs
,
draft_probs
,
draft_probs
,
draft_token_ids
,
draft_token_ids
,
generator
s
,
seeded_seq
s
,
))
))
output_token_ids
=
self
.
_create_output
(
output_token_ids
=
self
.
_create_output
(
...
@@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
"""Perform modified rejection sampling on each sequence.
...
@@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# shape [batch_size, k]
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
,
generator
s
)
draft_token_ids
,
seeded_seq
s
)
recovered_probs
=
self
.
_get_recovered_probs
(
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
generators
,
k
=
k
)
# NOTE: the recovered_probs are overwritten by this method.
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids
=
_multinomial
(
recovered_token_ids
=
_multinomial
(
recovered_probs
,
recovered_probs
,
num_samples
=
1
,
num_samples
=
1
,
k
=
k
,
k
=
k
,
generators
=
generators
,
seeded_seqs
=
seeded_seqs
or
{},
seed_indices
=
seed_indices
,
# this arg is unused when None but torch.jit requires a list
non_seed_indices
=
non_seed_indices
or
[],
).
reshape
(
batch_size
,
k
)
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
return
accepted
,
recovered_token_ids
...
@@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
True, then a token can be accepted, else it should be
...
@@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
draft_token_ids
]
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
if
not
seeded_seqs
:
generators
)
if
len
(
seed_indices
)
==
0
:
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
else
:
else
:
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
for
idx
in
seed_indices
:
non_seeded_indices
=
[]
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
,
k
,
dtype
=
self
.
probs_dtype
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
,
device
=
target_probs
.
device
,
generator
=
generators
[
idx
])
generator
=
generator
)
if
non_seeded_indices
:
if
non_seed_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
uniform_rand
[
non_seed_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
len
(
non_seed_indices
),
k
,
k
,
dtype
=
self
.
probs_dtype
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
device
=
target_probs
.
device
)
...
@@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@
staticmethod
def
_split_batch_by_seeded
(
generators
:
List
[
Optional
[
torch
.
Generator
]],
k
:
int
=
1
,
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
if
all
(
generator
is
None
for
generator
in
generators
):
seed_indices
:
List
[
int
]
=
[]
non_seed_indices
:
Optional
[
List
[
int
]]
=
None
else
:
seed_indices
,
non_seed_indices
=
[],
[]
for
i
,
generator
in
enumerate
(
generators
):
if
generator
is
None
:
non_seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
else
:
seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
return
seed_indices
,
non_seed_indices
# torch.multinomial forces a GPU<->CPU sync.
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Therefore, we use an optimized implementation instead that skips the sync.
...
@@ -304,9 +282,7 @@ def _multinomial(
...
@@ -304,9 +282,7 @@ def _multinomial(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
k
:
int
,
k
:
int
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Dict
[
int
,
torch
.
Generator
],
seed_indices
:
List
[
int
],
non_seed_indices
:
List
[
int
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
if
num_samples
>
1
:
...
@@ -315,13 +291,20 @@ def _multinomial(
...
@@ -315,13 +291,20 @@ def _multinomial(
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
)
q
=
torch
.
empty_like
(
probs
)
if
len
(
seed_indices
)
==
0
:
if
not
seeded_seqs
:
q
.
exponential_
(
1.0
)
q
.
exponential_
(
1.0
)
else
:
else
:
q
[
non_seed_indices
].
exponential_
(
1.0
)
non_seeded_indices
:
List
[
int
]
=
[]
for
idx
in
seed_indices
:
start
=
0
q
[
idx
].
exponential_
(
1.0
,
generator
=
generators
[
idx
//
k
])
for
idx
in
range
(
len
(
q
)
//
k
):
end
=
start
+
k
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
extend
(
list
(
range
(
start
,
end
)))
else
:
q
[
start
:
end
].
exponential_
(
1.0
,
generator
=
generator
)
start
=
end
q
[
non_seeded_indices
].
exponential_
(
1.0
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
5cf9254a
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Lis
t
,
Optional
from
typing
import
Dic
t
,
Optional
import
torch
import
torch
import
torch.jit
import
torch.jit
...
@@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
...
@@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
vllm/model_executor/sampling_metadata.py
View file @
5cf9254a
...
@@ -118,6 +118,7 @@ class SamplingMetadata:
...
@@ -118,6 +118,7 @@ class SamplingMetadata:
query_lens
:
Optional
[
List
[
int
]],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
device
:
str
,
pin_memory
:
bool
,
pin_memory
:
bool
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
)
->
"SamplingMetadata"
:
)
->
"SamplingMetadata"
:
(
(
seq_groups
,
seq_groups
,
...
@@ -125,7 +126,7 @@ class SamplingMetadata:
...
@@ -125,7 +126,7 @@ class SamplingMetadata:
categorized_sample_indices
,
categorized_sample_indices
,
num_prompts
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
)
device
,
generators
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
target_device
=
device
,
target_device
=
device
,
...
@@ -160,6 +161,7 @@ def _prepare_seq_groups(
...
@@ -160,6 +161,7 @@ def _prepare_seq_groups(
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
"""Prepare sequence groups and indices for sampling.
"""Prepare sequence groups and indices for sampling.
...
@@ -170,8 +172,10 @@ def _prepare_seq_groups(
...
@@ -170,8 +172,10 @@ def _prepare_seq_groups(
Index of prompt len should match with seq_group_metadata_list.
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
device: A device to use for random number generator
s
,
`SequenceGroupToSample.generator`.
`SequenceGroupToSample.generator`.
generators: A store of per-request random number generators used
for seeded requests.
Returns:
Returns:
seq_groups: A list of sequence group to sample.
seq_groups: A list of sequence group to sample.
...
@@ -217,8 +221,10 @@ def _prepare_seq_groups(
...
@@ -217,8 +221,10 @@ def _prepare_seq_groups(
if
seq_group_metadata
.
is_prompt
:
if
seq_group_metadata
.
is_prompt
:
if
sampling_params
.
seed
is
not
None
:
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
device
=
device
).
manual_seed
(
sampling_params
.
seed
)
sampling_params
.
seed
)
if
generators
is
not
None
:
generators
[
seq_group_metadata
.
request_id
]
=
generator
num_prompts
+=
1
num_prompts
+=
1
num_prefill_sample
=
len
(
seq_ids
)
num_prefill_sample
=
len
(
seq_ids
)
...
@@ -235,6 +241,9 @@ def _prepare_seq_groups(
...
@@ -235,6 +241,9 @@ def _prepare_seq_groups(
prompt_logprob_len
=
0
prompt_logprob_len
=
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
if
sampling_params
.
seed
is
not
None
and
generators
is
not
None
:
generator
=
generators
.
get
(
seq_group_metadata
.
request_id
)
# Update indices to select from the model output.
# Update indices to select from the model output.
"""
"""
This blocks computes selected_token_indices which is used in the
This blocks computes selected_token_indices which is used in the
...
@@ -279,9 +288,6 @@ def _prepare_seq_groups(
...
@@ -279,9 +288,6 @@ def _prepare_seq_groups(
logit_idx
+=
sample_len
logit_idx
+=
sample_len
sample_idx
+=
sample_len
sample_idx
+=
sample_len
if
sampling_params
.
seed
is
not
None
:
generator
=
seq_group_metadata
.
state
.
generator
seq_groups
.
append
(
seq_groups
.
append
(
SequenceGroupToSample
(
SequenceGroupToSample
(
seq_ids
=
seq_ids
,
seq_ids
=
seq_ids
,
...
...
vllm/sequence.py
View file @
5cf9254a
...
@@ -411,14 +411,6 @@ class Sequence:
...
@@ -411,14 +411,6 @@ class Sequence:
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
@
dataclass
class
SequenceGroupState
:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator
:
Optional
=
None
# type: ignore
class
SequenceGroup
:
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
"""A group of sequences that are generated from the same prompt.
...
@@ -461,7 +453,6 @@ class SequenceGroup:
...
@@ -461,7 +453,6 @@ class SequenceGroup:
time_in_queue
=
None
)
time_in_queue
=
None
)
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
embeddings
=
embeddings
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
self
.
pooling_params
=
pooling_params
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
...
@@ -648,7 +639,6 @@ class SequenceGroupMetadata:
...
@@ -648,7 +639,6 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
(SequenceGroup.encoder_seq). Should be None
...
@@ -674,7 +664,6 @@ class SequenceGroupMetadata:
...
@@ -674,7 +664,6 @@ class SequenceGroupMetadata:
token_chunk_size
:
Optional
[
int
]
=
None
,
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
...
@@ -690,7 +679,6 @@ class SequenceGroupMetadata:
...
@@ -690,7 +679,6 @@ class SequenceGroupMetadata:
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
computed_block_nums
=
computed_block_nums
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
encoder_seq_data
=
encoder_seq_data
self
.
encoder_seq_data
=
encoder_seq_data
self
.
cross_block_table
=
cross_block_table
self
.
cross_block_table
=
cross_block_table
self
.
_token_chunk_size
=
token_chunk_size
self
.
_token_chunk_size
=
token_chunk_size
...
...
vllm/spec_decode/batch_expansion.py
View file @
5cf9254a
...
@@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
...
@@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
import
torch
import
torch
from
vllm
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceGroupState
,
SequenceGroupMetadata
,
get_all_seq_ids
)
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
...
@@ -16,6 +16,8 @@ SeqId = int
...
@@ -16,6 +16,8 @@ SeqId = int
TargetSeqId
=
int
TargetSeqId
=
int
TokenId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
BatchExpansionTop1Scorer
(
SpeculativeScorer
):
class
BatchExpansionTop1Scorer
(
SpeculativeScorer
):
"""Implements a speculative scorer that uses batch expansion to get
"""Implements a speculative scorer that uses batch expansion to get
...
@@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score
=
self
.
_get_token_ids_to_score
(
token_ids_to_score
=
self
.
_get_token_ids_to_score
(
proposal_token_ids
[
batch_index
])
proposal_token_ids
[
batch_index
])
# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params
=
input_seq_group_metadata
.
sampling_params
non_bonus_sampling_params
=
DEFAULT_SIMPLE_SAMPLING_PARAMS
\
if
sampling_params
.
temperature
else
sampling_params
target_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
target_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
token_ids
in
token_ids_to_score
:
last_index
=
len
(
token_ids_to_score
)
-
1
for
i
,
token_ids
in
enumerate
(
token_ids_to_score
):
target_sampling_params
=
sampling_params
if
i
==
last_index
\
else
non_bonus_sampling_params
target_seq_group_metadata_list
.
append
(
target_seq_group_metadata_list
.
append
(
self
.
_create_single_target_seq_group_metadata
(
self
.
_create_single_target_seq_group_metadata
(
input_seq_group_metadata
,
input_seq_group_metadata
,
input_seq_id
,
input_seq_id
,
next
(
target_seq_ids_iter
),
next
(
target_seq_ids_iter
),
token_ids
,
token_ids
,
sampling_params
=
target_sampling_params
,
))
))
return
target_seq_group_metadata_list
return
target_seq_group_metadata_list
@
staticmethod
def
_create_single_target_seq_group_metadata
(
def
_create_single_target_seq_group_metadata
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_id
:
SeqId
,
seq_id
:
SeqId
,
target_seq_id
:
TargetSeqId
,
target_seq_id
:
TargetSeqId
,
token_ids
:
List
[
TokenId
],
token_ids
:
List
[
TokenId
],
sampling_params
:
SamplingParams
,
)
->
SequenceGroupMetadata
:
)
->
SequenceGroupMetadata
:
"""Create a single target SequenceGroupMetadata.
"""Create a single target SequenceGroupMetadata.
...
@@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for
data
in
new_seq_data_dict
.
values
():
for
data
in
new_seq_data_dict
.
values
():
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
if
(
seq_group_metadata
.
state
is
not
None
and
seq_group_metadata
.
state
.
generator
is
not
None
):
generator
=
torch
.
Generator
(
device
=
seq_group_metadata
.
state
.
generator
.
device
)
generator
.
set_state
(
seq_group_metadata
.
state
.
generator
.
get_state
())
state
=
SequenceGroupState
(
generator
=
generator
)
else
:
state
=
None
return
SequenceGroupMetadata
(
return
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
seq_data
=
new_seq_data_dict
,
seq_data
=
new_seq_data_dict
,
sampling_params
=
seq_group_metadata
.
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
block_tables
=
{
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
},
lora_request
=
None
,
lora_request
=
None
,
token_chunk_size
=
1
,
token_chunk_size
=
1
,
state
=
state
,
)
)
def
_split_scoring_output
(
def
_split_scoring_output
(
...
...
vllm/spec_decode/medusa_worker.py
View file @
5cf9254a
...
@@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
)
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
...
...
vllm/spec_decode/mlp_speculator_worker.py
View file @
5cf9254a
...
@@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
...
@@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
(
input_tokens
,
seq_lens
,
(
input_tokens
,
seq_lens
,
query_lens
)
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
query_lens
)
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
)
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
input_ids
=
input_tokens
,
input_ids
=
input_tokens
,
...
...
vllm/spec_decode/ngram_worker.py
View file @
5cf9254a
...
@@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
...
@@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
class
NGramWorker
(
NonLLMProposerWorkerBase
,
LoraNotSupportedWorkerBase
):
class
NGramWorker
(
NonLLMProposerWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding,
Current NGramWorker only implements prompt lookup decoding,
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
5cf9254a
...
@@ -213,6 +213,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -213,6 +213,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
"""
self
.
proposer_worker
=
proposer_worker
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
scorer_worker
=
scorer_worker
scorer_runner
=
getattr
(
self
.
scorer_worker
,
"model_runner"
,
None
)
self
.
generators
=
scorer_runner
.
get_generators
(
)
if
scorer_runner
else
None
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_allow_zero_draft_token_step
=
allow_zero_draft_token_step
self
.
_allow_zero_draft_token_step
=
allow_zero_draft_token_step
...
@@ -591,20 +594,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -591,20 +594,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
# Sampler arguments
# Sampler arguments
sampler_extra_kwargs
=
{}
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
isinstance
(
self
.
spec_decode_sampler
,
if
self
.
generators
and
isinstance
(
self
.
spec_decode_sampler
,
SpecDecodeStochasticBaseSampler
):
SpecDecodeStochasticBaseSampler
):
sampler_extra_kwargs
[
"seeded_seqs"
]
=
{
# Get sequence group state
idx
:
self
.
generators
[
sgm
.
request_id
]
generators
=
[]
for
idx
,
sgm
in
enumerate
(
seq_group_metadata_list
)
for
seq_group_metadata
in
seq_group_metadata_list
:
if
sgm
.
sampling_params
.
seed
is
not
None
if
(
seq_group_metadata
.
state
is
not
None
}
and
seq_group_metadata
.
state
.
generator
is
not
None
):
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
else
:
generators
.
append
(
None
)
sampler_extra_kwargs
[
"generators"
]
=
generators
accepted_token_ids
=
self
.
spec_decode_sampler
(
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_probs
=
proposal_verifier_probs
,
target_probs
=
proposal_verifier_probs
,
...
...
vllm/worker/cpu_model_runner.py
View file @
5cf9254a
...
@@ -337,7 +337,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
...
@@ -337,7 +337,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# just use seq_lens instead.
# just use seq_lens instead.
seq_lens
,
seq_lens
,
self
.
device
,
self
.
device
,
pin_memory
=
False
)
pin_memory
=
False
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
return
CPUModelInput
(
return
CPUModelInput
(
input_tokens
=
input_tokens
,
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
input_positions
=
input_positions
,
...
...
vllm/worker/model_runner.py
View file @
5cf9254a
...
@@ -1264,11 +1264,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1264,11 +1264,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
"""
model_input
=
self
.
_prepare_model_input_tensors
(
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
seq_group_metadata_list
,
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
if
get_pp_group
().
is_last_rank
:
model_input
.
seq_lens
,
# Sampling metadata is only required for the final pp group
model_input
.
query_lens
,
generators
=
self
.
get_generators
(
finished_requests_ids
)
self
.
device
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
self
.
pin_memory
)
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
self
.
pin_memory
,
generators
)
else
:
sampling_metadata
=
None
is_prompt
=
(
seq_group_metadata_list
[
0
].
is_prompt
is_prompt
=
(
seq_group_metadata_list
[
0
].
is_prompt
if
seq_group_metadata_list
else
None
)
if
seq_group_metadata_list
else
None
)
return
dataclasses
.
replace
(
model_input
,
return
dataclasses
.
replace
(
model_input
,
...
...
vllm/worker/model_runner_base.py
View file @
5cf9254a
...
@@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass.
ModelRunnerInputBase subclass.
"""
"""
# Map of request_id -> generator used for seeded random sampling
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
@
abstractmethod
@
abstractmethod
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
self
,
...
@@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]):
Execute the model on the given input.
Execute the model on the given input.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
get_generators
(
self
,
finished_request_ids
:
Optional
[
List
[
str
]]
=
None
):
"""
Return dict of per-request generators used for random sampling.
"""
# Clean up generators from completed requests
if
finished_request_ids
:
for
request_id
in
finished_request_ids
:
self
.
generators
.
pop
(
request_id
,
None
)
return
self
.
generators
vllm/worker/neuron_model_runner.py
View file @
5cf9254a
...
@@ -219,7 +219,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -219,7 +219,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# just use seq_lens instead.
# just use seq_lens instead.
seq_lens
,
seq_lens
,
self
.
device
,
self
.
device
,
self
.
pin_memory
)
self
.
pin_memory
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
return
ModelInputForNeuron
(
input_tokens
=
input_tokens
,
return
ModelInputForNeuron
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
input_positions
=
input_positions
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment