Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
80f63a39
Unverified
Commit
80f63a39
authored
Feb 15, 2025
by
Lily Liu
Committed by
GitHub
Feb 15, 2025
Browse files
[V1][Spec Decode] Ngram Spec Decode (#12193)
Signed-off-by:
LiuXiaoxuanPKU
<
lilyliupku@gmail.com
>
parent
367cb8ce
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1024 additions
and
83 deletions
+1024
-83
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+196
-6
tests/v1/e2e/test_ngram_spec_decode.py
tests/v1/e2e/test_ngram_spec_decode.py
+49
-0
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+173
-0
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+2
-0
tests/v1/spec_decode/test_ngram.py
tests/v1/spec_decode/test_ngram.py
+32
-0
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+3
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+6
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+2
-3
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+21
-12
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+66
-31
vllm/v1/core/scheduler_output.py
vllm/v1/core/scheduler_output.py
+4
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+31
-0
vllm/v1/outputs.py
vllm/v1/outputs.py
+9
-3
vllm/v1/request.py
vllm/v1/request.py
+17
-0
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+2
-0
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+160
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+14
-1
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+99
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+11
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+127
-25
No files found.
tests/v1/core/test_scheduler.py
View file @
80f63a39
...
...
@@ -4,10 +4,12 @@ from typing import List, Optional
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.core.scheduler
import
Scheduler
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
EOS_TOKEN_ID
=
50256
def
create_scheduler
(
model
:
str
=
"facebook/opt-125m"
,
...
...
@@ -38,6 +40,7 @@ def create_scheduler(
return
Scheduler
(
scheduler_config
,
model_config
,
cache_config
,
speculative_config
=
None
,
lora_config
=
None
,
log_stats
=
True
)
...
...
@@ -46,8 +49,12 @@ def create_requests(
num_requests
:
int
,
num_tokens
:
int
=
10
,
mm_positions
:
Optional
[
List
[
PlaceholderRange
]]
=
None
,
max_tokens
:
int
=
16
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
):
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
stop_token_ids
=
stop_token_ids
)
requests
=
[]
for
i
in
range
(
num_requests
):
if
mm_positions
is
not
None
:
...
...
@@ -64,7 +71,7 @@ def create_requests(
multi_modal_inputs
=
mm_inputs
,
multi_modal_placeholders
=
mm_position
,
multi_modal_hashes
=
None
,
eos_token_id
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
arrival_time
=
0
,
)
requests
.
append
(
request
)
...
...
@@ -195,7 +202,7 @@ def test_schedule_partial_requests():
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
request
.
request_id
for
request
in
requests
],
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[
0
]
*
len
(
requests
),
sampled_token_ids
=
[
[
0
]
for
_
in
range
(
len
(
requests
)
)]
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
...
...
@@ -215,6 +222,189 @@ def test_schedule_partial_requests():
assert
requests
[
2
].
request_id
not
in
output
.
num_scheduled_tokens
def
test_stop_via_update_from_output
():
"""Test stopping behavior through update_from_output"""
scheduler
=
create_scheduler
()
# Test case 1: Stop on EOS token
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
10
)
for
req
in
requests
:
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
requests
[
0
].
request_id
:
1
,
requests
[
1
].
request_id
:
2
},
total_num_scheduled_tokens
=
3
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[],
requests
[
1
].
request_id
:
[
10
]
},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[])
model_output
=
ModelRunnerOutput
(
req_ids
=
[
req
.
request_id
for
req
in
requests
],
req_id_to_index
=
{
req
.
request_id
:
i
for
i
,
req
in
enumerate
(
requests
)
},
sampled_token_ids
=
[[
EOS_TOKEN_ID
],
[
10
,
11
]],
# First request hits EOS, second continues
logprobs
=
None
,
prompt_logprobs_dict
=
{})
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# Verify first request stopped, second continues
assert
len
(
scheduler
.
running
)
==
1
assert
scheduler
.
running
[
0
].
request_id
==
requests
[
1
].
request_id
assert
requests
[
0
].
status
==
RequestStatus
.
FINISHED_STOPPED
assert
requests
[
0
].
request_id
in
scheduler
.
finished_req_ids
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
EOS_TOKEN_ID
]
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
10
,
11
]
# Test case 2: Stop on custom stop token
scheduler
=
create_scheduler
()
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
10
,
stop_token_ids
=
[
42
,
43
])
for
req
in
requests
:
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
requests
[
0
].
request_id
:
3
,
requests
[
1
].
request_id
:
2
},
total_num_scheduled_tokens
=
5
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[
10
,
42
],
requests
[
1
].
request_id
:
[
13
]
},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[])
model_output
=
ModelRunnerOutput
(
req_ids
=
[
req
.
request_id
for
req
in
requests
],
req_id_to_index
=
{
req
.
request_id
:
i
for
i
,
req
in
enumerate
(
requests
)
},
sampled_token_ids
=
[[
10
,
42
,
12
],
[
13
,
14
]],
# First request hits stop token
logprobs
=
None
,
prompt_logprobs_dict
=
{})
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# Verify first request stopped on custom token
assert
len
(
scheduler
.
running
)
==
1
assert
scheduler
.
running
[
0
].
request_id
==
requests
[
1
].
request_id
assert
requests
[
0
].
status
==
RequestStatus
.
FINISHED_STOPPED
assert
requests
[
0
].
stop_reason
==
42
assert
requests
[
0
].
request_id
in
scheduler
.
finished_req_ids
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
10
,
42
]
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
13
,
14
]
# Test case 3: Stop on max tokens
scheduler
=
create_scheduler
()
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
2
)
for
req
in
requests
:
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
requests
[
0
].
request_id
:
3
,
requests
[
1
].
request_id
:
1
},
total_num_scheduled_tokens
=
4
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[
10
,
11
],
requests
[
1
].
request_id
:
[]
},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[])
model_output
=
ModelRunnerOutput
(
req_ids
=
[
req
.
request_id
for
req
in
requests
],
req_id_to_index
=
{
req
.
request_id
:
i
for
i
,
req
in
enumerate
(
requests
)
},
sampled_token_ids
=
[[
10
,
11
,
12
],
[
13
]],
# First request exceeds max_tokens
logprobs
=
None
,
prompt_logprobs_dict
=
{})
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# Verify first request stopped due to length
assert
len
(
scheduler
.
running
)
==
1
assert
scheduler
.
running
[
0
].
request_id
==
requests
[
1
].
request_id
assert
requests
[
0
].
status
==
RequestStatus
.
FINISHED_LENGTH_CAPPED
assert
requests
[
0
].
request_id
in
scheduler
.
finished_req_ids
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
10
,
11
]
# Truncated to max_tokens
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
13
]
# Test case 4: Ignore EOS flag
scheduler
=
create_scheduler
()
requests
=
create_requests
(
num_requests
=
1
,
max_tokens
=
10
)
requests
[
0
].
sampling_params
.
ignore_eos
=
True
requests
[
0
].
num_computed_tokens
=
requests
[
0
].
num_tokens
scheduler
.
requests
[
requests
[
0
].
request_id
]
=
requests
[
0
]
scheduler
.
running
.
append
(
requests
[
0
])
scheduler
.
scheduled_req_ids
.
add
(
requests
[
0
].
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
requests
[
0
].
request_id
:
3
},
total_num_scheduled_tokens
=
3
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[
EOS_TOKEN_ID
,
10
]
},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[])
model_output
=
ModelRunnerOutput
(
req_ids
=
[
requests
[
0
].
request_id
],
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
sampled_token_ids
=
[[
EOS_TOKEN_ID
,
10
,
11
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{})
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# Verify request continues past EOS
assert
len
(
scheduler
.
running
)
==
1
assert
not
requests
[
0
].
is_finished
()
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
EOS_TOKEN_ID
,
10
,
11
]
def
test_schedule_concurrent_batches
():
scheduler
=
create_scheduler
(
max_num_batched_tokens
=
1024
,
...
...
@@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
requests
[
0
].
request_id
],
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
sampled_token_ids
=
[
0
],
sampled_token_ids
=
[
[
0
]
],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
...
...
@@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
requests
[
1
].
request_id
],
req_id_to_index
=
{
requests
[
1
].
request_id
:
0
},
sampled_token_ids
=
[
0
],
sampled_token_ids
=
[
[
0
]
],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
...
...
tests/v1/e2e/test_ngram_spec_decode.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm
import
LLM
,
SamplingParams
@
pytest
.
fixture
def
test_prompts
():
return
[
"Can you repeat the sentence ten times, this is a sentence."
,
"Can you repeat the sentence ten times, this is a test."
,
]
@
pytest
.
fixture
def
sampling_config
():
# Only support greedy for now
return
SamplingParams
(
temperature
=
0
,
max_tokens
=
30
,
ignore_eos
=
False
)
@
pytest
.
fixture
def
model_name
():
return
"meta-llama/Meta-Llama-3-8B-Instruct"
def
test_ngram_correctness
(
monkeypatch
,
test_prompts
,
sampling_config
,
model_name
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
ref_llm
=
LLM
(
model
=
model_name
)
ref_outputs
=
ref_llm
.
generate
(
test_prompts
,
sampling_config
)
del
ref_llm
spec_llm
=
LLM
(
model
=
model_name
,
speculative_model
=
'[ngram]'
,
ngram_prompt_lookup_max
=
5
,
ngram_prompt_lookup_min
=
3
,
num_speculative_tokens
=
3
)
spec_outputs
=
spec_llm
.
generate
(
test_prompts
,
sampling_config
)
for
ref_output
,
spec_output
in
zip
(
ref_outputs
,
spec_outputs
):
assert
ref_output
.
outputs
[
0
].
text
==
spec_output
.
outputs
[
0
].
text
,
\
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
,"
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
del
spec_llm
tests/v1/sample/test_rejection_sampler.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
torch
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
,
RejectionSampler
@
pytest
.
fixture
def
sampler
():
return
RejectionSampler
()
def
create_logits_tensor
(
token_ids
:
List
[
int
],
vocab_size
:
int
=
100
)
->
torch
.
Tensor
:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
logits
=
torch
.
full
((
len
(
token_ids
),
vocab_size
),
-
100.0
).
cuda
()
for
i
,
token_id
in
enumerate
(
token_ids
):
logits
[
i
,
token_id
]
=
100.0
return
logits
def
create_sampling_metadata
(
spec_tokens
:
List
[
List
[
int
]])
->
SamplingMetadata
:
batch_size
=
len
(
spec_tokens
)
return
SamplingMetadata
(
temperature
=
0.0
,
all_greedy
=
True
,
all_random
=
False
,
rejection_sampling
=
True
,
spec_token_ids
=
spec_tokens
,
top_p
=
None
,
top_k
=
None
,
no_top_p
=
False
,
no_top_k
=
False
,
min_p
=
torch
.
empty
(
batch_size
,
),
no_min_p
=
True
,
generators
=
{},
max_num_logprobs
=
0
,
no_penalties
=
False
,
prompt_token_ids
=
None
,
frequency_penalties
=
torch
.
tensor
([]),
presence_penalties
=
torch
.
tensor
([]),
repetition_penalties
=
torch
.
tensor
([]),
output_token_ids
=
[],
min_tokens
=
[],
stop_token_ids
=
[],
logit_bias
=
[
None
]
*
batch_size
,
)
def
test_perfect_match
(
sampler
):
"""Test when output tokens perfectly match speculated tokens"""
spec_tokens
=
[[
1
,
2
,
3
]]
output_tokens
=
[
1
,
2
,
3
,
4
]
# 4 is the bonus token
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
3
,
4
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_early_mismatch
(
sampler
):
"""Test when there's an early mismatch in tokens"""
spec_tokens
=
[[
1
,
2
,
3
]]
output_tokens
=
[
1
,
5
,
3
,
4
]
# Mismatch at position 1
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
5
,
INVALID_TOKEN_ID
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_multiple_sequences
(
sampler
):
"""Test handling multiple sequences of speculated tokens"""
spec_tokens
=
[[
1
,
2
],
[
3
]]
output_tokens
=
[
1
,
2
,
5
,
3
,
4
]
# Two sequences with bonus tokens 5 and 4
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
5
],
[
3
,
4
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_single_token_sequence
(
sampler
):
"""Test handling sequences with single token"""
spec_tokens
=
[[
1
]]
output_tokens
=
[
1
,
2
]
# Single token with bonus token 2
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_empty_sequence
(
sampler
):
"""Test handling empty sequence of speculated tokens"""
spec_tokens
:
List
[
List
[
int
]]
=
[[]]
output_tokens
=
[
5
]
# Just the bonus token
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
5
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_multiple_mismatches
(
sampler
):
"""Test handling multiple sequences with mismatches"""
spec_tokens
=
[[
1
,
2
,
3
],
[
4
,
5
,
6
]]
output_tokens
=
[
1
,
2
,
7
,
6
,
4
,
8
,
6
,
9
]
# Mismatches in both sequences
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
7
,
INVALID_TOKEN_ID
],
[
4
,
8
,
INVALID_TOKEN_ID
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
@
pytest
.
mark
.
parametrize
(
"spec_tokens,output_tokens,expected"
,
[
([[
1
,
2
]],
[
1
,
2
,
3
],
[[
1
,
2
,
3
]]),
# Perfect match with bonus
([[
1
]],
[
2
,
3
],
[[
2
,
INVALID_TOKEN_ID
]]),
# First mismatch
([[
1
,
2
],
[
3
,
4
]],
[
1
,
5
,
6
,
3
,
4
,
7
],
[[
1
,
5
,
INVALID_TOKEN_ID
],
[
3
,
4
,
7
]]),
# Mixed matches
])
def
test_parametrized_cases
(
sampler
,
spec_tokens
,
output_tokens
,
expected
):
"""Parametrized test for various matching scenarios"""
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected_tensor
=
torch
.
tensor
(
expected
,
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected_tensor
)
def
test_logits_shape_handling
(
sampler
):
"""Test handling of different logits tensor shapes"""
spec_tokens
=
[[
1
,
2
]]
output_tokens
=
[
1
,
2
,
3
]
vocab_size
=
1000
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
,
vocab_size
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
assert
logits
.
shape
[
-
1
]
==
vocab_size
tests/v1/sample/test_sampler.py
View file @
80f63a39
...
...
@@ -77,6 +77,7 @@ def _create_default_sampling_metadata(
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
all_greedy
=
True
,
all_random
=
False
,
rejection_sampling
=
False
,
top_p
=
torch
.
empty
(
batch_size
,
),
top_k
=
torch
.
empty
(
batch_size
,
),
no_top_p
=
True
,
...
...
@@ -88,6 +89,7 @@ def _create_default_sampling_metadata(
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
spec_token_ids
=
[],
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
...
...
tests/v1/spec_decode/test_ngram.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
ConstantList
@
pytest
.
fixture
def
proposer
():
return
NgramProposer
()
def
test_kmp_lps_array
(
proposer
):
assert
proposer
.
_kmp_lps_array
([])
==
[]
assert
proposer
.
_kmp_lps_array
([
1
])
==
[
0
]
assert
proposer
.
_kmp_lps_array
([
1
,
1
,
1
])
==
[
0
,
1
,
2
]
assert
proposer
.
_kmp_lps_array
([
1
,
2
,
3
,
4
])
==
[
0
,
0
,
0
,
0
]
assert
proposer
.
_kmp_lps_array
([
1
,
2
,
1
,
2
,
3
])
==
[
0
,
0
,
1
,
2
,
0
]
def
test_find_subarray_kmp
(
proposer
):
X
=
ConstantList
([
1
,
2
,
3
,
4
,
1
,
2
,
3
,
5
,
6
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
2
)
is
None
X
=
ConstantList
([
1
,
2
,
3
,
4
,
1
,
2
,
3
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
3
)
==
[
4
,
1
,
2
]
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
2
)
==
[
4
,
1
]
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
3
)
==
[
4
,
1
,
2
]
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
2
)
==
[
4
,
1
]
X
=
ConstantList
([
1
,
3
,
6
,
2
,
3
,
4
,
1
,
2
,
3
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
3
)
==
[
4
,
1
,
2
]
# Return on the first match
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
3
)
==
[
6
,
2
,
3
]
\ No newline at end of file
tests/v1/worker/test_gpu_input_batch.py
View file @
80f63a39
...
...
@@ -92,6 +92,7 @@ def _construct_expected_sampling_metadata(
device
=
device
),
all_greedy
=
False
,
all_random
=
True
,
rejection_sampling
=
False
,
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_k
=
torch
.
tensor
(
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
no_top_p
=
all
(
x
==
1.0
for
x
in
top_p
),
...
...
@@ -116,6 +117,7 @@ def _construct_expected_sampling_metadata(
dtype
=
torch
.
float
,
device
=
device
),
output_token_ids
=
output_token_ids
,
spec_token_ids
=
[],
min_tokens
=
min_tokens
,
stop_token_ids
=
stop_token_ids
,
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
...
...
@@ -205,7 +207,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
# Generate the sampling metadata
sampling_metadata
=
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
skip_copy
=
False
)
req_id_output_token_ids
,
req_id_to_spec_token_ids
=
{},
skip_copy
=
False
)
# Create expected output.
expected_sampling_metadata
=
_construct_expected_sampling_metadata
(
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
80f63a39
...
...
@@ -66,6 +66,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
...
...
@@ -109,6 +110,7 @@ def test_update_states_request_finished(model_runner):
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{
req_id
},
...
...
@@ -137,6 +139,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{},
...
...
@@ -160,6 +163,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_cached_reqs
=
[
cached_req_data
],
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
...
...
@@ -188,6 +192,7 @@ def test_update_states_no_changes(model_runner):
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
...
...
@@ -220,6 +225,7 @@ def test_update_states_request_unscheduled(model_runner):
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_ids
[
0
]:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
...
...
vllm/platforms/cuda.py
View file @
80f63a39
...
...
@@ -124,9 +124,8 @@ class CudaPlatformBase(Platform):
"vllm.worker.multi_step_worker.MultiStepWorker"
elif
vllm_config
.
speculative_config
:
if
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"Speculative decoding is not yet supported on VLLM V1."
)
parallel_config
.
worker_cls
=
\
"vllm.v1.worker.gpu_worker.Worker"
else
:
parallel_config
.
worker_cls
=
\
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
...
...
vllm/v1/core/kv_cache_manager.py
View file @
80f63a39
...
...
@@ -82,6 +82,11 @@ class KVCacheManager:
self
.
req_to_block_hashes
:
DefaultDict
[
str
,
List
[
BlockHashType
]]
=
defaultdict
(
list
)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self
.
num_cached_block
:
Dict
[
str
,
int
]
=
defaultdict
(
int
)
self
.
prefix_cache_stats
=
PrefixCacheStats
()
@
property
...
...
@@ -241,23 +246,25 @@ class KVCacheManager:
if
not
self
.
enable_caching
:
return
new_blocks
# NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks
=
(
num_computed_tokens
+
num_tokens
)
//
self
.
block_size
num_computed_full_blocks
=
num_computed_tokens
//
self
.
block_size
new_full_blocks
=
req_blocks
[
num_computed_full_blocks
:
num_full_blocks
]
num_cached_blocks
=
self
.
num_cached_block
[
request
.
request_id
]
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
num_full_blocks_after_append
=
(
num_computed_tokens
+
num_tokens
-
len
(
request
.
spec_token_ids
))
//
self
.
block_size
new_full_blocks
=
req_blocks
[
num_cached_blocks
:
num_full_blocks_after_append
]
if
new_full_blocks
:
self
.
_cache_full_blocks
(
request
=
request
,
blk_start_idx
=
num_c
omputed_full
_blocks
,
blk_start_idx
=
num_c
ached
_blocks
,
# The new full blocks are the full blocks that are not computed.
full_blocks
=
new_full_blocks
,
prev_block
=
(
req_blocks
[
num_computed_full_blocks
-
1
]
if
num_computed_full_blocks
>
0
else
None
))
prev_block
=
(
req_blocks
[
num_cached_blocks
-
1
]
if
num_cached_blocks
>
0
else
None
))
self
.
num_cached_block
[
request
.
request_id
]
=
num_full_blocks_after_append
return
new_blocks
def
free
(
self
,
request
:
Request
)
->
None
:
...
...
@@ -281,6 +288,8 @@ class KVCacheManager:
if
block
.
ref_cnt
==
0
:
self
.
free_block_queue
.
append
(
block
)
self
.
num_cached_block
.
pop
(
request
.
request_id
,
None
)
def
reset_prefix_cache
(
self
)
->
bool
:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
...
...
vllm/v1/core/scheduler.py
View file @
80f63a39
...
...
@@ -4,7 +4,8 @@ import time
from
collections
import
deque
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
compute_encoder_budget
)
...
...
@@ -28,11 +29,13 @@ class Scheduler:
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
log_stats
:
bool
,
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
speculative_config
=
speculative_config
self
.
log_stats
=
log_stats
# Scheduling constraints.
...
...
@@ -96,12 +99,14 @@ class Scheduler:
def
schedule
(
self
)
->
"SchedulerOutput"
:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and num_tokens,
# which is equal to len(prompt_token_ids) + len(output_token_ids).
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens. This is general enough to cover chunked prefills,
# prefix caching, and the "jump decoding" optimization in the future.
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs
:
List
[
Request
]
=
[]
scheduled_resumed_reqs
:
List
[
Request
]
=
[]
...
...
@@ -114,7 +119,8 @@ class Scheduler:
# Encoder-related.
scheduled_encoder_inputs
:
Dict
[
str
,
List
[
int
]]
=
{}
encoder_budget
=
self
.
max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens
:
Dict
[
str
,
List
[
int
]]
=
{}
scheduled_timestamp
=
time
.
monotonic
()
# First, schedule the RUNNING requests.
...
...
@@ -126,7 +132,8 @@ class Scheduler:
req_index
+=
1
continue
num_new_tokens
=
request
.
num_tokens
-
request
.
num_computed_tokens
num_new_tokens
=
(
request
.
num_tokens_with_spec
-
request
.
num_computed_tokens
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
...
...
@@ -189,6 +196,11 @@ class Scheduler:
self
.
encoder_cache_manager
.
allocate
(
request
,
i
)
encoder_budget
=
new_encoder_budget
# Speculative decode related.
if
request
.
spec_token_ids
:
scheduled_spec_decode_tokens
[
request
.
request_id
]
=
request
.
spec_token_ids
# Record the LoRAs in scheduled_running_reqs
requested_loras
:
Set
[
int
]
=
set
()
if
self
.
lora_config
:
...
...
@@ -338,6 +350,7 @@ class Scheduler:
num_scheduled_tokens
=
num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_encoder_inputs
=
scheduled_encoder_inputs
,
scheduled_spec_decode_tokens
=
scheduled_spec_decode_tokens
,
num_common_prefix_blocks
=
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
...
...
@@ -447,11 +460,11 @@ class Scheduler:
scheduler_output
:
"SchedulerOutput"
,
model_runner_output
:
"ModelRunnerOutput"
,
)
->
EngineCoreOutputs
:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
logprobs
=
model_runner_output
.
logprobs
prompt_logprobs_dict
=
model_runner_output
.
prompt_logprobs_dict
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
new_running
:
List
[
Request
]
=
[]
outputs
:
List
[
EngineCoreOutput
]
=
[]
...
...
@@ -466,11 +479,30 @@ class Scheduler:
new_running
.
append
(
request
)
continue
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
req_id
not
in
scheduler_output
.
scheduled_spec_decode_tokens
:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
# Otherwise, we ignore the sampler output for the request.
request
.
num_computed_tokens
+=
num_tokens_scheduled
# When the request's num_computed_tokens catches up its num_tokens,
# the request generates output tokens. Otherwise, we ignore the
# sampler output for the request.
assert
request
.
num_computed_tokens
<=
request
.
num_tokens
else
:
# num_computed_tokens_step represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections.
# It is calculated as:
# num_computed_tokens_step = num_scheduled_tokens -
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
[
req_id
])
num_computed_tokens_step
=
num_scheduled_tokens
[
req_id
]
-
(
len
(
scheduled_spec_token_ids
)
+
1
-
len
(
generated_token_ids
))
request
.
num_computed_tokens
+=
num_computed_tokens_step
cached_encoder_input_ids
=
(
self
.
encoder_cache_manager
.
get_cached_input_ids
(
request
))
...
...
@@ -485,27 +517,32 @@ class Scheduler:
self
.
encoder_cache_manager
.
free_encoder_input
(
request
,
input_id
)
if
request
.
num_computed_tokens
>=
request
.
num_tokens
:
# Clear the spec tokens as the request has generated
# a new token. Here, We assume all spec tokens are verified
# if we perform speculative decoding for this request.
# Therefore, we can clear all spec tokens after
# the generation step.
request
.
clear_spec_tokens
()
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
None
new_token_ids
:
List
[
int
]
=
[]
if
request
.
num_computed_tokens
==
request
.
num_tokens
:
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id
=
sampled_token_ids
[
req_index
]
request
.
append_output_token_ids
(
token_id
)
num_new_tokens
=
1
# TODO: Update the KV cache manager for prefix caching.
if
request
.
num_computed_tokens
>=
request
.
num_tokens
:
for
output_token_id
in
generated_token_ids
:
request
.
append_output_token_ids
(
output_token_id
)
new_token_ids
.
append
(
output_token_id
)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped
=
self
.
_check_stop
(
request
)
if
stopped
:
self
.
_free_request
(
request
)
break
# Extract sample logprobs if needed.
if
request
.
sampling_params
.
logprobs
is
not
None
:
...
...
@@ -514,8 +551,6 @@ class Scheduler:
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
new_token_ids
=
request
.
output_token_ids
[
-
num_new_tokens
:]
# Transmit partial if chunked prefill & prompt logprobs is enabled
if
new_token_ids
or
prompt_logprobs_tensors
is
not
None
:
# Add EngineCoreOutput for this Request.
...
...
vllm/v1/core/scheduler_output.py
View file @
80f63a39
...
...
@@ -91,6 +91,10 @@ class SchedulerOutput:
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens
:
int
# req_id -> spec_decode_tokens
# If a request does not have any spec decode tokens, it will
# not be included in the dictionary.
scheduled_spec_decode_tokens
:
Dict
[
str
,
List
[
int
]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
...
...
vllm/v1/engine/core.py
View file @
80f63a39
...
...
@@ -27,6 +27,7 @@ from vllm.v1.executor.abstract import Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -65,6 +66,7 @@ class EngineCore:
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
speculative_config
=
vllm_config
.
speculative_config
,
log_stats
=
self
.
log_stats
,
)
...
...
@@ -84,6 +86,15 @@ class EngineCore:
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
# Setup speculative decode.
# TODO: find a better way to check if we are using ngram.
self
.
use_spec_decode
=
False
if
self
.
scheduler
.
speculative_config
:
assert
self
.
scheduler
.
speculative_config
.
ngram_prompt_lookup_min
\
,
"Only ngram spec decode is supported in V1."
self
.
proposer
=
NgramProposer
()
self
.
use_spec_decode
=
True
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
Tuple
[
int
,
int
]:
start
=
time
.
time
()
...
...
@@ -147,6 +158,9 @@ class EngineCore:
return
EngineCoreOutputs
(
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
())
if
self
.
use_spec_decode
:
self
.
propose_tokens
()
scheduler_output
=
self
.
scheduler
.
schedule
()
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
...
...
@@ -207,6 +221,23 @@ class EngineCore:
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
model_executor
.
profile
(
is_start
)
def
propose_tokens
(
self
):
assert
self
.
scheduler
.
speculative_config
is
not
None
for
req
in
self
.
scheduler
.
running
:
# Ignore requests that are doing chunked prefill.
if
req
.
num_computed_tokens
<
req
.
num_tokens
-
1
:
continue
# Ignore requests that already have spec tokens.
if
req
.
spec_token_ids
:
continue
spec_tokens
=
self
.
proposer
.
propose
(
req
.
all_token_ids
,
self
.
scheduler
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
scheduler
.
speculative_config
.
num_speculative_tokens
,
)
if
spec_tokens
:
req
.
append_spec_token_ids
(
spec_tokens
)
def
reset_prefix_cache
(
self
):
self
.
scheduler
.
reset_prefix_cache
()
...
...
vllm/v1/outputs.py
View file @
80f63a39
...
...
@@ -43,7 +43,10 @@ class LogprobsTensors(NamedTuple):
@
dataclass
class
SamplerOutput
:
# [num_reqs]
# [num_reqs, max_num_generated_tokens]
# Different requests can have different number of generated tokens.
# All requests are padded to max_num_generated_tokens.
# INVALID_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids
:
torch
.
Tensor
logprobs_tensors
:
Optional
[
LogprobsTensors
]
...
...
@@ -58,8 +61,11 @@ class ModelRunnerOutput:
# req_id -> index
req_id_to_index
:
Dict
[
str
,
int
]
# [num_reqs]
sampled_token_ids
:
List
[
int
]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids
:
List
[
List
[
int
]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
...
...
vllm/v1/request.py
View file @
80f63a39
...
...
@@ -46,6 +46,7 @@ class Request:
self
.
num_prompt_tokens
=
len
(
self
.
prompt_token_ids
)
self
.
_output_token_ids
:
List
[
int
]
=
[]
self
.
_all_token_ids
:
List
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
spec_token_ids
:
List
[
int
]
=
[]
self
.
num_computed_tokens
=
0
# Multi-modal related
...
...
@@ -103,10 +104,26 @@ class Request:
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
def
append_spec_token_ids
(
self
,
token_ids
:
Union
[
int
,
List
[
int
]],
)
->
None
:
if
isinstance
(
token_ids
,
int
):
self
.
spec_token_ids
.
append
(
token_ids
)
else
:
self
.
spec_token_ids
.
extend
(
token_ids
)
def
clear_spec_tokens
(
self
)
->
None
:
self
.
spec_token_ids
.
clear
()
@
property
def
num_tokens
(
self
)
->
int
:
return
len
(
self
.
_all_token_ids
)
@
property
def
num_tokens_with_spec
(
self
)
->
int
:
return
len
(
self
.
_all_token_ids
)
+
len
(
self
.
spec_token_ids
)
@
property
def
num_output_tokens
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
...
...
vllm/v1/sample/metadata.py
View file @
80f63a39
...
...
@@ -12,6 +12,8 @@ class SamplingMetadata:
temperature
:
torch
.
Tensor
all_greedy
:
bool
all_random
:
bool
rejection_sampling
:
bool
spec_token_ids
:
List
[
List
[
int
]]
top_p
:
torch
.
Tensor
top_k
:
torch
.
Tensor
...
...
vllm/v1/sample/rejection_sampler.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
import
torch
import
torch.nn
as
nn
from
torch.nn.utils.rnn
import
pad_sequence
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
try
:
import
flashinfer.sampling
as
fs
is_flashinfer_available
=
True
except
ImportError
:
is_flashinfer_available
=
False
logger
=
init_logger
(
__name__
)
INVALID_TOKEN_ID
=
-
1
class
RejectionSampler
(
nn
.
Module
):
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
if
not
sampling_metadata
.
all_greedy
:
raise
NotImplementedError
(
"Only greedy sampling is supported by rejection sampler."
)
if
is_flashinfer_available
:
logger
.
info
(
"User FlashInfer for rejection sampling."
)
return
RejectionSampler
.
flashinfer_sample
(
logits
,
sampling_metadata
)
else
:
logger
.
warning
(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling."
)
return
RejectionSampler
.
greedy_sample_native
(
logits
,
sampling_metadata
)
@
staticmethod
def
flashinfer_sample
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
spec_token_ids
=
sampling_metadata
.
spec_token_ids
max_spec_len
=
max
(
len
(
s
)
for
s
in
spec_token_ids
)
batch_size
=
len
(
spec_token_ids
)
draft_token_ids
=
torch
.
full
((
batch_size
,
max_spec_len
),
INVALID_TOKEN_ID
,
device
=
"cpu"
,
dtype
=
torch
.
long
)
target_token_ids
=
torch
.
full
((
batch_size
,
max_spec_len
+
1
),
fill_value
=
INVALID_TOKEN_ID
,
device
=
logits
.
device
,
dtype
=
torch
.
long
)
# TODO: Vectorize the following loop for better performance.
start_loc
=
0
for
i
in
range
(
batch_size
):
num_spec_tokens
=
len
(
spec_token_ids
[
i
])
draft_token_ids
[
i
,
:
num_spec_tokens
]
=
torch
.
tensor
(
spec_token_ids
[
i
],
device
=
"cpu"
,
dtype
=
torch
.
long
)
end_loc
=
start_loc
+
num_spec_tokens
+
1
# Assume greedy sampling.
target_token_ids
[
i
,
:
num_spec_tokens
+
1
]
=
torch
.
argmax
(
logits
[
start_loc
:
end_loc
],
dim
=-
1
)
start_loc
=
end_loc
vocab_size
=
logits
.
size
(
-
1
)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids
=
draft_token_ids
.
to
(
logits
.
device
)
draft_probs
=
RejectionSampler
.
_create_greedy_token_probs
(
draft_token_ids
,
vocab_size
,
logits
.
device
)
target_probs
=
RejectionSampler
.
_create_greedy_token_probs
(
target_token_ids
,
vocab_size
,
logits
.
device
)
uniform_samples
=
torch
.
zeros
(
batch_size
,
max_spec_len
+
1
,
device
=
logits
.
device
)
sampled_token_ids
,
_
,
_
=
fs
.
chain_speculative_sampling
(
draft_probs
,
draft_token_ids
,
uniform_samples
,
target_probs
,
)
return
SamplerOutput
(
sampled_token_ids
=
sampled_token_ids
,
logprobs_tensors
=
None
)
# TODO: The following method can be optimized for better performance.
@
staticmethod
def
greedy_sample_native
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
spec_lens
=
[
len
(
x
)
for
x
in
sampling_metadata
.
spec_token_ids
]
# Add 1 to include the 'bonus' token.
sample_lens
=
[
x
+
1
for
x
in
spec_lens
]
output_token_ids
=
logits
.
argmax
(
dim
=-
1
).
view
(
-
1
)
output_token_ids
=
output_token_ids
.
split
(
sample_lens
)
output_token_ids
=
pad_sequence
(
output_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
spec_token_ids
=
[
torch
.
tensor
(
x
,
dtype
=
output_token_ids
.
dtype
,
device
=
output_token_ids
.
device
)
for
x
in
sampling_metadata
.
spec_token_ids
]
spec_token_ids
=
pad_sequence
(
spec_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask
=
(
output_token_ids
[:,
:
-
1
]
==
spec_token_ids
).
cumprod
(
dim
=
1
)
# Identify valid positions (non-padding).
valid_mask
=
output_token_ids
!=
INVALID_TOKEN_ID
# Generate mask with bonus token.
generate_mask
=
torch
.
cat
([
accept_mask
,
torch
.
zeros
(
accept_mask
.
size
(
0
),
1
,
device
=
accept_mask
.
device
)
],
dim
=
1
).
to
(
torch
.
bool
)
&
valid_mask
zeros_mask
=
(
generate_mask
==
0
)
first_zero_idx
=
zeros_mask
.
float
().
argmax
(
dim
=
1
)
# Figure out which rows actually contain at least one zero.
rows_with_zero
=
zeros_mask
.
any
(
dim
=
1
)
# Use indexing to set the first zero in each of those rows to 1.
generate_mask
[
rows_with_zero
,
first_zero_idx
[
rows_with_zero
]]
=
1
output_token_ids
[
~
generate_mask
]
=
INVALID_TOKEN_ID
return
SamplerOutput
(
sampled_token_ids
=
output_token_ids
,
logprobs_tensors
=
None
)
@
staticmethod
def
_create_greedy_token_probs
(
token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
out_device
:
torch
.
device
)
->
torch
.
Tensor
:
batch_size
,
num_tokens
=
token_ids
.
shape
token_probs
=
torch
.
zeros
(
batch_size
,
num_tokens
,
vocab_size
,
dtype
=
torch
.
float
,
device
=
out_device
)
# Ignore INVALID_TOKEN_ID.
valid_mask
=
(
token_ids
!=
INVALID_TOKEN_ID
)
valid_indices
=
token_ids
.
clone
()
valid_indices
[
~
valid_mask
]
=
0
token_probs
.
scatter_
(
dim
=
2
,
index
=
valid_indices
.
unsqueeze
(
-
1
),
src
=
valid_mask
.
unsqueeze
(
-
1
).
float
())
return
token_probs
vllm/v1/sample/sampler.py
View file @
80f63a39
...
...
@@ -9,6 +9,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from
vllm.v1.sample.ops.penalties
import
(
apply_all_penalties
,
apply_min_token_penalties
)
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
_SAMPLING_EPS
=
1e-5
...
...
@@ -18,12 +19,21 @@ class Sampler(nn.Module):
def
__init__
(
self
):
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
rejection_sampler
=
RejectionSampler
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
if
sampling_metadata
.
rejection_sampling
:
if
sampling_metadata
.
max_num_logprobs
:
raise
NotImplementedError
(
"Rejection sampling does not support logprobs."
)
return
self
.
rejection_sampler
(
logits
,
sampling_metadata
,
)
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
...
...
@@ -54,7 +64,10 @@ class Sampler(nn.Module):
# These are GPU tensors.
sampler_output
=
SamplerOutput
(
sampled_token_ids
=
sampled
,
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids
=
sampled
.
unsqueeze
(
-
1
),
logprobs_tensors
=
logprobs_tensors
,
)
return
sampler_output
...
...
vllm/v1/spec_decode/ngram_proposer.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
vllm.v1.utils
import
ConstantList
class
NgramProposer
:
def
__init__
(
self
):
pass
def
propose
(
self
,
context_token_ids
:
ConstantList
[
int
],
n
:
int
,
k
:
int
)
->
Optional
[
List
[
int
]]:
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
that match.
Args:
context_token_ids: List of token IDs representing the
context sequence.
n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return
the maximum amount of tokens until the end.
Returns:
List[int]: The sequence of tokens that followed
the matched n-gram in the context.
None: If no matching n-gram pattern is found.
Example:
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
- The last 2 tokens [2,3] will be matched against the previous
4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# TODO: Use c++ to implement the _find_subarray_kmp to
# improve the efficiency
return
self
.
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
@
staticmethod
def
_kmp_lps_array
(
pattern
:
List
[
int
])
->
List
[
int
]:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps
=
[
0
]
*
len
(
pattern
)
prev_lps
=
0
# length of the previous longest prefix suffix
i
=
1
while
i
<
len
(
pattern
):
if
pattern
[
i
]
==
pattern
[
prev_lps
]:
prev_lps
+=
1
lps
[
i
]
=
prev_lps
i
+=
1
else
:
if
prev_lps
!=
0
:
prev_lps
=
lps
[
prev_lps
-
1
]
else
:
lps
[
i
]
=
0
i
+=
1
return
lps
@
staticmethod
def
_find_subarray_kmp
(
context_token_ids
:
ConstantList
[
int
],
n
:
int
,
k
:
int
)
->
Optional
[
List
[
int
]]:
context_len
=
len
(
context_token_ids
)
assert
n
>
0
pattern
=
context_token_ids
[
-
n
:]
# Precompute lps array for Y
lps
=
NgramProposer
.
_kmp_lps_array
(
pattern
)
i
=
0
j
=
0
# -n because the last n tokens are used as pattern
while
i
<
context_len
-
n
:
if
context_token_ids
[
i
]
==
pattern
[
j
]:
i
+=
1
j
+=
1
# If we have matched the entire Y
if
j
==
n
:
# Found pattern in context, gather the next K elements
return
context_token_ids
[
i
:
i
+
k
]
else
:
# Mismatch
if
j
!=
0
:
# Use the lps array to avoid re-checking elements
j
=
lps
[
j
-
1
]
else
:
i
+=
1
# Y not found
return
None
vllm/v1/worker/gpu_input_batch.py
View file @
80f63a39
...
...
@@ -390,6 +390,7 @@ class InputBatch:
def
make_sampling_metadata
(
self
,
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]],
req_id_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
skip_copy
:
bool
=
False
,
)
->
SamplingMetadata
:
if
not
skip_copy
:
...
...
@@ -423,7 +424,8 @@ class InputBatch:
self
.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
output_token_ids
:
List
[
List
[
int
]]
=
[]
spec_token_ids
:
List
[
List
[
int
]]
=
[]
rejection_sampling
=
False
for
req_id
in
self
.
req_ids
[:
self
.
num_reqs
]:
assert
req_id
is
not
None
# Currently we create a tensor for output_token_ids from scratch
...
...
@@ -434,11 +436,18 @@ class InputBatch:
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids
.
append
(
req_id_output_token_ids
[
req_id
])
req_spec_token_ids
=
req_id_to_spec_token_ids
.
get
(
req_id
,
[])
spec_token_ids
.
append
(
req_spec_token_ids
)
if
req_spec_token_ids
:
# If any of the requests require speculative decoding, set the
# flag to True.
rejection_sampling
=
True
return
SamplingMetadata
(
temperature
=
self
.
temperature
[:
self
.
num_reqs
],
all_greedy
=
self
.
all_greedy
,
all_random
=
self
.
all_random
,
rejection_sampling
=
rejection_sampling
,
top_p
=
self
.
top_p
[:
self
.
num_reqs
],
top_k
=
self
.
top_k
[:
self
.
num_reqs
],
min_p
=
self
.
min_p
[:
self
.
num_reqs
],
...
...
@@ -452,6 +461,7 @@ class InputBatch:
presence_penalties
=
self
.
presence_penalties
[:
self
.
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
self
.
num_reqs
],
output_token_ids
=
output_token_ids
,
spec_token_ids
=
spec_token_ids
,
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
no_penalties
=
self
.
no_penalties
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
80f63a39
...
...
@@ -32,6 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
...
...
@@ -180,6 +181,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
max_model_len
,
self
.
max_num_tokens
),
dtype
=
np
.
int32
)
self
.
arange_cpu
=
torch
.
from_numpy
(
self
.
arange_np
)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
...
...
@@ -368,7 +370,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
batch_changed
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
Tuple
[
FlashAttentionMetadata
,
torch
.
Tensor
]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
...
...
@@ -382,12 +386,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens_list
:
List
[
int
]
=
[]
max_num_scheduled_tokens
=
0
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
all_spec_token_ids
:
List
[
int
]
=
[]
num_spec_tokens_list
:
List
[
int
]
=
[]
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens_list
.
append
(
num_tokens
)
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
num_tokens
)
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
all_spec_token_ids
.
extend
(
spec_token_ids
)
num_spec_tokens_list
.
append
(
len
(
spec_token_ids
))
num_scheduled_tokens
:
np
.
ndarray
=
np
.
array
(
num_scheduled_tokens_list
,
dtype
=
np
.
int32
)
assert
max_num_scheduled_tokens
>
0
...
...
@@ -426,6 +437,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# where M is the max_model_len.
token_indices
=
(
positions_np
+
req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
use_spec_decode
=
len
(
all_spec_token_ids
)
>
0
if
use_spec_decode
:
# 1. Write spec_token_ids to input batch.
# Step 1. Get req indices that perform spec decode and repeat
# the req indices by the number of spec tokens. Note
# for requests that don't perform spec decode, the
# number of spec tokens is 0 and the req index is
# repeated 0 times.
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
# spec_req_indices: [0, 0, 0, 2, 2, 4]
spec_req_indices
=
np
.
repeat
(
self
.
arange_np
[:
num_reqs
],
num_spec_tokens_list
)
# spec_offsets: offsets within each spec token list.
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
spec_offsets
=
np
.
concatenate
(
[
self
.
arange_np
[
1
:
val
+
1
]
for
val
in
num_spec_tokens_list
])
# spec_seq_offsets: offsets within each sequence.
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
# after repeating: [1, 1, 1, 3, 3, 2]
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
# = [2, 3, 4, 4, 5, 3]
spec_seq_offsets
=
np
.
repeat
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
],
num_spec_tokens_list
)
+
spec_offsets
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
cumsums_spec_offsets
=
(
spec_seq_offsets
+
spec_req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
cumsums_spec_offsets
=
torch
.
from_numpy
(
cumsums_spec_offsets
).
to
(
torch
.
int64
)
all_spec_token_ids
=
torch
.
tensor
(
all_spec_token_ids
,
device
=
"cpu"
,
dtype
=
self
.
input_ids_cpu
.
dtype
)
# Step 2. Write spec token ids to input_ids_cpu.
self
.
input_batch
.
token_ids_cpu_tensor
.
flatten
().
scatter_
(
0
,
cumsums_spec_offsets
,
all_spec_token_ids
)
# 2. Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_spec_tokens_np
=
np
.
array
(
num_spec_tokens_list
,
dtype
=
np
.
int32
)
num_sampled_tokens
=
num_spec_tokens_np
+
1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc
=
cu_num_tokens
-
num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc
=
np
.
repeat
(
logits_start_loc
,
num_sampled_tokens
)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens
=
np
.
cumsum
(
num_sampled_tokens
)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets
=
np
.
repeat
(
cu_num_sampled_tokens
-
num_sampled_tokens
,
num_sampled_tokens
)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens
=
num_sampled_tokens
.
sum
()
sampled_arange
=
(
self
.
arange_np
[:
total_num_sampled_tokens
]
-
cumsums_sampled_offsets
)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices
=
logits_start_loc
+
sampled_arange
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
...
...
@@ -519,16 +603,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
suffix_kv_lens
=
suffix_kv_lens
,
)
if
use_spec_decode
:
logits_indices
=
torch
.
from_numpy
(
spec_decode_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
else
:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
# Hot-Swap lora model
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
# tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
return
attn_metadata
,
logits_indices
def
_compute_cascade_attn_prefix_len
(
...
...
@@ -673,6 +762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_prepare_sampling
(
self
,
batch_changed
:
bool
,
req_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
)
->
SamplingMetadata
:
# Create the sampling metadata.
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]]
=
\
...
...
@@ -680,7 +770,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
req_id
,
req
in
self
.
requests
.
items
()}
sampling_metadata
=
self
.
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
skip_copy
=
not
batch_changed
)
req_id_output_token_ids
,
req_to_spec_token_ids
,
not
batch_changed
)
return
sampling_metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
...
...
@@ -847,7 +937,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
_prepare_sampling
(
batch_changed
)
sampling_metadata
=
self
.
_prepare_sampling
(
batch_changed
,
scheduler_output
.
scheduled_spec_decode_tokens
)
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
...
...
@@ -857,18 +948,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the requests one by one. Optimize.
num_reqs
=
self
.
input_batch
.
num_reqs
request_seq_lens
:
List
[
Tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
for
i
,
req_id
in
enumerate
(
# type: ignore[assignment]
self
.
input_batch
.
req_ids
[:
num_reqs
]):
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
assert
seq_len
<=
req_state
.
num_tokens
if
seq_len
==
req_state
.
num_tokens
:
# Append the sampled token to the output token ids.
self
.
input_batch
.
num_tokens
[
i
]
+=
1
# OPTIMIZATION: Priming the state updates for later updates.
req_state
.
output_token_ids
.
append
(
0
)
if
seq_len
>=
req_state
.
num_tokens
:
request_seq_lens
.
append
((
i
,
req_state
,
seq_len
))
else
:
# Ignore the sampled token from the partial request.
...
...
@@ -886,7 +971,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
.
tolist
()
logprobs_tensors
=
sampler_output
.
logprobs_tensors
logprobs_lists
=
logprobs_tensors
.
tolists
()
\
if
logprobs_tensors
is
not
None
else
None
...
...
@@ -897,16 +981,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
,
)
# Update with the actual token ids
# Update batch with the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
token_id
=
sampled_token_ids
[
i
]
token_id
=
valid_
sampled_token_ids
[
i
]
[
0
]
self
.
input_batch
.
token_ids_cpu
[
i
,
seq_len
]
=
token_id
req_state
.
output_token_ids
[
-
1
]
=
token_id
req_state
.
output_token_ids
.
append
(
token_id
)
self
.
input_batch
.
num_tokens
[
i
]
+=
1
else
:
valid_mask
=
sampled_token_ids
!=
INVALID_TOKEN_ID
gen_lens
=
valid_mask
.
sum
(
dim
=
1
).
tolist
()
valid_sampled_token_ids
=
[
seq
.
tolist
()
for
seq
in
sampled_token_ids
[
valid_mask
].
split
(
gen_lens
)
]
self
.
input_batch
.
num_tokens
[:
num_reqs
]
+=
gen_lens
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
target_slice
=
slice
(
seq_len
-
gen_lens
[
i
]
+
1
,
seq_len
+
1
)
self
.
input_batch
.
token_ids_cpu
[
i
,
target_slice
]
=
valid_sampled_token_ids
[
i
]
req_state
.
output_token_ids
.
extend
(
valid_sampled_token_ids
[
i
])
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
valid_
sampled_token_ids
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment