Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ab502751
Unverified
Commit
ab502751
authored
May 03, 2024
by
Cade Daniel
Committed by
GitHub
May 03, 2024
Browse files
[Speculative decoding] Support target-model logprobs (#4378)
parent
43c413ec
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
728 additions
and
87 deletions
+728
-87
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+63
-3
tests/spec_decode/e2e/test_logprobs.py
tests/spec_decode/e2e/test_logprobs.py
+335
-0
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+47
-16
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+8
-0
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+24
-5
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+2
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+12
-6
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+9
-7
vllm/sequence.py
vllm/sequence.py
+3
-0
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+41
-18
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+5
-0
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+6
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+74
-26
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+1
-1
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+98
-5
No files found.
tests/spec_decode/e2e/conftest.py
View file @
ab502751
import
asyncio
import
asyncio
import
time
from
itertools
import
cycle
from
itertools
import
cycle
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
pytest
import
pytest
import
ray
import
ray
import
torch
from
pynvml
import
(
nvmlDeviceGetHandleByIndex
,
nvmlDeviceGetMemoryInfo
,
nvmlInit
)
from
tests.conftest
import
cleanup
from
tests.conftest
import
cleanup
from
vllm
import
LLM
from
vllm
import
LLM
...
@@ -13,7 +17,7 @@ from vllm.lora.request import LoRARequest
...
@@ -13,7 +17,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.sequence
import
Logprob
,
MultiModalData
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
random_uuid
from
vllm.utils
import
Counter
,
random_uuid
...
@@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
...
@@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
test_name
=
request
.
node
.
name
test_name
=
request
.
node
.
name
def
generator_inner
():
def
generator_inner
():
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
threshold_bytes
=
2
*
2
**
30
,
timeout_s
=
60
,
)
use_async
=
False
use_async
=
False
if
"use_async"
in
kwargs
:
if
"use_async"
in
kwargs
:
use_async
=
kwargs
.
pop
(
"use_async"
)
use_async
=
kwargs
.
pop
(
"use_async"
)
print
(
f
'
{
use_async
=
}
'
)
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
llm
=
AsyncLLM
(
**
kwargs
)
if
use_async
else
LLM
(
**
kwargs
)
llm
=
AsyncLLM
(
**
kwargs
)
if
use_async
else
LLM
(
**
kwargs
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
...
@@ -188,6 +199,20 @@ def get_output_from_llm_generator(
...
@@ -188,6 +199,20 @@ def get_output_from_llm_generator(
return
tokens
,
token_ids
return
tokens
,
token_ids
def
get_logprobs_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
List
[
List
[
Dict
[
int
,
Logprob
]]]:
"""Returns a dict of (token_id: Logprob) for each generated position, for
each sequence in the batch.
"""
for
llm
in
llm_generator
():
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
logprobs
=
[
output
.
outputs
[
0
].
logprobs
[:]
for
output
in
outputs
]
del
llm
return
logprobs
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
test_llm_generator
,
batch_size
,
batch_size
,
...
@@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
...
@@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
assert
baseline_token_ids
==
spec_token_ids
def
wait_for_gpu_memory_to_clear
(
devices
:
List
[
int
],
threshold_bytes
:
int
,
timeout_s
:
float
=
120
)
->
None
:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit
()
start_time
=
time
.
time
()
while
True
:
output
=
{}
output_raw
=
{}
for
device
in
devices
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
gb_used
=
mem_info
.
used
/
2
**
30
output_raw
[
device
]
=
gb_used
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
'
print
(
'gpu memory used (GB): '
,
end
=
''
)
for
k
,
v
in
output
.
items
():
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
''
)
dur_s
=
time
.
time
()
-
start_time
if
all
(
v
<=
(
threshold_bytes
/
2
**
30
)
for
v
in
output_raw
.
values
()):
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
f
'(
{
threshold_bytes
/
2
**
30
=
}
)
{
dur_s
=
:.
02
f
}
'
)
break
if
dur_s
>=
timeout_s
:
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold_bytes
/
2
**
30
=
}
)'
)
time
.
sleep
(
5
)
tests/spec_decode/e2e/test_logprobs.py
0 → 100644
View file @
ab502751
import
math
from
itertools
import
cycle
import
pytest
from
vllm
import
SamplingParams
from
.conftest
import
get_logprobs_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_equality
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify output logprobs are equal with and without speculative decoding.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_diff_num_logprobs
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
,
num_logprobs
:
int
):
"""Verify output logprobs are equal with and without spec decode.
This specifies a number of logprobs >1.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
,
logprob_rank
=
num_logprobs
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
},
{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_when_skip_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_temp_1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
batch_size
=
8
max_output_len
=
output_len
force_output_len
=
True
logprob_rank
=
5
temperature
=
1.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
logprobs
=
logprob_rank
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
num_returned_logprobs
=
[
len
(
logprob_dict
)
for
seq_logprobs
in
spec_batch_logprobs
for
logprob_dict
in
seq_logprobs
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert
any
([
num_returned
>
logprob_rank
for
num_returned
in
num_returned_logprobs
])
def
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
logprob_rank
:
int
=
1
):
"""Helper method that compares the logprobs outputs of both the baseline LLM
and the test LLM. It asserts greedy equality of the logprobs when the
temperature is zero.
"""
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
logprobs
=
logprob_rank
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
baseline_batch_logprobs
=
get_logprobs_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_logprobs
)
==
len
(
prompts
)
assert
len
(
spec_batch_logprobs
)
==
len
(
prompts
)
# For each sequence in the batch.
for
i
,
(
baseline_logprobs
,
spec_logprobs
)
in
enumerate
(
zip
(
baseline_batch_logprobs
,
spec_batch_logprobs
)):
assert
len
(
spec_logprobs
)
==
len
(
baseline_logprobs
)
# For each generated position of the sequence.
for
pos
,
(
spec_pos_logprobs
,
baseline_pos_logprobs
)
in
enumerate
(
zip
(
spec_logprobs
,
baseline_logprobs
)):
# Map rank to token/logprob in spec output.
spec_rank_to_token_id
=
{
value
.
rank
:
key
for
key
,
value
in
spec_pos_logprobs
.
items
()
}
spec_rank_to_logprob
=
{
value
.
rank
:
value
.
logprob
for
key
,
value
in
spec_pos_logprobs
.
items
()
}
# Map rank to token/logprob in baseline output.
baseline_rank_to_token_id
=
{
value
.
rank
:
key
for
key
,
value
in
baseline_pos_logprobs
.
items
()
}
baseline_rank_to_logprob
=
{
value
.
rank
:
value
.
logprob
for
key
,
value
in
baseline_pos_logprobs
.
items
()
}
# Assert set of ranks returned is equal.
assert
set
(
spec_rank_to_token_id
.
keys
())
==
set
(
baseline_rank_to_token_id
.
keys
())
# Assert each logprob/token id is correct, keyed by rank.
for
rank
in
sorted
(
set
(
spec_rank_to_token_id
.
keys
())):
assert
spec_rank_to_token_id
[
rank
]
==
baseline_rank_to_token_id
[
rank
],
f
"
{
rank
}
"
assert
math
.
isclose
(
a
=
spec_rank_to_logprob
[
rank
],
b
=
baseline_rank_to_logprob
[
rank
],
abs_tol
=
1e-1
,
)
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
ab502751
...
@@ -41,8 +41,7 @@ from .conftest import (get_output_from_llm_generator,
...
@@ -41,8 +41,7 @@ from .conftest import (get_output_from_llm_generator,
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[
[{
{
# Use a small model for a fast test.
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
"model"
:
"JackFram/llama-68m"
,
...
@@ -52,13 +51,7 @@ from .conftest import (get_output_from_llm_generator,
...
@@ -52,13 +51,7 @@ from .conftest import (get_output_from_llm_generator,
# Required for spec decode.
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"use_v2_block_manager"
:
True
,
}])
# whether use AsyncLLM engine
"use_async"
:
async_mode
,
}
# Try both async and sync engine execution
for
async_mode
in
[
True
,
False
]
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
"per_test_common_llm_kwargs"
,
[
[
...
@@ -117,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
...
@@ -117,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
assert
actual_tokens
.
strip
()
==
expected_tokens
.
strip
()
assert
actual_tokens
.
strip
()
==
expected_tokens
.
strip
()
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Use AsyncLLM engine
"use_async"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_with_async_engine
(
test_llm_generator
,
baseline_llm_generator
,
batch_size
:
int
):
"""Verify spec decode works well with async LLM engine.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
32
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
ab502751
...
@@ -292,6 +292,10 @@ def test_draft_proposals_full_speculation_len():
...
@@ -292,6 +292,10 @@ def test_draft_proposals_full_speculation_len():
vocab_size
,
vocab_size
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
logprobs
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
),
size
=
(
batch_size
,
),
...
@@ -392,6 +396,10 @@ def test_draft_proposals_mixed_k():
...
@@ -392,6 +396,10 @@ def test_draft_proposals_mixed_k():
vocab_size
,
vocab_size
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
logprobs
=
torch
.
rand
(
expected_num_proposal_seqs
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
sampled_token_ids
=
torch
.
randint
(
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
ab502751
...
@@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
...
@@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size
,
vocab_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
...
@@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size
,
vocab_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
...
@@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
num_lookahead_slots
=
k
)
num_lookahead_slots
=
k
)
expected_output
=
create_sampler_output_list
(
expected_output
=
create_sampler_output_list
(
rejection_sampler_output
.
transpose
(
0
,
1
),
[
None
for
_
in
range
(
k
+
1
)])
token_ids
=
rejection_sampler_output
.
transpose
(
0
,
1
),
probs
=
[
None
for
_
in
range
(
k
+
1
)],
logprobs
=
[
None
for
_
in
range
(
k
+
1
)])
seq_ids
=
[
seq_ids
=
[
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
...
@@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
continue
continue
assert
actual_by_step
[
i
].
output_token
==
expected_by_step
[
assert
actual_by_step
[
i
].
output_token
==
expected_by_step
[
i
].
output_token
i
].
output_token
assert
actual_by_step
[
i
].
logprobs
==
expected_by_step
[
i
].
logprobs
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
...
@@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size
,
vocab_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
...
...
tests/spec_decode/utils.py
View file @
ab502751
...
@@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose(
...
@@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose(
def
create_sampler_output_list
(
def
create_sampler_output_list
(
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
probs
:
Iterable
[
Optional
[
torch
.
Tensor
]],
probs
:
Iterable
[
Optional
[
torch
.
Tensor
]],
logprobs
:
Iterable
[
Optional
[
torch
.
Tensor
]],
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
num_steps
,
batch_size
=
token_ids
.
shape
num_steps
,
batch_size
=
token_ids
.
shape
token_ids_by_step
=
token_ids
.
tolist
()
token_ids_by_step
=
token_ids
.
tolist
()
...
@@ -222,6 +223,7 @@ def create_sampler_output_list(
...
@@ -222,6 +223,7 @@ def create_sampler_output_list(
)
for
seq_index
,
token_id
in
enumerate
(
token_ids_by_step
[
step
])
)
for
seq_index
,
token_id
in
enumerate
(
token_ids_by_step
[
step
])
],
],
sampled_token_probs
=
probs
[
step
],
sampled_token_probs
=
probs
[
step
],
logprobs
=
logprobs
[
step
],
sampled_token_ids
=
token_ids
[
step
])
sampled_token_ids
=
token_ids
[
step
])
for
step
in
range
(
num_steps
)
for
step
in
range
(
num_steps
)
]
]
...
...
vllm/engine/output_processor/multi_step.py
View file @
ab502751
import
functools
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
...
@@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import (
...
@@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Logprob
,
Sequence
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -48,10 +49,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -48,10 +49,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
# TODO(sang): Prompt logprob currently not implemented in multi step
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
# workers.
self
.
_log_prompt_logprob_unsupported_warning_once
()
@
staticmethod
@
functools
.
lru_cache
()
def
_log_prompt_logprob_unsupported_warning_once
():
logger
.
warning
(
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers)."
)
"(e.g., speculative decode uses multi step workers)."
)
pass
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
...
@@ -89,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -89,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
valid_samples
:
List
[
SequenceOutput
],
valid_samples
:
List
[
SequenceOutput
],
sampling_params
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
output_token_ids
=
[
sample
.
output_token
for
sample
in
valid_samples
]
output_token_ids
=
[
sample
.
output_token
for
sample
in
valid_samples
]
output_logprobs
=
[
sample
.
logprobs
for
sample
in
valid_samples
]
# Truncate to max_tokens if necessary.
# Truncate to max_tokens if necessary.
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
...
@@ -113,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -113,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Incrementally append tokens to the sequence, as if we had only one new
# Incrementally append tokens to the sequence, as if we had only one new
# token.
# token.
for
output_token_id
in
output_token_ids
:
for
output_token_id
,
output_logprob
in
zip
(
output_token_ids
,
output_logprobs
):
seq
.
append_token_id
(
seq
.
append_token_id
(
token_id
=
output_token_id
,
token_id
=
output_token_id
,
# TODO emit logprobs in multi-step decoding.
logprobs
=
output_logprob
,
logprobs
=
{
output_token_id
:
Logprob
(
0.0
)},
)
)
new_char_count
=
0
new_char_count
=
0
...
...
vllm/model_executor/layers/sampler.py
View file @
ab502751
...
@@ -103,8 +103,7 @@ class Sampler(nn.Module):
...
@@ -103,8 +103,7 @@ class Sampler(nn.Module):
if
self
.
include_gpu_probs_tensor
:
if
self
.
include_gpu_probs_tensor
:
assert
maybe_sampled_tokens_tensor
is
not
None
assert
maybe_sampled_tokens_tensor
is
not
None
sampled_tokens_tensor
=
maybe_sampled_tokens_tensor
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
on_device_tensors
=
(
probs
,
sampled_tokens_tensor
)
else
:
else
:
on_device_tensors
=
None
on_device_tensors
=
None
...
@@ -965,8 +964,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
...
@@ -965,8 +964,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
has implications on the overall design of the sampler, e.g. how to record
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
accurate logprobs for the user, so this improvement is deferred to later.
"""
"""
logprobs
[
sample_indices
,
:]
=
-
float
(
'inf'
)
# NOTE: logprobs are not modified so they can be returned to the user.
logprobs
[
sample_indices
,
greedy_samples
]
=
0.0
probs
[
sample_indices
,
:]
=
0
probs
[
sample_indices
,
:]
=
0
probs
[
sample_indices
,
greedy_samples
]
=
1.0
probs
[
sample_indices
,
greedy_samples
]
=
1.0
...
@@ -976,7 +974,8 @@ def _build_sampler_output(
...
@@ -976,7 +974,8 @@ def _build_sampler_output(
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
sample_logprobs
:
List
[
SampleLogprobs
],
sample_logprobs
:
List
[
SampleLogprobs
],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
"""Construct Python objects with the output of sampling.
...
@@ -1005,14 +1004,17 @@ def _build_sampler_output(
...
@@ -1005,14 +1004,17 @@ def _build_sampler_output(
# If not specified, store None values in SamplerOutput.
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
if
on_device_tensors
is
not
None
:
sampled_token_probs
,
sampled_token_ids
=
on_device_tensors
(
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
)
=
on_device_tensors
else
:
else
:
sampled_token_probs
,
sampled_token_ids
=
(
None
,
None
)
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
None
)
return
SamplerOutput
(
return
SamplerOutput
(
outputs
=
sampler_output
,
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
)
)
...
...
vllm/sequence.py
View file @
ab502751
...
@@ -700,6 +700,9 @@ class SamplerOutput:
...
@@ -700,6 +700,9 @@ class SamplerOutput:
# On-device tensor containing probabilities of each token.
# On-device tensor containing probabilities of each token.
sampled_token_probs
:
Optional
[
"torch.Tensor"
]
=
None
sampled_token_probs
:
Optional
[
"torch.Tensor"
]
=
None
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
# On-device tensor containing the sampled token ids.
# On-device tensor containing the sampled token ids.
sampled_token_ids
:
Optional
[
"torch.Tensor"
]
=
None
sampled_token_ids
:
Optional
[
"torch.Tensor"
]
=
None
...
...
vllm/spec_decode/batch_expansion.py
View file @
ab502751
...
@@ -94,7 +94,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -94,7 +94,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
=
self
.
_contract_batch
(
all_tokens
,
all_probs
,
spec_logprobs
=
self
.
_contract_batch
(
contracted_bs
=
len
(
seq_group_metadata_list
),
contracted_bs
=
len
(
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
proposals
=
proposals
,
...
@@ -107,6 +107,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -107,6 +107,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return
SpeculativeScores
(
return
SpeculativeScores
(
probs
=
all_probs
,
probs
=
all_probs
,
token_ids
=
all_tokens
,
token_ids
=
all_tokens
,
logprobs
=
spec_logprobs
,
)
)
def
_expand_batch
(
def
_expand_batch
(
...
@@ -148,12 +149,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -148,12 +149,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
num_scoring_tokens
)
def
_contract_batch
(
self
,
contracted_bs
:
int
,
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
List
[
SamplerOutput
],
target_sampler_output
:
List
[
SamplerOutput
],
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Contract the expanded batch back into its original size.
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
This maps the scores of speculative tokens back to their original
sequences.
sequences.
...
@@ -161,8 +162,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -161,8 +162,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
target_sampler_output will be contracted to.
"""
"""
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
(
target_token_ids
,
target_probs
,
target_logprobs
,
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
target_sampler_output
,
num_scoring_tokens
)
# Map distinct sequences used to score each token
# Map distinct sequences used to score each token
...
@@ -179,6 +181,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -179,6 +181,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
spec_expanded_bs
,
k
+
1
)
spec_expanded_bs
,
k
+
1
)
target_probs
=
target_probs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
target_probs
=
target_probs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
self
.
_vocab_size
)
target_logprobs
=
target_logprobs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
all_tokens
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
),
all_tokens
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
,
fill_value
=-
1
,
...
@@ -189,16 +193,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -189,16 +193,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self
.
_vocab_size
,
self
.
_vocab_size
,
device
=
self
.
_device
,
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
all_logprobs
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
,
self
.
_vocab_size
,
),
fill_value
=-
float
(
"inf"
),
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
if
non_spec_indices
:
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_logprobs
if
spec_indices
:
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
return
all_tokens
,
all_probs
return
all_tokens
,
all_probs
,
all_logprobs
def
_create_scoring_model_input
(
def
_create_scoring_model_input
(
self
,
self
,
...
@@ -308,7 +322,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -308,7 +322,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def
_split_scoring_output
(
def
_split_scoring_output
(
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Split the target model output into speculative and non-speculative
"""Split the target model output into speculative and non-speculative
output.
output.
"""
"""
...
@@ -328,21 +343,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -328,21 +343,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
spec_sampled_tokens
,
non_spec_sampled_tokens
(
spec_sampled_tokens
,
non_spec_sampled_tokens
)
=
sampler_output
.
sampled_token_ids
.
flatten
().
split
(
split_sizes
)
)
=
sampler_output
.
sampled_token_ids
.
flatten
().
split
(
split_sizes
)
(
spec_logprobs
,
non_spec_logprobs
,
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
# Convert scores to tensors.
# Convert scores to tensors.
sampler_output
.
sampled_token_probs
=
spec_probs
sampler_output
.
sampled_token_probs
=
spec_probs
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
target_token_ids
,
target_probs
=
sampler_output_to_torch
(
sampler_output
.
logprobs
=
spec_logprobs
[
sampler_output
],
True
)
(
target_token_ids
,
target_probs
,
target_logprobs
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
# Convert non-speculative output tokens to tensors.
# Convert non-speculative output tokens to tensors.
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
non_spec_target_token_ids
,
non_spec_target_probs
=
(
sampler_output
.
logprobs
=
non_spec_logprobs
sampler_output_to_torch
([
sampler_output
],
True
))
(
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
=
sampler_output_to_torch
([
sampler_output
],
return
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
True
)
non_spec_target_probs
)
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
def
_create_target_seq_id_iterator
(
def
_create_target_seq_id_iterator
(
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
...
...
vllm/spec_decode/interfaces.py
View file @
ab502751
...
@@ -38,6 +38,11 @@ class SpeculativeScores:
...
@@ -38,6 +38,11 @@ class SpeculativeScores:
# Probabilities of the speculative tokens according to the scoring model.
# Probabilities of the speculative tokens according to the scoring model.
probs
:
torch
.
Tensor
probs
:
torch
.
Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs
:
torch
.
Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
# tokens and also non-speculative normal decoding.
token_ids
:
torch
.
Tensor
token_ids
:
torch
.
Tensor
...
...
vllm/spec_decode/ngram_worker.py
View file @
ab502751
...
@@ -140,11 +140,17 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -140,11 +140,17 @@ class NGramWorker(LoraNotSupportedWorkerBase):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
token_logprobs
=
torch
.
zeros
(
(
len
(
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
for
i
in
range
(
len
(
seq_group_metadata_list
)):
for
i
in
range
(
len
(
seq_group_metadata_list
)):
outputs
.
append
(
outputs
.
append
(
SamplerOutput
(
SamplerOutput
(
outputs
=
None
,
outputs
=
None
,
sampled_token_probs
=
token_probs
[
i
],
sampled_token_probs
=
token_probs
[
i
],
logprobs
=
token_logprobs
,
sampled_token_ids
=
token_ids
[
i
],
sampled_token_ids
=
token_ids
[
i
],
))
))
return
outputs
,
False
return
outputs
,
False
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
ab502751
...
@@ -5,15 +5,16 @@ import torch
...
@@ -5,15 +5,16 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
get_all_num_logprobs
,
get_all_seq_ids
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
split_batch_by_proposal_len
)
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
...
@@ -258,6 +259,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -258,6 +259,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# overhead when the engine runs in a different process than the workers.
# overhead when the engine runs in a different process than the workers.
sampler_output
.
probs
=
None
sampler_output
.
probs
=
None
sampler_output
.
sampled_tokens
=
None
sampler_output
.
sampled_tokens
=
None
sampler_output
.
logprobs
=
None
return
[
sampler_output
]
return
[
sampler_output
]
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
...
@@ -298,12 +300,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -298,12 +300,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
)
#logger.info("verify proposals")
#logger.info("verify proposals")
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
proposal_scores
,
proposals
,
k
)
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
#logger.info("create output list")
#logger.info("create output list")
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
return
self
.
_create_output_sampler_list
(
accepted_token_ids
,
k
)
seq_group_metadata_list
,
accepted_token_ids
,
target_logprobs
=
target_logprobs
,
k
=
k
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
def
_verify_tokens
(
...
@@ -312,9 +317,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -312,9 +317,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores
:
SpeculativeScores
,
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
max_proposal_len
:
int
,
)
->
torch
.
Tensor
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
"""Determine which speculative tokens are accepted using the
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
"""
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
...
@@ -361,17 +369,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -361,17 +369,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids
[:,
1
:]
=
-
1
non_spec_token_ids
[:,
1
:]
=
-
1
accepted_token_ids
=
torch
.
cat
(
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
[
accepted_token_ids
,
non_spec_token_ids
])
logprobs
=
proposal_scores
.
logprobs
# Rearrange so that results are in the order of the original seq group
# Rearrange so that results are in the order of the original seq group
# metadata.
# metadata.
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
return
accepted_token_ids
return
accepted_token_ids
,
logprobs
def
_create_output_sampler_list
(
def
_create_output_sampler_list
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
k
:
int
,
k
:
int
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
"""Given the accepted token ids, create a list of SamplerOutput.
"""Given the accepted token ids, create a list of SamplerOutput.
...
@@ -379,30 +389,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -379,30 +389,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
The output is padded with -1 tokens such that each sequence has
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
the same number of outputs.
"""
"""
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
batch_size
,
num_steps
=
accepted_token_ids
.
shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
# Get the logprobs/rank of the accepted tokens.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
)
=
get_sampled_token_logprobs
(
logprob_tensor
=
target_logprobs_by_step
,
sampled_token_ids
=
accepted_token_ids_by_step
,
)
# shape: [k+1, batch_size]
# Get the top-k logprobs (which may or may not include the logprob of
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
# the accepted token).
1
).
tolist
()
(
topk_logprobs_by_step
,
topk_indices_by_step
)
=
target_logprobs_by_step
.
topk
(
k
=
self
.
scorer_worker
.
model_config
.
max_logprobs
,
dim
=-
1
,
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize all tensors to CPU Python lists.
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
accepted_token_id_ranks_by_step
=
(
accepted_token_id_ranks_by_step
.
tolist
())
accepted_token_id_logprobs_by_step
=
(
accepted_token_id_logprobs_by_step
.
tolist
())
topk_logprobs_by_step
=
topk_logprobs_by_step
.
tolist
()
topk_indices_by_step
=
topk_indices_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list
=
[]
sampler_output_list
=
[]
for
token_ids_by_step
in
accepted_token_ids_by_step
:
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
for
token_id
in
token_ids_by_step
):
if
all
(
token_id
==
-
1
for
token_id
in
accepted_token_ids_by_step
[
step_index
]):
break
break
step_output_token_ids
=
[]
step_output_token_ids
=
[]
for
token_id
,
seq_id
in
zip
(
token_ids_by_step
,
seq_ids
):
for
sequence_index
in
range
(
batch_size
):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
step_output_token_ids
.
append
(
step_output_token_ids
.
append
(
SequenceGroupOutput
(
create_sequence_group_output
(
samples
=
[
token_id
=
accepted_token_ids_by_step
[
step_index
]
SequenceOutput
(
[
sequence_index
],
parent_seq_id
=
seq_id
,
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
output_token
=
token_id
,
step_index
][
sequence_index
],
# TODO Add verifier logprobs.
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
logprobs
=
{
token_id
:
Logprob
(
0.0
)},
step_index
][
sequence_index
],
)
seq_id
=
seq_ids
[
sequence_index
],
],
topk_token_ids
=
topk_indices_by_step
[
step_index
]
prompt_logprobs
=
None
,
[
sequence_index
][:
num_logprobs
],
topk_logprobs
=
topk_logprobs_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
))
))
sampler_output_list
.
append
(
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
SamplerOutput
(
outputs
=
step_output_token_ids
))
...
...
vllm/spec_decode/top1_proposer.py
View file @
ab502751
...
@@ -166,7 +166,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -166,7 +166,7 @@ class Top1Proposer(SpeculativeProposer):
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
proposal_tokens
,
proposal_probs
,
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
sampler_output
,
sampler_transposed
)
# Now, reformat the output GPU tensors such that each sequence has
# Now, reformat the output GPU tensors such that each sequence has
...
...
vllm/spec_decode/util.py
View file @
ab502751
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SeqId
=
int
SeqId
=
int
...
@@ -21,6 +22,89 @@ def get_all_seq_ids(
...
@@ -21,6 +22,89 @@ def get_all_seq_ids(
]))
]))
def
get_all_num_logprobs
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
List
[
int
]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
num_logprobs
=
seq_group_metadata
.
sampling_params
.
logprobs
if
seq_group_metadata
.
sampling_params
.
logprobs
is
None
:
num_logprobs
=
0
all_num_logprobs
.
append
(
num_logprobs
)
return
all_num_logprobs
def
get_sampled_token_logprobs
(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
# shape [num_steps, batch_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps
,
batch_size
,
vocab_size
=
logprob_tensor
.
shape
selected_logprobs
=
logprob_tensor
[
torch
.
arange
(
num_steps
).
unsqueeze
(
1
),
torch
.
arange
(
batch_size
),
sampled_token_ids
,
]
expanded_selected_logprobs
=
selected_logprobs
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
vocab_size
)
sampled_token_ids_ranks
=
(
logprob_tensor
>=
expanded_selected_logprobs
).
sum
(
-
1
)
return
sampled_token_ids_ranks
,
selected_logprobs
def
create_sequence_group_output
(
token_id
:
int
,
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
seq_id
:
SeqId
,
topk_token_ids
:
List
[
int
],
topk_logprobs
:
List
[
float
],
)
->
SequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[int]): The list of top-k token ids.
topk_logprobs (List[float]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs
:
Dict
[
int
,
Logprob
]
=
{
token_id
:
Logprob
(
logprob
=
token_id_logprob
,
rank
=
token_id_logprob_rank
,
),
}
logprobs
.
update
({
topk_token_ids
[
topk_logprob_index
]:
Logprob
(
logprob
=
topk_logprobs
[
topk_logprob_index
],
rank
=
topk_logprob_index
+
1
,
)
for
topk_logprob_index
,
_
in
enumerate
(
topk_token_ids
)
})
return
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
logprobs
)
],
# TODO add prompt logprobs support.
prompt_logprobs
=
None
,
)
def
split_batch_by_proposal_len
(
def
split_batch_by_proposal_len
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_lens
:
List
[
int
],
select_proposal_len_zero
:
bool
proposal_lens
:
List
[
int
],
select_proposal_len_zero
:
bool
...
@@ -49,8 +133,8 @@ def split_batch_by_proposal_len(
...
@@ -49,8 +133,8 @@ def split_batch_by_proposal_len(
def
sampler_output_to_torch
(
def
sampler_output_to_torch
(
sampler_output_list
:
List
[
SamplerOutput
],
sampler_output_list
:
List
[
SamplerOutput
],
sampler_transposed
:
bool
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Utility function which converts a list of SamplerOutput to tensors.
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
sampler_transposed here is used as the indicator for whether
...
@@ -76,6 +160,15 @@ def sampler_output_to_torch(
...
@@ -76,6 +160,15 @@ def sampler_output_to_torch(
if
sampler_transposed
:
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs
=
torch
.
stack
(
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output]
# shape: [batch_size, num_sampler_output]
sampled_token_ids
=
torch
.
stack
(
sampled_token_ids
=
torch
.
stack
(
[
[
...
@@ -87,7 +180,7 @@ def sampler_output_to_torch(
...
@@ -87,7 +180,7 @@ def sampler_output_to_torch(
if
sampler_transposed
:
if
sampler_transposed
:
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
return
sampled_token_ids
,
sampled_token_probs
return
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment