Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ab502751
"tests/tokenization/test_cached_tokenizer.py" did not exist on "8fa7357f2d3171e3d373be865c8f9520e538c415"
Unverified
Commit
ab502751
authored
May 03, 2024
by
Cade Daniel
Committed by
GitHub
May 03, 2024
Browse files
[Speculative decoding] Support target-model logprobs (#4378)
parent
43c413ec
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
728 additions
and
87 deletions
+728
-87
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+63
-3
tests/spec_decode/e2e/test_logprobs.py
tests/spec_decode/e2e/test_logprobs.py
+335
-0
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+47
-16
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+8
-0
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+24
-5
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+2
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+12
-6
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+9
-7
vllm/sequence.py
vllm/sequence.py
+3
-0
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+41
-18
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+5
-0
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+6
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+74
-26
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+1
-1
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+98
-5
No files found.
tests/spec_decode/e2e/conftest.py
View file @
ab502751
import
asyncio
import
time
from
itertools
import
cycle
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
pytest
import
ray
import
torch
from
pynvml
import
(
nvmlDeviceGetHandleByIndex
,
nvmlDeviceGetMemoryInfo
,
nvmlInit
)
from
tests.conftest
import
cleanup
from
vllm
import
LLM
...
...
@@ -13,7 +17,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.sequence
import
Logprob
,
MultiModalData
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
random_uuid
...
...
@@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
test_name
=
request
.
node
.
name
def
generator_inner
():
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
threshold_bytes
=
2
*
2
**
30
,
timeout_s
=
60
,
)
use_async
=
False
if
"use_async"
in
kwargs
:
use_async
=
kwargs
.
pop
(
"use_async"
)
print
(
f
'
{
use_async
=
}
'
)
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
llm
=
AsyncLLM
(
**
kwargs
)
if
use_async
else
LLM
(
**
kwargs
)
set_random_seed
(
seed
)
...
...
@@ -188,6 +199,20 @@ def get_output_from_llm_generator(
return
tokens
,
token_ids
def
get_logprobs_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
List
[
List
[
Dict
[
int
,
Logprob
]]]:
"""Returns a dict of (token_id: Logprob) for each generated position, for
each sequence in the batch.
"""
for
llm
in
llm_generator
():
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
logprobs
=
[
output
.
outputs
[
0
].
logprobs
[:]
for
output
in
outputs
]
del
llm
return
logprobs
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
...
...
@@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
def
wait_for_gpu_memory_to_clear
(
devices
:
List
[
int
],
threshold_bytes
:
int
,
timeout_s
:
float
=
120
)
->
None
:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit
()
start_time
=
time
.
time
()
while
True
:
output
=
{}
output_raw
=
{}
for
device
in
devices
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
gb_used
=
mem_info
.
used
/
2
**
30
output_raw
[
device
]
=
gb_used
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
'
print
(
'gpu memory used (GB): '
,
end
=
''
)
for
k
,
v
in
output
.
items
():
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
''
)
dur_s
=
time
.
time
()
-
start_time
if
all
(
v
<=
(
threshold_bytes
/
2
**
30
)
for
v
in
output_raw
.
values
()):
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
f
'(
{
threshold_bytes
/
2
**
30
=
}
)
{
dur_s
=
:.
02
f
}
'
)
break
if
dur_s
>=
timeout_s
:
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold_bytes
/
2
**
30
=
}
)'
)
time
.
sleep
(
5
)
tests/spec_decode/e2e/test_logprobs.py
0 → 100644
View file @
ab502751
import
math
from
itertools
import
cycle
import
pytest
from
vllm
import
SamplingParams
from
.conftest
import
get_logprobs_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_equality
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify output logprobs are equal with and without speculative decoding.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_diff_num_logprobs
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
,
num_logprobs
:
int
):
"""Verify output logprobs are equal with and without spec decode.
This specifies a number of logprobs >1.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
,
logprob_rank
=
num_logprobs
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
},
{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_when_skip_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_temp_1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
batch_size
=
8
max_output_len
=
output_len
force_output_len
=
True
logprob_rank
=
5
temperature
=
1.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
logprobs
=
logprob_rank
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
num_returned_logprobs
=
[
len
(
logprob_dict
)
for
seq_logprobs
in
spec_batch_logprobs
for
logprob_dict
in
seq_logprobs
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert
any
([
num_returned
>
logprob_rank
for
num_returned
in
num_returned_logprobs
])
def
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
logprob_rank
:
int
=
1
):
"""Helper method that compares the logprobs outputs of both the baseline LLM
and the test LLM. It asserts greedy equality of the logprobs when the
temperature is zero.
"""
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
logprobs
=
logprob_rank
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
baseline_batch_logprobs
=
get_logprobs_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_logprobs
)
==
len
(
prompts
)
assert
len
(
spec_batch_logprobs
)
==
len
(
prompts
)
# For each sequence in the batch.
for
i
,
(
baseline_logprobs
,
spec_logprobs
)
in
enumerate
(
zip
(
baseline_batch_logprobs
,
spec_batch_logprobs
)):
assert
len
(
spec_logprobs
)
==
len
(
baseline_logprobs
)
# For each generated position of the sequence.
for
pos
,
(
spec_pos_logprobs
,
baseline_pos_logprobs
)
in
enumerate
(
zip
(
spec_logprobs
,
baseline_logprobs
)):
# Map rank to token/logprob in spec output.
spec_rank_to_token_id
=
{
value
.
rank
:
key
for
key
,
value
in
spec_pos_logprobs
.
items
()
}
spec_rank_to_logprob
=
{
value
.
rank
:
value
.
logprob
for
key
,
value
in
spec_pos_logprobs
.
items
()
}
# Map rank to token/logprob in baseline output.
baseline_rank_to_token_id
=
{
value
.
rank
:
key
for
key
,
value
in
baseline_pos_logprobs
.
items
()
}
baseline_rank_to_logprob
=
{
value
.
rank
:
value
.
logprob
for
key
,
value
in
baseline_pos_logprobs
.
items
()
}
# Assert set of ranks returned is equal.
assert
set
(
spec_rank_to_token_id
.
keys
())
==
set
(
baseline_rank_to_token_id
.
keys
())
# Assert each logprob/token id is correct, keyed by rank.
for
rank
in
sorted
(
set
(
spec_rank_to_token_id
.
keys
())):
assert
spec_rank_to_token_id
[
rank
]
==
baseline_rank_to_token_id
[
rank
],
f
"
{
rank
}
"
assert
math
.
isclose
(
a
=
spec_rank_to_logprob
[
rank
],
b
=
baseline_rank_to_logprob
[
rank
],
abs_tol
=
1e-1
,
)
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
ab502751
...
...
@@ -41,8 +41,7 @@ from .conftest import (get_output_from_llm_generator,
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[
{
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
...
...
@@ -52,13 +51,7 @@ from .conftest import (get_output_from_llm_generator,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# whether use AsyncLLM engine
"use_async"
:
async_mode
,
}
# Try both async and sync engine execution
for
async_mode
in
[
True
,
False
]
])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
...
...
@@ -117,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
assert
actual_tokens
.
strip
()
==
expected_tokens
.
strip
()
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Use AsyncLLM engine
"use_async"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_with_async_engine
(
test_llm_generator
,
baseline_llm_generator
,
batch_size
:
int
):
"""Verify spec decode works well with async LLM engine.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
32
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
ab502751
...
...
@@ -292,6 +292,10 @@ def test_draft_proposals_full_speculation_len():
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
logprobs
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
),
...
...
@@ -392,6 +396,10 @@ def test_draft_proposals_mixed_k():
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
logprobs
=
torch
.
rand
(
expected_num_proposal_seqs
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
ab502751
...
...
@@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
...
...
@@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
...
...
@@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
num_lookahead_slots
=
k
)
expected_output
=
create_sampler_output_list
(
rejection_sampler_output
.
transpose
(
0
,
1
),
[
None
for
_
in
range
(
k
+
1
)])
token_ids
=
rejection_sampler_output
.
transpose
(
0
,
1
),
probs
=
[
None
for
_
in
range
(
k
+
1
)],
logprobs
=
[
None
for
_
in
range
(
k
+
1
)])
seq_ids
=
[
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
...
...
@@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
continue
assert
actual_by_step
[
i
].
output_token
==
expected_by_step
[
i
].
output_token
assert
actual_by_step
[
i
].
logprobs
==
expected_by_step
[
i
].
logprobs
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
...
...
@@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
...
...
tests/spec_decode/utils.py
View file @
ab502751
...
...
@@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose(
def
create_sampler_output_list
(
token_ids
:
torch
.
Tensor
,
probs
:
Iterable
[
Optional
[
torch
.
Tensor
]],
logprobs
:
Iterable
[
Optional
[
torch
.
Tensor
]],
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
num_steps
,
batch_size
=
token_ids
.
shape
token_ids_by_step
=
token_ids
.
tolist
()
...
...
@@ -222,6 +223,7 @@ def create_sampler_output_list(
)
for
seq_index
,
token_id
in
enumerate
(
token_ids_by_step
[
step
])
],
sampled_token_probs
=
probs
[
step
],
logprobs
=
logprobs
[
step
],
sampled_token_ids
=
token_ids
[
step
])
for
step
in
range
(
num_steps
)
]
...
...
vllm/engine/output_processor/multi_step.py
View file @
ab502751
import
functools
from
typing
import
Callable
,
List
from
transformers
import
PreTrainedTokenizer
...
...
@@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Logprob
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
...
...
@@ -48,10 +49,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
self
.
_log_prompt_logprob_unsupported_warning_once
()
@
staticmethod
@
functools
.
lru_cache
()
def
_log_prompt_logprob_unsupported_warning_once
():
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers)."
)
pass
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
...
...
@@ -89,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
valid_samples
:
List
[
SequenceOutput
],
sampling_params
:
SamplingParams
)
->
None
:
output_token_ids
=
[
sample
.
output_token
for
sample
in
valid_samples
]
output_logprobs
=
[
sample
.
logprobs
for
sample
in
valid_samples
]
# Truncate to max_tokens if necessary.
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
...
...
@@ -113,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for
output_token_id
in
output_token_ids
:
for
output_token_id
,
output_logprob
in
zip
(
output_token_ids
,
output_logprobs
):
seq
.
append_token_id
(
token_id
=
output_token_id
,
# TODO emit logprobs in multi-step decoding.
logprobs
=
{
output_token_id
:
Logprob
(
0.0
)},
logprobs
=
output_logprob
,
)
new_char_count
=
0
...
...
vllm/model_executor/layers/sampler.py
View file @
ab502751
...
...
@@ -103,8 +103,7 @@ class Sampler(nn.Module):
if
self
.
include_gpu_probs_tensor
:
assert
maybe_sampled_tokens_tensor
is
not
None
sampled_tokens_tensor
=
maybe_sampled_tokens_tensor
on_device_tensors
=
(
probs
,
sampled_tokens_tensor
)
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
else
:
on_device_tensors
=
None
...
...
@@ -965,8 +964,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
"""
logprobs
[
sample_indices
,
:]
=
-
float
(
'inf'
)
logprobs
[
sample_indices
,
greedy_samples
]
=
0.0
# NOTE: logprobs are not modified so they can be returned to the user.
probs
[
sample_indices
,
:]
=
0
probs
[
sample_indices
,
greedy_samples
]
=
1.0
...
...
@@ -976,7 +974,8 @@ def _build_sampler_output(
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
sample_logprobs
:
List
[
SampleLogprobs
],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
...
...
@@ -1005,14 +1004,17 @@ def _build_sampler_output(
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
sampled_token_probs
,
sampled_token_ids
=
on_device_tensors
(
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
)
=
on_device_tensors
else
:
sampled_token_probs
,
sampled_token_ids
=
(
None
,
None
)
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
None
)
return
SamplerOutput
(
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
)
...
...
vllm/sequence.py
View file @
ab502751
...
...
@@ -700,6 +700,9 @@ class SamplerOutput:
# On-device tensor containing probabilities of each token.
sampled_token_probs
:
Optional
[
"torch.Tensor"
]
=
None
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
# On-device tensor containing the sampled token ids.
sampled_token_ids
:
Optional
[
"torch.Tensor"
]
=
None
...
...
vllm/spec_decode/batch_expansion.py
View file @
ab502751
...
...
@@ -94,7 +94,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
=
self
.
_contract_batch
(
all_tokens
,
all_probs
,
spec_logprobs
=
self
.
_contract_batch
(
contracted_bs
=
len
(
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
...
...
@@ -107,6 +107,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return
SpeculativeScores
(
probs
=
all_probs
,
token_ids
=
all_tokens
,
logprobs
=
spec_logprobs
,
)
def
_expand_batch
(
...
...
@@ -148,12 +149,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
def
_contract_batch
(
self
,
contracted_bs
:
int
,
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
List
[
SamplerOutput
],
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
...
...
@@ -161,8 +162,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
(
target_token_ids
,
target_probs
,
target_logprobs
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
# Map distinct sequences used to score each token
...
...
@@ -179,6 +181,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
spec_expanded_bs
,
k
+
1
)
target_probs
=
target_probs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
target_logprobs
=
target_logprobs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
all_tokens
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
,
...
...
@@ -189,16 +193,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self
.
_vocab_size
,
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
all_logprobs
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
,
self
.
_vocab_size
,
),
fill_value
=-
float
(
"inf"
),
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_logprobs
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
return
all_tokens
,
all_probs
return
all_tokens
,
all_probs
,
all_logprobs
def
_create_scoring_model_input
(
self
,
...
...
@@ -308,7 +322,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def
_split_scoring_output
(
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Split the target model output into speculative and non-speculative
output.
"""
...
...
@@ -328,21 +343,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
spec_sampled_tokens
,
non_spec_sampled_tokens
)
=
sampler_output
.
sampled_token_ids
.
flatten
().
split
(
split_sizes
)
(
spec_logprobs
,
non_spec_logprobs
,
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
# Convert scores to tensors.
sampler_output
.
sampled_token_probs
=
spec_probs
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
target_token_ids
,
target_probs
=
sampler_output_to_torch
(
[
sampler_output
],
True
)
sampler_output
.
logprobs
=
spec_logprobs
(
target_token_ids
,
target_probs
,
target_logprobs
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
# Convert non-speculative output tokens to tensors.
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
non_spec_target_token_ids
,
non_spec_target_probs
=
(
sampler_output_to_torch
([
sampler_output
],
True
))
return
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
sampler_output
.
logprobs
=
non_spec_logprobs
(
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
def
_create_target_seq_id_iterator
(
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
...
...
vllm/spec_decode/interfaces.py
View file @
ab502751
...
...
@@ -38,6 +38,11 @@ class SpeculativeScores:
# Probabilities of the speculative tokens according to the scoring model.
probs
:
torch
.
Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs
:
torch
.
Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids
:
torch
.
Tensor
...
...
vllm/spec_decode/ngram_worker.py
View file @
ab502751
...
...
@@ -140,11 +140,17 @@ class NGramWorker(LoraNotSupportedWorkerBase):
device
=
self
.
device
,
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
token_logprobs
=
torch
.
zeros
(
(
len
(
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
for
i
in
range
(
len
(
seq_group_metadata_list
)):
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
sampled_token_probs
=
token_probs
[
i
],
logprobs
=
token_logprobs
,
sampled_token_ids
=
token_ids
[
i
],
))
return
outputs
,
False
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
ab502751
...
...
@@ -5,15 +5,16 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
get_all_num_logprobs
,
get_all_seq_ids
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
...
...
@@ -258,6 +259,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# overhead when the engine runs in a different process than the workers.
sampler_output
.
probs
=
None
sampler_output
.
sampled_tokens
=
None
sampler_output
.
logprobs
=
None
return
[
sampler_output
]
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
...
...
@@ -298,12 +300,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
#logger.info("verify proposals")
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
#logger.info("create output list")
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
accepted_token_ids
,
k
)
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
accepted_token_ids
,
target_logprobs
=
target_logprobs
,
k
=
k
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
...
...
@@ -312,9 +317,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
)
->
torch
.
Tensor
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
...
...
@@ -361,17 +369,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids
[:,
1
:]
=
-
1
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
logprobs
=
proposal_scores
.
logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
return
accepted_token_ids
return
accepted_token_ids
,
logprobs
def
_create_output_sampler_list
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
k
:
int
,
)
->
List
[
SamplerOutput
]:
"""Given the accepted token ids, create a list of SamplerOutput.
...
...
@@ -379,30 +389,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
batch_size
,
num_steps
=
accepted_token_ids
.
shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
# Get the logprobs/rank of the accepted tokens.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
)
=
get_sampled_token_logprobs
(
logprob_tensor
=
target_logprobs_by_step
,
sampled_token_ids
=
accepted_token_ids_by_step
,
)
# shape: [k+1, batch_size]
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
).
tolist
()
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(
topk_logprobs_by_step
,
topk_indices_by_step
)
=
target_logprobs_by_step
.
topk
(
k
=
self
.
scorer_worker
.
model_config
.
max_logprobs
,
dim
=-
1
,
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize all tensors to CPU Python lists.
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
accepted_token_id_ranks_by_step
=
(
accepted_token_id_ranks_by_step
.
tolist
())
accepted_token_id_logprobs_by_step
=
(
accepted_token_id_logprobs_by_step
.
tolist
())
topk_logprobs_by_step
=
topk_logprobs_by_step
.
tolist
()
topk_indices_by_step
=
topk_indices_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list
=
[]
for
token_ids_by_step
in
accepted_token_ids_by_step
:
if
all
(
token_id
==
-
1
for
token_id
in
token_ids_by_step
):
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
for
token_id
in
accepted_token_ids_by_step
[
step_index
]):
break
step_output_token_ids
=
[]
for
token_id
,
seq_id
in
zip
(
token_ids_by_step
,
seq_ids
):
for
sequence_index
in
range
(
batch_size
):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
step_output_token_ids
.
append
(
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
# TODO Add verifier logprobs.
logprobs
=
{
token_id
:
Logprob
(
0.0
)},
)
],
prompt_logprobs
=
None
,
create_sequence_group_output
(
token_id
=
accepted_token_ids_by_step
[
step_index
]
[
sequence_index
],
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
step_index
][
sequence_index
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
step_index
][
sequence_index
],
seq_id
=
seq_ids
[
sequence_index
],
topk_token_ids
=
topk_indices_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
topk_logprobs
=
topk_logprobs_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
...
...
vllm/spec_decode/top1_proposer.py
View file @
ab502751
...
...
@@ -166,7 +166,7 @@ class Top1Proposer(SpeculativeProposer):
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
proposal_tokens
,
proposal_probs
,
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
# Now, reformat the output GPU tensors such that each sequence has
...
...
vllm/spec_decode/util.py
View file @
ab502751
from
contextlib
import
contextmanager
from
itertools
import
chain
from
typing
import
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SeqId
=
int
...
...
@@ -21,6 +22,89 @@ def get_all_seq_ids(
]))
def
get_all_num_logprobs
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
List
[
int
]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
num_logprobs
=
seq_group_metadata
.
sampling_params
.
logprobs
if
seq_group_metadata
.
sampling_params
.
logprobs
is
None
:
num_logprobs
=
0
all_num_logprobs
.
append
(
num_logprobs
)
return
all_num_logprobs
def
get_sampled_token_logprobs
(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
# shape [num_steps, batch_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps
,
batch_size
,
vocab_size
=
logprob_tensor
.
shape
selected_logprobs
=
logprob_tensor
[
torch
.
arange
(
num_steps
).
unsqueeze
(
1
),
torch
.
arange
(
batch_size
),
sampled_token_ids
,
]
expanded_selected_logprobs
=
selected_logprobs
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
vocab_size
)
sampled_token_ids_ranks
=
(
logprob_tensor
>=
expanded_selected_logprobs
).
sum
(
-
1
)
return
sampled_token_ids_ranks
,
selected_logprobs
def
create_sequence_group_output
(
token_id
:
int
,
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
seq_id
:
SeqId
,
topk_token_ids
:
List
[
int
],
topk_logprobs
:
List
[
float
],
)
->
SequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[int]): The list of top-k token ids.
topk_logprobs (List[float]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs
:
Dict
[
int
,
Logprob
]
=
{
token_id
:
Logprob
(
logprob
=
token_id_logprob
,
rank
=
token_id_logprob_rank
,
),
}
logprobs
.
update
({
topk_token_ids
[
topk_logprob_index
]:
Logprob
(
logprob
=
topk_logprobs
[
topk_logprob_index
],
rank
=
topk_logprob_index
+
1
,
)
for
topk_logprob_index
,
_
in
enumerate
(
topk_token_ids
)
})
return
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
logprobs
)
],
# TODO add prompt logprobs support.
prompt_logprobs
=
None
,
)
def
split_batch_by_proposal_len
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_lens
:
List
[
int
],
select_proposal_len_zero
:
bool
...
...
@@ -49,8 +133,8 @@ def split_batch_by_proposal_len(
def
sampler_output_to_torch
(
sampler_output_list
:
List
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
sampler_output_list
:
List
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
...
...
@@ -76,6 +160,15 @@ def sampler_output_to_torch(
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs
=
torch
.
stack
(
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output]
sampled_token_ids
=
torch
.
stack
(
[
...
...
@@ -87,7 +180,7 @@ def sampler_output_to_torch(
if
sampler_transposed
:
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
return
sampled_token_ids
,
sampled_token_probs
return
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment