Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5cf9254a
Unverified
Commit
5cf9254a
authored
Jul 30, 2024
by
Nick Hill
Committed by
GitHub
Jul 30, 2024
Browse files
[BugFix] Fix use of per-request seed with pipeline parallel (#6698)
parent
f0584036
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
220 additions
and
136 deletions
+220
-136
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+8
-15
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+4
-1
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+53
-1
tests/spec_decode/e2e/test_seed.py
tests/spec_decode/e2e/test_seed.py
+1
-1
tests/spec_decode/test_batch_expansion.py
tests/spec_decode/test_batch_expansion.py
+1
-0
tests/utils.py
tests/utils.py
+31
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+0
-1
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+39
-56
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+2
-2
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+13
-7
vllm/sequence.py
vllm/sequence.py
+0
-12
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+22
-15
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+3
-1
vllm/spec_decode/mlp_speculator_worker.py
vllm/spec_decode/mlp_speculator_worker.py
+3
-1
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+1
-2
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+11
-14
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+2
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+9
-5
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+15
-0
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+2
-1
No files found.
tests/samplers/test_rejection_sampler.py
View file @
5cf9254a
...
...
@@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
generators
=
[
None
]
*
batch_size
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
,
generators
)
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"frac_seeded"
,
[
0.0
,
0.25
,
0.5
,
1.0
])
...
...
@@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
results
=
[]
for
_
in
range
(
n_rep
):
generators
=
[
torch
.
Generator
(
device
=
device
).
manual_seed
(
i
)
if
seeded_mask
[
i
]
else
None
for
i
in
range
(
batch_size
)
]
seeded_seqs
=
{
i
:
torch
.
Generator
(
device
=
device
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
if
seeded_mask
[
i
]
}
results
.
append
(
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
,
generator
s
))
draft_token_ids
,
seeded_seq
s
))
for
i
in
range
(
batch_size
):
if
seeded_mask
[
i
]:
...
...
@@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
raise
AssertionError
()
oob_token_ids
[
0
][
0
]
=
rogue_token_id
generators
=
[
None
]
*
batch_size
with
pytest
.
raises
(
AssertionError
):
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
,
generators
)
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
...
...
@@ -417,15 +414,11 @@ class _CorrectnessTestHelper:
dtype
=
torch
.
int64
,
device
=
"cuda"
).
repeat
(
num_samples
,
1
)
# unseeded
generators
=
[
None
]
# Get output tokens via rejection sampling.
output_token_ids
=
self
.
rejection_sampler
(
target_probs
.
to
(
"cuda"
),
bonus_token_ids
.
to
(
"cuda"
),
draft_probs
.
to
(
"cuda"
),
draft_token_ids
.
to
(
"cuda"
),
generators
)
draft_token_ids
.
to
(
"cuda"
))
# Remove bonus tokens
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
()
...
...
tests/samplers/test_sampler.py
View file @
5cf9254a
...
...
@@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
def
test_sampling
():
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
pin_memory
=
is_pin_memory_available
(),
generators
=
generators
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
...
...
tests/spec_decode/e2e/test_mlp_correctness.py
View file @
5cf9254a
...
...
@@ -21,7 +21,8 @@ correctess for the target model outputs.
import
pytest
from
.conftest
import
run_greedy_equality_correctness_test
from
.conftest
import
(
run_equality_correctness_test
,
run_greedy_equality_correctness_test
)
# main model
MAIN_MODEL
=
"JackFram/llama-160m"
...
...
@@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
# Speculative model
"speculative_model"
:
SPEC_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"seed"
:
5
}])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.1
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
def
test_mlp_e2e_seeded_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
,
temperature
:
float
):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
True
,
force_output_len
=
True
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with
pytest
.
raises
(
AssertionError
):
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
False
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
...
...
tests/spec_decode/e2e/test_seed.py
View file @
5cf9254a
...
...
@@ -29,7 +29,7 @@ from .conftest import run_equality_correctness_test
"output_len"
,
[
# Use smaller output len for fast test.
1
0
,
2
0
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
def
test_seeded_consistency
(
baseline_llm_generator
,
test_llm_generator
,
...
...
tests/spec_decode/test_batch_expansion.py
View file @
5cf9254a
...
...
@@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id
,
target_seq_id
,
token_ids
,
input_seq_group_metadata
.
sampling_params
,
)
assert
output
.
request_id
==
input_seq_group_metadata
.
request_id
...
...
tests/utils.py
View file @
5cf9254a
...
...
@@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage"
:
completion
.
usage
,
})
# test seeded random sampling
completion
=
client
.
completions
.
create
(
model
=
model
,
prompt
=
prompt
,
max_tokens
=
5
,
seed
=
33
,
temperature
=
1.0
)
results
.
append
({
"test"
:
"seeded_sampling"
,
"text"
:
completion
.
choices
[
0
].
text
,
"finish_reason"
:
completion
.
choices
[
0
].
finish_reason
,
"usage"
:
completion
.
usage
,
})
# test seeded random sampling with multiple prompts
completion
=
client
.
completions
.
create
(
model
=
model
,
prompt
=
[
prompt
,
prompt
],
max_tokens
=
5
,
seed
=
33
,
temperature
=
1.0
)
results
.
append
({
"test"
:
"seeded_sampling"
,
"text"
:
[
choice
.
text
for
choice
in
completion
.
choices
],
"finish_reason"
:
[
choice
.
finish_reason
for
choice
in
completion
.
choices
],
"usage"
:
completion
.
usage
,
})
# test simple list
batch
=
client
.
completions
.
create
(
model
=
model
,
...
...
vllm/core/scheduler.py
View file @
5cf9254a
...
...
@@ -1029,7 +1029,6 @@ class Scheduler:
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
state
=
seq_group
.
state
,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
5cf9254a
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.jit
...
...
@@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
...
...
@@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
probabilities.
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
...
...
@@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
,
draft_probs
,
draft_token_ids
,
generator
s
,
seeded_seq
s
,
))
output_token_ids
=
self
.
_create_output
(
...
...
@@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
...
...
@@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
,
generator
s
)
draft_token_ids
,
seeded_seq
s
)
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
generators
,
k
=
k
)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids
=
_multinomial
(
recovered_probs
,
num_samples
=
1
,
k
=
k
,
generators
=
generators
,
seed_indices
=
seed_indices
,
# this arg is unused when None but torch.jit requires a list
non_seed_indices
=
non_seed_indices
or
[],
seeded_seqs
=
seeded_seqs
or
{},
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
...
...
@@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
...
...
@@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
generators
)
if
len
(
seed_indices
)
==
0
:
if
not
seeded_seqs
:
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
else
:
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
for
idx
in
seed_indices
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
non_seeded_indices
=
[]
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
,
generator
=
generators
[
idx
])
if
non_seed_indices
:
uniform_rand
[
non_seed_indices
,
:]
=
torch
.
rand
(
len
(
non_seed_indices
),
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
...
...
@@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@
staticmethod
def
_split_batch_by_seeded
(
generators
:
List
[
Optional
[
torch
.
Generator
]],
k
:
int
=
1
,
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
if
all
(
generator
is
None
for
generator
in
generators
):
seed_indices
:
List
[
int
]
=
[]
non_seed_indices
:
Optional
[
List
[
int
]]
=
None
else
:
seed_indices
,
non_seed_indices
=
[],
[]
for
i
,
generator
in
enumerate
(
generators
):
if
generator
is
None
:
non_seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
else
:
seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
return
seed_indices
,
non_seed_indices
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
...
...
@@ -304,9 +282,7 @@ def _multinomial(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
k
:
int
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seed_indices
:
List
[
int
],
non_seed_indices
:
List
[
int
],
seeded_seqs
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
...
...
@@ -315,13 +291,20 @@ def _multinomial(
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
)
if
len
(
seed_indices
)
==
0
:
if
not
seeded_seqs
:
q
.
exponential_
(
1.0
)
else
:
q
[
non_seed_indices
].
exponential_
(
1.0
)
for
idx
in
seed_indices
:
q
[
idx
].
exponential_
(
1.0
,
generator
=
generators
[
idx
//
k
])
non_seeded_indices
:
List
[
int
]
=
[]
start
=
0
for
idx
in
range
(
len
(
q
)
//
k
):
end
=
start
+
k
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
extend
(
list
(
range
(
start
,
end
)))
else
:
q
[
start
:
end
].
exponential_
(
1.0
,
generator
=
generator
)
start
=
end
q
[
non_seeded_indices
].
exponential_
(
1.0
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
5cf9254a
from
abc
import
abstractmethod
from
typing
import
Lis
t
,
Optional
from
typing
import
Dic
t
,
Optional
import
torch
import
torch.jit
...
...
@@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/model_executor/sampling_metadata.py
View file @
5cf9254a
...
...
@@ -118,6 +118,7 @@ class SamplingMetadata:
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
pin_memory
:
bool
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
)
->
"SamplingMetadata"
:
(
seq_groups
,
...
...
@@ -125,7 +126,7 @@ class SamplingMetadata:
categorized_sample_indices
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
)
device
,
generators
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
device
,
...
...
@@ -160,6 +161,7 @@ def _prepare_seq_groups(
seq_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
"""Prepare sequence groups and indices for sampling.
...
...
@@ -170,8 +172,10 @@ def _prepare_seq_groups(
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
device: A device to use for random number generator
s
,
`SequenceGroupToSample.generator`.
generators: A store of per-request random number generators used
for seeded requests.
Returns:
seq_groups: A list of sequence group to sample.
...
...
@@ -217,8 +221,10 @@ def _prepare_seq_groups(
if
seq_group_metadata
.
is_prompt
:
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
sampling_params
.
seed
)
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
sampling_params
.
seed
)
if
generators
is
not
None
:
generators
[
seq_group_metadata
.
request_id
]
=
generator
num_prompts
+=
1
num_prefill_sample
=
len
(
seq_ids
)
...
...
@@ -235,6 +241,9 @@ def _prepare_seq_groups(
prompt_logprob_len
=
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
if
sampling_params
.
seed
is
not
None
and
generators
is
not
None
:
generator
=
generators
.
get
(
seq_group_metadata
.
request_id
)
# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
...
...
@@ -279,9 +288,6 @@ def _prepare_seq_groups(
logit_idx
+=
sample_len
sample_idx
+=
sample_len
if
sampling_params
.
seed
is
not
None
:
generator
=
seq_group_metadata
.
state
.
generator
seq_groups
.
append
(
SequenceGroupToSample
(
seq_ids
=
seq_ids
,
...
...
vllm/sequence.py
View file @
5cf9254a
...
...
@@ -411,14 +411,6 @@ class Sequence:
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
@
dataclass
class
SequenceGroupState
:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator
:
Optional
=
None
# type: ignore
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
...
...
@@ -461,7 +453,6 @@ class SequenceGroup:
time_in_queue
=
None
)
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
self
.
prompt_adapter_request
=
prompt_adapter_request
...
...
@@ -648,7 +639,6 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
...
...
@@ -674,7 +664,6 @@ class SequenceGroupMetadata:
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
...
...
@@ -690,7 +679,6 @@ class SequenceGroupMetadata:
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
encoder_seq_data
=
encoder_seq_data
self
.
cross_block_table
=
cross_block_table
self
.
_token_chunk_size
=
token_chunk_size
...
...
vllm/spec_decode/batch_expansion.py
View file @
5cf9254a
...
...
@@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
import
torch
from
vllm
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceGroupState
,
get_all_seq_ids
)
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
...
...
@@ -16,6 +16,8 @@ SeqId = int
TargetSeqId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
BatchExpansionTop1Scorer
(
SpeculativeScorer
):
"""Implements a speculative scorer that uses batch expansion to get
...
...
@@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score
=
self
.
_get_token_ids_to_score
(
proposal_token_ids
[
batch_index
])
# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params
=
input_seq_group_metadata
.
sampling_params
non_bonus_sampling_params
=
DEFAULT_SIMPLE_SAMPLING_PARAMS
\
if
sampling_params
.
temperature
else
sampling_params
target_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
token_ids
in
token_ids_to_score
:
last_index
=
len
(
token_ids_to_score
)
-
1
for
i
,
token_ids
in
enumerate
(
token_ids_to_score
):
target_sampling_params
=
sampling_params
if
i
==
last_index
\
else
non_bonus_sampling_params
target_seq_group_metadata_list
.
append
(
self
.
_create_single_target_seq_group_metadata
(
input_seq_group_metadata
,
input_seq_id
,
next
(
target_seq_ids_iter
),
token_ids
,
sampling_params
=
target_sampling_params
,
))
return
target_seq_group_metadata_list
@
staticmethod
def
_create_single_target_seq_group_metadata
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_id
:
SeqId
,
target_seq_id
:
TargetSeqId
,
token_ids
:
List
[
TokenId
],
sampling_params
:
SamplingParams
,
)
->
SequenceGroupMetadata
:
"""Create a single target SequenceGroupMetadata.
...
...
@@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for
data
in
new_seq_data_dict
.
values
():
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
if
(
seq_group_metadata
.
state
is
not
None
and
seq_group_metadata
.
state
.
generator
is
not
None
):
generator
=
torch
.
Generator
(
device
=
seq_group_metadata
.
state
.
generator
.
device
)
generator
.
set_state
(
seq_group_metadata
.
state
.
generator
.
get_state
())
state
=
SequenceGroupState
(
generator
=
generator
)
else
:
state
=
None
return
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
seq_data
=
new_seq_data_dict
,
sampling_params
=
seq_group_metadata
.
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
lora_request
=
None
,
token_chunk_size
=
1
,
state
=
state
,
)
def
_split_scoring_output
(
...
...
vllm/spec_decode/medusa_worker.py
View file @
5cf9254a
...
...
@@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
)
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
...
...
vllm/spec_decode/mlp_speculator_worker.py
View file @
5cf9254a
...
...
@@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
(
input_tokens
,
seq_lens
,
query_lens
)
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
)
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
input_ids
=
input_tokens
,
...
...
vllm/spec_decode/ngram_worker.py
View file @
5cf9254a
...
...
@@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
class
NGramWorker
(
NonLLMProposerWorkerBase
,
LoraNotSupportedWorkerBase
):
class
NGramWorker
(
NonLLMProposerWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding,
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
5cf9254a
...
...
@@ -213,6 +213,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
scorer_runner
=
getattr
(
self
.
scorer_worker
,
"model_runner"
,
None
)
self
.
generators
=
scorer_runner
.
get_generators
(
)
if
scorer_runner
else
None
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_allow_zero_draft_token_step
=
allow_zero_draft_token_step
...
...
@@ -591,20 +594,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
# Sampler arguments
sampler_extra_kwargs
=
{}
if
isinstance
(
self
.
spec_decode_sampler
,
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
self
.
generators
and
isinstance
(
self
.
spec_decode_sampler
,
SpecDecodeStochasticBaseSampler
):
# Get sequence group state
generators
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
if
(
seq_group_metadata
.
state
is
not
None
and
seq_group_metadata
.
state
.
generator
is
not
None
):
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
else
:
generators
.
append
(
None
)
sampler_extra_kwargs
[
"generators"
]
=
generators
sampler_extra_kwargs
[
"seeded_seqs"
]
=
{
idx
:
self
.
generators
[
sgm
.
request_id
]
for
idx
,
sgm
in
enumerate
(
seq_group_metadata_list
)
if
sgm
.
sampling_params
.
seed
is
not
None
}
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_probs
=
proposal_verifier_probs
,
...
...
vllm/worker/cpu_model_runner.py
View file @
5cf9254a
...
...
@@ -337,7 +337,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# just use seq_lens instead.
seq_lens
,
self
.
device
,
pin_memory
=
False
)
pin_memory
=
False
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
return
CPUModelInput
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
...
...
vllm/worker/model_runner.py
View file @
5cf9254a
...
...
@@ -1264,11 +1264,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
self
.
pin_memory
)
if
get_pp_group
().
is_last_rank
:
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
self
.
pin_memory
,
generators
)
else
:
sampling_metadata
=
None
is_prompt
=
(
seq_group_metadata_list
[
0
].
is_prompt
if
seq_group_metadata_list
else
None
)
return
dataclasses
.
replace
(
model_input
,
...
...
vllm/worker/model_runner_base.py
View file @
5cf9254a
...
...
@@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass.
"""
# Map of request_id -> generator used for seeded random sampling
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
@
abstractmethod
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
...
...
@@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]):
Execute the model on the given input.
"""
raise
NotImplementedError
def
get_generators
(
self
,
finished_request_ids
:
Optional
[
List
[
str
]]
=
None
):
"""
Return dict of per-request generators used for random sampling.
"""
# Clean up generators from completed requests
if
finished_request_ids
:
for
request_id
in
finished_request_ids
:
self
.
generators
.
pop
(
request_id
,
None
)
return
self
.
generators
vllm/worker/neuron_model_runner.py
View file @
5cf9254a
...
...
@@ -219,7 +219,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# just use seq_lens instead.
seq_lens
,
self
.
device
,
self
.
pin_memory
)
self
.
pin_memory
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
return
ModelInputForNeuron
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment