Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
80f63a39
Unverified
Commit
80f63a39
authored
Feb 15, 2025
by
Lily Liu
Committed by
GitHub
Feb 15, 2025
Browse files
[V1][Spec Decode] Ngram Spec Decode (#12193)
Signed-off-by:
LiuXiaoxuanPKU
<
lilyliupku@gmail.com
>
parent
367cb8ce
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1024 additions
and
83 deletions
+1024
-83
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+196
-6
tests/v1/e2e/test_ngram_spec_decode.py
tests/v1/e2e/test_ngram_spec_decode.py
+49
-0
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+173
-0
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+2
-0
tests/v1/spec_decode/test_ngram.py
tests/v1/spec_decode/test_ngram.py
+32
-0
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+3
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+6
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+2
-3
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+21
-12
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+66
-31
vllm/v1/core/scheduler_output.py
vllm/v1/core/scheduler_output.py
+4
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+31
-0
vllm/v1/outputs.py
vllm/v1/outputs.py
+9
-3
vllm/v1/request.py
vllm/v1/request.py
+17
-0
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+2
-0
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+160
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+14
-1
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+99
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+11
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+127
-25
No files found.
tests/v1/core/test_scheduler.py
View file @
80f63a39
...
@@ -4,10 +4,12 @@ from typing import List, Optional
...
@@ -4,10 +4,12 @@ from typing import List, Optional
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.core.scheduler
import
Scheduler
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
EOS_TOKEN_ID
=
50256
def
create_scheduler
(
def
create_scheduler
(
model
:
str
=
"facebook/opt-125m"
,
model
:
str
=
"facebook/opt-125m"
,
...
@@ -38,6 +40,7 @@ def create_scheduler(
...
@@ -38,6 +40,7 @@ def create_scheduler(
return
Scheduler
(
scheduler_config
,
return
Scheduler
(
scheduler_config
,
model_config
,
model_config
,
cache_config
,
cache_config
,
speculative_config
=
None
,
lora_config
=
None
,
lora_config
=
None
,
log_stats
=
True
)
log_stats
=
True
)
...
@@ -46,8 +49,12 @@ def create_requests(
...
@@ -46,8 +49,12 @@ def create_requests(
num_requests
:
int
,
num_requests
:
int
,
num_tokens
:
int
=
10
,
num_tokens
:
int
=
10
,
mm_positions
:
Optional
[
List
[
PlaceholderRange
]]
=
None
,
mm_positions
:
Optional
[
List
[
PlaceholderRange
]]
=
None
,
max_tokens
:
int
=
16
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
):
):
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
stop_token_ids
=
stop_token_ids
)
requests
=
[]
requests
=
[]
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
if
mm_positions
is
not
None
:
if
mm_positions
is
not
None
:
...
@@ -64,7 +71,7 @@ def create_requests(
...
@@ -64,7 +71,7 @@ def create_requests(
multi_modal_inputs
=
mm_inputs
,
multi_modal_inputs
=
mm_inputs
,
multi_modal_placeholders
=
mm_position
,
multi_modal_placeholders
=
mm_position
,
multi_modal_hashes
=
None
,
multi_modal_hashes
=
None
,
eos_token_id
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
arrival_time
=
0
,
arrival_time
=
0
,
)
)
requests
.
append
(
request
)
requests
.
append
(
request
)
...
@@ -195,7 +202,7 @@ def test_schedule_partial_requests():
...
@@ -195,7 +202,7 @@ def test_schedule_partial_requests():
model_runner_output
=
ModelRunnerOutput
(
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
request
.
request_id
for
request
in
requests
],
req_ids
=
[
request
.
request_id
for
request
in
requests
],
req_id_to_index
=
req_to_index
,
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[
0
]
*
len
(
requests
),
sampled_token_ids
=
[
[
0
]
for
_
in
range
(
len
(
requests
)
)]
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
prompt_logprobs_dict
=
{},
)
)
...
@@ -215,6 +222,189 @@ def test_schedule_partial_requests():
...
@@ -215,6 +222,189 @@ def test_schedule_partial_requests():
assert
requests
[
2
].
request_id
not
in
output
.
num_scheduled_tokens
assert
requests
[
2
].
request_id
not
in
output
.
num_scheduled_tokens
def
test_stop_via_update_from_output
():
"""Test stopping behavior through update_from_output"""
scheduler
=
create_scheduler
()
# Test case 1: Stop on EOS token
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
10
)
for
req
in
requests
:
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
requests
[
0
].
request_id
:
1
,
requests
[
1
].
request_id
:
2
},
total_num_scheduled_tokens
=
3
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[],
requests
[
1
].
request_id
:
[
10
]
},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[])
model_output
=
ModelRunnerOutput
(
req_ids
=
[
req
.
request_id
for
req
in
requests
],
req_id_to_index
=
{
req
.
request_id
:
i
for
i
,
req
in
enumerate
(
requests
)
},
sampled_token_ids
=
[[
EOS_TOKEN_ID
],
[
10
,
11
]],
# First request hits EOS, second continues
logprobs
=
None
,
prompt_logprobs_dict
=
{})
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# Verify first request stopped, second continues
assert
len
(
scheduler
.
running
)
==
1
assert
scheduler
.
running
[
0
].
request_id
==
requests
[
1
].
request_id
assert
requests
[
0
].
status
==
RequestStatus
.
FINISHED_STOPPED
assert
requests
[
0
].
request_id
in
scheduler
.
finished_req_ids
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
EOS_TOKEN_ID
]
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
10
,
11
]
# Test case 2: Stop on custom stop token
scheduler
=
create_scheduler
()
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
10
,
stop_token_ids
=
[
42
,
43
])
for
req
in
requests
:
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
requests
[
0
].
request_id
:
3
,
requests
[
1
].
request_id
:
2
},
total_num_scheduled_tokens
=
5
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[
10
,
42
],
requests
[
1
].
request_id
:
[
13
]
},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[])
model_output
=
ModelRunnerOutput
(
req_ids
=
[
req
.
request_id
for
req
in
requests
],
req_id_to_index
=
{
req
.
request_id
:
i
for
i
,
req
in
enumerate
(
requests
)
},
sampled_token_ids
=
[[
10
,
42
,
12
],
[
13
,
14
]],
# First request hits stop token
logprobs
=
None
,
prompt_logprobs_dict
=
{})
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# Verify first request stopped on custom token
assert
len
(
scheduler
.
running
)
==
1
assert
scheduler
.
running
[
0
].
request_id
==
requests
[
1
].
request_id
assert
requests
[
0
].
status
==
RequestStatus
.
FINISHED_STOPPED
assert
requests
[
0
].
stop_reason
==
42
assert
requests
[
0
].
request_id
in
scheduler
.
finished_req_ids
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
10
,
42
]
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
13
,
14
]
# Test case 3: Stop on max tokens
scheduler
=
create_scheduler
()
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
2
)
for
req
in
requests
:
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
requests
[
0
].
request_id
:
3
,
requests
[
1
].
request_id
:
1
},
total_num_scheduled_tokens
=
4
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[
10
,
11
],
requests
[
1
].
request_id
:
[]
},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[])
model_output
=
ModelRunnerOutput
(
req_ids
=
[
req
.
request_id
for
req
in
requests
],
req_id_to_index
=
{
req
.
request_id
:
i
for
i
,
req
in
enumerate
(
requests
)
},
sampled_token_ids
=
[[
10
,
11
,
12
],
[
13
]],
# First request exceeds max_tokens
logprobs
=
None
,
prompt_logprobs_dict
=
{})
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# Verify first request stopped due to length
assert
len
(
scheduler
.
running
)
==
1
assert
scheduler
.
running
[
0
].
request_id
==
requests
[
1
].
request_id
assert
requests
[
0
].
status
==
RequestStatus
.
FINISHED_LENGTH_CAPPED
assert
requests
[
0
].
request_id
in
scheduler
.
finished_req_ids
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
10
,
11
]
# Truncated to max_tokens
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
13
]
# Test case 4: Ignore EOS flag
scheduler
=
create_scheduler
()
requests
=
create_requests
(
num_requests
=
1
,
max_tokens
=
10
)
requests
[
0
].
sampling_params
.
ignore_eos
=
True
requests
[
0
].
num_computed_tokens
=
requests
[
0
].
num_tokens
scheduler
.
requests
[
requests
[
0
].
request_id
]
=
requests
[
0
]
scheduler
.
running
.
append
(
requests
[
0
])
scheduler
.
scheduled_req_ids
.
add
(
requests
[
0
].
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
requests
[
0
].
request_id
:
3
},
total_num_scheduled_tokens
=
3
,
scheduled_encoder_inputs
=
{},
scheduled_spec_decode_tokens
=
{
requests
[
0
].
request_id
:
[
EOS_TOKEN_ID
,
10
]
},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[])
model_output
=
ModelRunnerOutput
(
req_ids
=
[
requests
[
0
].
request_id
],
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
sampled_token_ids
=
[[
EOS_TOKEN_ID
,
10
,
11
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{})
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# Verify request continues past EOS
assert
len
(
scheduler
.
running
)
==
1
assert
not
requests
[
0
].
is_finished
()
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
EOS_TOKEN_ID
,
10
,
11
]
def
test_schedule_concurrent_batches
():
def
test_schedule_concurrent_batches
():
scheduler
=
create_scheduler
(
scheduler
=
create_scheduler
(
max_num_batched_tokens
=
1024
,
max_num_batched_tokens
=
1024
,
...
@@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
...
@@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
model_runner_output
=
ModelRunnerOutput
(
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
requests
[
0
].
request_id
],
req_ids
=
[
requests
[
0
].
request_id
],
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
sampled_token_ids
=
[
0
],
sampled_token_ids
=
[
[
0
]
],
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
prompt_logprobs_dict
=
{},
)
)
...
@@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
...
@@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
model_runner_output
=
ModelRunnerOutput
(
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
requests
[
1
].
request_id
],
req_ids
=
[
requests
[
1
].
request_id
],
req_id_to_index
=
{
requests
[
1
].
request_id
:
0
},
req_id_to_index
=
{
requests
[
1
].
request_id
:
0
},
sampled_token_ids
=
[
0
],
sampled_token_ids
=
[
[
0
]
],
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
prompt_logprobs_dict
=
{},
)
)
...
...
tests/v1/e2e/test_ngram_spec_decode.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm
import
LLM
,
SamplingParams
@
pytest
.
fixture
def
test_prompts
():
return
[
"Can you repeat the sentence ten times, this is a sentence."
,
"Can you repeat the sentence ten times, this is a test."
,
]
@
pytest
.
fixture
def
sampling_config
():
# Only support greedy for now
return
SamplingParams
(
temperature
=
0
,
max_tokens
=
30
,
ignore_eos
=
False
)
@
pytest
.
fixture
def
model_name
():
return
"meta-llama/Meta-Llama-3-8B-Instruct"
def
test_ngram_correctness
(
monkeypatch
,
test_prompts
,
sampling_config
,
model_name
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
ref_llm
=
LLM
(
model
=
model_name
)
ref_outputs
=
ref_llm
.
generate
(
test_prompts
,
sampling_config
)
del
ref_llm
spec_llm
=
LLM
(
model
=
model_name
,
speculative_model
=
'[ngram]'
,
ngram_prompt_lookup_max
=
5
,
ngram_prompt_lookup_min
=
3
,
num_speculative_tokens
=
3
)
spec_outputs
=
spec_llm
.
generate
(
test_prompts
,
sampling_config
)
for
ref_output
,
spec_output
in
zip
(
ref_outputs
,
spec_outputs
):
assert
ref_output
.
outputs
[
0
].
text
==
spec_output
.
outputs
[
0
].
text
,
\
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
,"
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
del
spec_llm
tests/v1/sample/test_rejection_sampler.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
torch
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
,
RejectionSampler
@
pytest
.
fixture
def
sampler
():
return
RejectionSampler
()
def
create_logits_tensor
(
token_ids
:
List
[
int
],
vocab_size
:
int
=
100
)
->
torch
.
Tensor
:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
logits
=
torch
.
full
((
len
(
token_ids
),
vocab_size
),
-
100.0
).
cuda
()
for
i
,
token_id
in
enumerate
(
token_ids
):
logits
[
i
,
token_id
]
=
100.0
return
logits
def
create_sampling_metadata
(
spec_tokens
:
List
[
List
[
int
]])
->
SamplingMetadata
:
batch_size
=
len
(
spec_tokens
)
return
SamplingMetadata
(
temperature
=
0.0
,
all_greedy
=
True
,
all_random
=
False
,
rejection_sampling
=
True
,
spec_token_ids
=
spec_tokens
,
top_p
=
None
,
top_k
=
None
,
no_top_p
=
False
,
no_top_k
=
False
,
min_p
=
torch
.
empty
(
batch_size
,
),
no_min_p
=
True
,
generators
=
{},
max_num_logprobs
=
0
,
no_penalties
=
False
,
prompt_token_ids
=
None
,
frequency_penalties
=
torch
.
tensor
([]),
presence_penalties
=
torch
.
tensor
([]),
repetition_penalties
=
torch
.
tensor
([]),
output_token_ids
=
[],
min_tokens
=
[],
stop_token_ids
=
[],
logit_bias
=
[
None
]
*
batch_size
,
)
def
test_perfect_match
(
sampler
):
"""Test when output tokens perfectly match speculated tokens"""
spec_tokens
=
[[
1
,
2
,
3
]]
output_tokens
=
[
1
,
2
,
3
,
4
]
# 4 is the bonus token
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
3
,
4
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_early_mismatch
(
sampler
):
"""Test when there's an early mismatch in tokens"""
spec_tokens
=
[[
1
,
2
,
3
]]
output_tokens
=
[
1
,
5
,
3
,
4
]
# Mismatch at position 1
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
5
,
INVALID_TOKEN_ID
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_multiple_sequences
(
sampler
):
"""Test handling multiple sequences of speculated tokens"""
spec_tokens
=
[[
1
,
2
],
[
3
]]
output_tokens
=
[
1
,
2
,
5
,
3
,
4
]
# Two sequences with bonus tokens 5 and 4
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
5
],
[
3
,
4
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_single_token_sequence
(
sampler
):
"""Test handling sequences with single token"""
spec_tokens
=
[[
1
]]
output_tokens
=
[
1
,
2
]
# Single token with bonus token 2
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_empty_sequence
(
sampler
):
"""Test handling empty sequence of speculated tokens"""
spec_tokens
:
List
[
List
[
int
]]
=
[[]]
output_tokens
=
[
5
]
# Just the bonus token
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
5
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
def
test_multiple_mismatches
(
sampler
):
"""Test handling multiple sequences with mismatches"""
spec_tokens
=
[[
1
,
2
,
3
],
[
4
,
5
,
6
]]
output_tokens
=
[
1
,
2
,
7
,
6
,
4
,
8
,
6
,
9
]
# Mismatches in both sequences
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
7
,
INVALID_TOKEN_ID
],
[
4
,
8
,
INVALID_TOKEN_ID
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
@
pytest
.
mark
.
parametrize
(
"spec_tokens,output_tokens,expected"
,
[
([[
1
,
2
]],
[
1
,
2
,
3
],
[[
1
,
2
,
3
]]),
# Perfect match with bonus
([[
1
]],
[
2
,
3
],
[[
2
,
INVALID_TOKEN_ID
]]),
# First mismatch
([[
1
,
2
],
[
3
,
4
]],
[
1
,
5
,
6
,
3
,
4
,
7
],
[[
1
,
5
,
INVALID_TOKEN_ID
],
[
3
,
4
,
7
]]),
# Mixed matches
])
def
test_parametrized_cases
(
sampler
,
spec_tokens
,
output_tokens
,
expected
):
"""Parametrized test for various matching scenarios"""
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
expected_tensor
=
torch
.
tensor
(
expected
,
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected_tensor
)
def
test_logits_shape_handling
(
sampler
):
"""Test handling of different logits tensor shapes"""
spec_tokens
=
[[
1
,
2
]]
output_tokens
=
[
1
,
2
,
3
]
vocab_size
=
1000
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
,
vocab_size
)
output
=
sampler
(
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
assert
logits
.
shape
[
-
1
]
==
vocab_size
tests/v1/sample/test_sampler.py
View file @
80f63a39
...
@@ -77,6 +77,7 @@ def _create_default_sampling_metadata(
...
@@ -77,6 +77,7 @@ def _create_default_sampling_metadata(
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
all_greedy
=
True
,
all_greedy
=
True
,
all_random
=
False
,
all_random
=
False
,
rejection_sampling
=
False
,
top_p
=
torch
.
empty
(
batch_size
,
),
top_p
=
torch
.
empty
(
batch_size
,
),
top_k
=
torch
.
empty
(
batch_size
,
),
top_k
=
torch
.
empty
(
batch_size
,
),
no_top_p
=
True
,
no_top_p
=
True
,
...
@@ -88,6 +89,7 @@ def _create_default_sampling_metadata(
...
@@ -88,6 +89,7 @@ def _create_default_sampling_metadata(
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
output_token_ids
=
output_token_ids
,
spec_token_ids
=
[],
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
...
...
tests/v1/spec_decode/test_ngram.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
ConstantList
@
pytest
.
fixture
def
proposer
():
return
NgramProposer
()
def
test_kmp_lps_array
(
proposer
):
assert
proposer
.
_kmp_lps_array
([])
==
[]
assert
proposer
.
_kmp_lps_array
([
1
])
==
[
0
]
assert
proposer
.
_kmp_lps_array
([
1
,
1
,
1
])
==
[
0
,
1
,
2
]
assert
proposer
.
_kmp_lps_array
([
1
,
2
,
3
,
4
])
==
[
0
,
0
,
0
,
0
]
assert
proposer
.
_kmp_lps_array
([
1
,
2
,
1
,
2
,
3
])
==
[
0
,
0
,
1
,
2
,
0
]
def
test_find_subarray_kmp
(
proposer
):
X
=
ConstantList
([
1
,
2
,
3
,
4
,
1
,
2
,
3
,
5
,
6
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
2
)
is
None
X
=
ConstantList
([
1
,
2
,
3
,
4
,
1
,
2
,
3
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
3
)
==
[
4
,
1
,
2
]
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
2
)
==
[
4
,
1
]
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
3
)
==
[
4
,
1
,
2
]
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
2
)
==
[
4
,
1
]
X
=
ConstantList
([
1
,
3
,
6
,
2
,
3
,
4
,
1
,
2
,
3
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
3
)
==
[
4
,
1
,
2
]
# Return on the first match
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
3
)
==
[
6
,
2
,
3
]
\ No newline at end of file
tests/v1/worker/test_gpu_input_batch.py
View file @
80f63a39
...
@@ -92,6 +92,7 @@ def _construct_expected_sampling_metadata(
...
@@ -92,6 +92,7 @@ def _construct_expected_sampling_metadata(
device
=
device
),
device
=
device
),
all_greedy
=
False
,
all_greedy
=
False
,
all_random
=
True
,
all_random
=
True
,
rejection_sampling
=
False
,
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_k
=
torch
.
tensor
(
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
top_k
=
torch
.
tensor
(
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
no_top_p
=
all
(
x
==
1.0
for
x
in
top_p
),
no_top_p
=
all
(
x
==
1.0
for
x
in
top_p
),
...
@@ -116,6 +117,7 @@ def _construct_expected_sampling_metadata(
...
@@ -116,6 +117,7 @@ def _construct_expected_sampling_metadata(
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
device
),
device
=
device
),
output_token_ids
=
output_token_ids
,
output_token_ids
=
output_token_ids
,
spec_token_ids
=
[],
min_tokens
=
min_tokens
,
min_tokens
=
min_tokens
,
stop_token_ids
=
stop_token_ids
,
stop_token_ids
=
stop_token_ids
,
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
...
@@ -205,7 +207,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
...
@@ -205,7 +207,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
# Generate the sampling metadata
# Generate the sampling metadata
sampling_metadata
=
input_batch
.
make_sampling_metadata
(
sampling_metadata
=
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
skip_copy
=
False
)
req_id_output_token_ids
,
req_id_to_spec_token_ids
=
{},
skip_copy
=
False
)
# Create expected output.
# Create expected output.
expected_sampling_metadata
=
_construct_expected_sampling_metadata
(
expected_sampling_metadata
=
_construct_expected_sampling_metadata
(
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
80f63a39
...
@@ -66,6 +66,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -66,6 +66,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_cached_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
finished_req_ids
=
set
(),
...
@@ -109,6 +110,7 @@ def test_update_states_request_finished(model_runner):
...
@@ -109,6 +110,7 @@ def test_update_states_request_finished(model_runner):
scheduled_cached_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{
req_id
},
finished_req_ids
=
{
req_id
},
...
@@ -137,6 +139,7 @@ def test_update_states_request_resumed(model_runner):
...
@@ -137,6 +139,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_cached_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{},
finished_req_ids
=
{},
...
@@ -160,6 +163,7 @@ def test_update_states_request_resumed(model_runner):
...
@@ -160,6 +163,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_cached_reqs
=
[
cached_req_data
],
scheduled_cached_reqs
=
[
cached_req_data
],
num_scheduled_tokens
=
{
req_id
:
1
},
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
finished_req_ids
=
set
(),
...
@@ -188,6 +192,7 @@ def test_update_states_no_changes(model_runner):
...
@@ -188,6 +192,7 @@ def test_update_states_no_changes(model_runner):
scheduled_cached_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_id
:
1
},
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
finished_req_ids
=
set
(),
...
@@ -220,6 +225,7 @@ def test_update_states_request_unscheduled(model_runner):
...
@@ -220,6 +225,7 @@ def test_update_states_request_unscheduled(model_runner):
scheduled_cached_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_ids
[
0
]:
1
},
num_scheduled_tokens
=
{
req_ids
[
0
]:
1
},
total_num_scheduled_tokens
=
1
,
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
finished_req_ids
=
set
(),
...
...
vllm/platforms/cuda.py
View file @
80f63a39
...
@@ -124,9 +124,8 @@ class CudaPlatformBase(Platform):
...
@@ -124,9 +124,8 @@ class CudaPlatformBase(Platform):
"vllm.worker.multi_step_worker.MultiStepWorker"
"vllm.worker.multi_step_worker.MultiStepWorker"
elif
vllm_config
.
speculative_config
:
elif
vllm_config
.
speculative_config
:
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
parallel_config
.
worker_cls
=
\
"Speculative decoding is not yet supported on VLLM V1."
"vllm.v1.worker.gpu_worker.Worker"
)
else
:
else
:
parallel_config
.
worker_cls
=
\
parallel_config
.
worker_cls
=
\
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
...
...
vllm/v1/core/kv_cache_manager.py
View file @
80f63a39
...
@@ -82,6 +82,11 @@ class KVCacheManager:
...
@@ -82,6 +82,11 @@ class KVCacheManager:
self
.
req_to_block_hashes
:
DefaultDict
[
self
.
req_to_block_hashes
:
DefaultDict
[
str
,
List
[
BlockHashType
]]
=
defaultdict
(
list
)
str
,
List
[
BlockHashType
]]
=
defaultdict
(
list
)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self
.
num_cached_block
:
Dict
[
str
,
int
]
=
defaultdict
(
int
)
self
.
prefix_cache_stats
=
PrefixCacheStats
()
self
.
prefix_cache_stats
=
PrefixCacheStats
()
@
property
@
property
...
@@ -241,23 +246,25 @@ class KVCacheManager:
...
@@ -241,23 +246,25 @@ class KVCacheManager:
if
not
self
.
enable_caching
:
if
not
self
.
enable_caching
:
return
new_blocks
return
new_blocks
# NOTE(rickyx): We are assuming the `num_tokens` are actual
num_cached_blocks
=
self
.
num_cached_block
[
request
.
request_id
]
# tokens rather than lookahead slots (e.g. for speculative decoding).
# Speculated tokens might be rejected in the future, so we does
# TODO(rickyx): When supporting speculative decoding, we will need to
# not cache any speculated tokens. We only cache blocks with
# differentiate between them so that we can know how many blocks are
# generated (accepted) tokens.
# full after appending the actual tokens.
num_full_blocks_after_append
=
(
num_computed_tokens
+
num_tokens
-
len
(
num_full_blocks
=
(
num_computed_tokens
+
num_tokens
)
//
self
.
block_size
request
.
spec_token_ids
))
//
self
.
block_size
num_computed_full_blocks
=
num_computed_tokens
//
self
.
block_size
new_full_blocks
=
req_blocks
[
new_full_blocks
=
req_blocks
[
num_computed_full_blocks
:
num_full_blocks
]
num_cached_blocks
:
num_full_blocks_after_append
]
if
new_full_blocks
:
if
new_full_blocks
:
self
.
_cache_full_blocks
(
self
.
_cache_full_blocks
(
request
=
request
,
request
=
request
,
blk_start_idx
=
num_c
omputed_full
_blocks
,
blk_start_idx
=
num_c
ached
_blocks
,
# The new full blocks are the full blocks that are not computed.
# The new full blocks are the full blocks that are not computed.
full_blocks
=
new_full_blocks
,
full_blocks
=
new_full_blocks
,
prev_block
=
(
req_blocks
[
num_computed_full_blocks
-
1
]
prev_block
=
(
req_blocks
[
num_cached_blocks
-
if
num_computed_full_blocks
>
0
else
None
))
1
]
if
num_cached_blocks
>
0
else
None
))
self
.
num_cached_block
[
request
.
request_id
]
=
num_full_blocks_after_append
return
new_blocks
return
new_blocks
def
free
(
self
,
request
:
Request
)
->
None
:
def
free
(
self
,
request
:
Request
)
->
None
:
...
@@ -281,6 +288,8 @@ class KVCacheManager:
...
@@ -281,6 +288,8 @@ class KVCacheManager:
if
block
.
ref_cnt
==
0
:
if
block
.
ref_cnt
==
0
:
self
.
free_block_queue
.
append
(
block
)
self
.
free_block_queue
.
append
(
block
)
self
.
num_cached_block
.
pop
(
request
.
request_id
,
None
)
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
)
->
bool
:
"""Reset prefix cache. This function may be used in RLHF
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
flows to invalid prefix caching after the weights are updated,
...
...
vllm/v1/core/scheduler.py
View file @
80f63a39
...
@@ -4,7 +4,8 @@ import time
...
@@ -4,7 +4,8 @@ import time
from
collections
import
deque
from
collections
import
deque
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
compute_encoder_budget
)
compute_encoder_budget
)
...
@@ -28,11 +29,13 @@ class Scheduler:
...
@@ -28,11 +29,13 @@ class Scheduler:
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
log_stats
:
bool
,
log_stats
:
bool
,
)
->
None
:
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
speculative_config
=
speculative_config
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
# Scheduling constraints.
# Scheduling constraints.
...
@@ -96,12 +99,14 @@ class Scheduler:
...
@@ -96,12 +99,14 @@ class Scheduler:
def
schedule
(
self
)
->
"SchedulerOutput"
:
def
schedule
(
self
)
->
"SchedulerOutput"
:
# NOTE(woosuk) on the scheduling algorithm:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and num_tokens,
# Each request just has the num_computed_tokens and
# which is equal to len(prompt_token_ids) + len(output_token_ids).
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# so that each request's num_computed_tokens can catch up its
# num_tokens. This is general enough to cover chunked prefills,
# num_tokens_with_spec. This is general enough to cover
# prefix caching, and the "jump decoding" optimization in the future.
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs
:
List
[
Request
]
=
[]
scheduled_new_reqs
:
List
[
Request
]
=
[]
scheduled_resumed_reqs
:
List
[
Request
]
=
[]
scheduled_resumed_reqs
:
List
[
Request
]
=
[]
...
@@ -114,7 +119,8 @@ class Scheduler:
...
@@ -114,7 +119,8 @@ class Scheduler:
# Encoder-related.
# Encoder-related.
scheduled_encoder_inputs
:
Dict
[
str
,
List
[
int
]]
=
{}
scheduled_encoder_inputs
:
Dict
[
str
,
List
[
int
]]
=
{}
encoder_budget
=
self
.
max_num_encoder_input_tokens
encoder_budget
=
self
.
max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens
:
Dict
[
str
,
List
[
int
]]
=
{}
scheduled_timestamp
=
time
.
monotonic
()
scheduled_timestamp
=
time
.
monotonic
()
# First, schedule the RUNNING requests.
# First, schedule the RUNNING requests.
...
@@ -126,7 +132,8 @@ class Scheduler:
...
@@ -126,7 +132,8 @@ class Scheduler:
req_index
+=
1
req_index
+=
1
continue
continue
num_new_tokens
=
request
.
num_tokens
-
request
.
num_computed_tokens
num_new_tokens
=
(
request
.
num_tokens_with_spec
-
request
.
num_computed_tokens
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
assert
num_new_tokens
>
0
...
@@ -189,6 +196,11 @@ class Scheduler:
...
@@ -189,6 +196,11 @@ class Scheduler:
self
.
encoder_cache_manager
.
allocate
(
request
,
i
)
self
.
encoder_cache_manager
.
allocate
(
request
,
i
)
encoder_budget
=
new_encoder_budget
encoder_budget
=
new_encoder_budget
# Speculative decode related.
if
request
.
spec_token_ids
:
scheduled_spec_decode_tokens
[
request
.
request_id
]
=
request
.
spec_token_ids
# Record the LoRAs in scheduled_running_reqs
# Record the LoRAs in scheduled_running_reqs
requested_loras
:
Set
[
int
]
=
set
()
requested_loras
:
Set
[
int
]
=
set
()
if
self
.
lora_config
:
if
self
.
lora_config
:
...
@@ -338,6 +350,7 @@ class Scheduler:
...
@@ -338,6 +350,7 @@ class Scheduler:
num_scheduled_tokens
=
num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_encoder_inputs
=
scheduled_encoder_inputs
,
scheduled_encoder_inputs
=
scheduled_encoder_inputs
,
scheduled_spec_decode_tokens
=
scheduled_spec_decode_tokens
,
num_common_prefix_blocks
=
num_common_prefix_blocks
,
num_common_prefix_blocks
=
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# instead of being newly scheduled in this step.
...
@@ -447,11 +460,11 @@ class Scheduler:
...
@@ -447,11 +460,11 @@ class Scheduler:
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
model_runner_output
:
"ModelRunnerOutput"
,
model_runner_output
:
"ModelRunnerOutput"
,
)
->
EngineCoreOutputs
:
)
->
EngineCoreOutputs
:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
logprobs
=
model_runner_output
.
logprobs
logprobs
=
model_runner_output
.
logprobs
prompt_logprobs_dict
=
model_runner_output
.
prompt_logprobs_dict
prompt_logprobs_dict
=
model_runner_output
.
prompt_logprobs_dict
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
new_running
:
List
[
Request
]
=
[]
new_running
:
List
[
Request
]
=
[]
outputs
:
List
[
EngineCoreOutput
]
=
[]
outputs
:
List
[
EngineCoreOutput
]
=
[]
...
@@ -466,11 +479,30 @@ class Scheduler:
...
@@ -466,11 +479,30 @@ class Scheduler:
new_running
.
append
(
request
)
new_running
.
append
(
request
)
continue
continue
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
req_id
not
in
scheduler_output
.
scheduled_spec_decode_tokens
:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
# Otherwise, we ignore the sampler output for the request.
request
.
num_computed_tokens
+=
num_tokens_scheduled
request
.
num_computed_tokens
+=
num_tokens_scheduled
# When the request's num_computed_tokens catches up its num_tokens,
# the request generates output tokens. Otherwise, we ignore the
# sampler output for the request.
assert
request
.
num_computed_tokens
<=
request
.
num_tokens
assert
request
.
num_computed_tokens
<=
request
.
num_tokens
else
:
# num_computed_tokens_step represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections.
# It is calculated as:
# num_computed_tokens_step = num_scheduled_tokens -
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
[
req_id
])
num_computed_tokens_step
=
num_scheduled_tokens
[
req_id
]
-
(
len
(
scheduled_spec_token_ids
)
+
1
-
len
(
generated_token_ids
))
request
.
num_computed_tokens
+=
num_computed_tokens_step
cached_encoder_input_ids
=
(
cached_encoder_input_ids
=
(
self
.
encoder_cache_manager
.
get_cached_input_ids
(
request
))
self
.
encoder_cache_manager
.
get_cached_input_ids
(
request
))
...
@@ -485,27 +517,32 @@ class Scheduler:
...
@@ -485,27 +517,32 @@ class Scheduler:
self
.
encoder_cache_manager
.
free_encoder_input
(
self
.
encoder_cache_manager
.
free_encoder_input
(
request
,
input_id
)
request
,
input_id
)
if
request
.
num_computed_tokens
>=
request
.
num_tokens
:
# Clear the spec tokens as the request has generated
# a new token. Here, We assume all spec tokens are verified
# if we perform speculative decoding for this request.
# Therefore, we can clear all spec tokens after
# the generation step.
request
.
clear_spec_tokens
()
# Get prompt logprobs for this request.
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
stopped
=
False
stopped
=
False
new_logprobs
=
None
new_logprobs
=
None
new_token_ids
=
None
new_token_ids
:
List
[
int
]
=
[]
if
request
.
num_computed_tokens
==
request
.
num_tokens
:
if
request
.
num_computed_tokens
>=
request
.
num_tokens
:
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
for
output_token_id
in
generated_token_ids
:
# NOTE(woosuk): Currently, we assume that each request
request
.
append_output_token_ids
(
output_token_id
)
# generates at most one token at each step.
new_token_ids
.
append
(
output_token_id
)
token_id
=
sampled_token_ids
[
req_index
]
request
.
append_output_token_ids
(
token_id
)
num_new_tokens
=
1
# TODO: Update the KV cache manager for prefix caching.
# Check for stop and update request state.
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
# This must be called before we make the EngineCoreOutput.
stopped
=
self
.
_check_stop
(
request
)
stopped
=
self
.
_check_stop
(
request
)
if
stopped
:
if
stopped
:
self
.
_free_request
(
request
)
self
.
_free_request
(
request
)
break
# Extract sample logprobs if needed.
# Extract sample logprobs if needed.
if
request
.
sampling_params
.
logprobs
is
not
None
:
if
request
.
sampling_params
.
logprobs
is
not
None
:
...
@@ -514,8 +551,6 @@ class Scheduler:
...
@@ -514,8 +551,6 @@ class Scheduler:
# the outer lists can be of length > 1.
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
new_token_ids
=
request
.
output_token_ids
[
-
num_new_tokens
:]
# Transmit partial if chunked prefill & prompt logprobs is enabled
# Transmit partial if chunked prefill & prompt logprobs is enabled
if
new_token_ids
or
prompt_logprobs_tensors
is
not
None
:
if
new_token_ids
or
prompt_logprobs_tensors
is
not
None
:
# Add EngineCoreOutput for this Request.
# Add EngineCoreOutput for this Request.
...
...
vllm/v1/core/scheduler_output.py
View file @
80f63a39
...
@@ -91,6 +91,10 @@ class SchedulerOutput:
...
@@ -91,6 +91,10 @@ class SchedulerOutput:
# Total number of tokens scheduled for all requests.
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens
:
int
total_num_scheduled_tokens
:
int
# req_id -> spec_decode_tokens
# If a request does not have any spec decode tokens, it will
# not be included in the dictionary.
scheduled_spec_decode_tokens
:
Dict
[
str
,
List
[
int
]]
# req_id -> encoder input indices that need processing.
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
# to process that the request's 0-th and 1-th images in the current step.
...
...
vllm/v1/engine/core.py
View file @
80f63a39
...
@@ -27,6 +27,7 @@ from vllm.v1.executor.abstract import Executor
...
@@ -27,6 +27,7 @@ from vllm.v1.executor.abstract import Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -65,6 +66,7 @@ class EngineCore:
...
@@ -65,6 +66,7 @@ class EngineCore:
model_config
=
vllm_config
.
model_config
,
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
lora_config
=
vllm_config
.
lora_config
,
speculative_config
=
vllm_config
.
speculative_config
,
log_stats
=
self
.
log_stats
,
log_stats
=
self
.
log_stats
,
)
)
...
@@ -84,6 +86,15 @@ class EngineCore:
...
@@ -84,6 +86,15 @@ class EngineCore:
self
.
batch_queue_size
)
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
# Setup speculative decode.
# TODO: find a better way to check if we are using ngram.
self
.
use_spec_decode
=
False
if
self
.
scheduler
.
speculative_config
:
assert
self
.
scheduler
.
speculative_config
.
ngram_prompt_lookup_min
\
,
"Only ngram spec decode is supported in V1."
self
.
proposer
=
NgramProposer
()
self
.
use_spec_decode
=
True
def
_initialize_kv_caches
(
self
,
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
Tuple
[
int
,
int
]:
vllm_config
:
VllmConfig
)
->
Tuple
[
int
,
int
]:
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -147,6 +158,9 @@ class EngineCore:
...
@@ -147,6 +158,9 @@ class EngineCore:
return
EngineCoreOutputs
(
return
EngineCoreOutputs
(
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
())
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
())
if
self
.
use_spec_decode
:
self
.
propose_tokens
()
scheduler_output
=
self
.
scheduler
.
schedule
()
scheduler_output
=
self
.
scheduler
.
schedule
()
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
...
@@ -207,6 +221,23 @@ class EngineCore:
...
@@ -207,6 +221,23 @@ class EngineCore:
def
profile
(
self
,
is_start
:
bool
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
model_executor
.
profile
(
is_start
)
self
.
model_executor
.
profile
(
is_start
)
def
propose_tokens
(
self
):
assert
self
.
scheduler
.
speculative_config
is
not
None
for
req
in
self
.
scheduler
.
running
:
# Ignore requests that are doing chunked prefill.
if
req
.
num_computed_tokens
<
req
.
num_tokens
-
1
:
continue
# Ignore requests that already have spec tokens.
if
req
.
spec_token_ids
:
continue
spec_tokens
=
self
.
proposer
.
propose
(
req
.
all_token_ids
,
self
.
scheduler
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
scheduler
.
speculative_config
.
num_speculative_tokens
,
)
if
spec_tokens
:
req
.
append_spec_token_ids
(
spec_tokens
)
def
reset_prefix_cache
(
self
):
def
reset_prefix_cache
(
self
):
self
.
scheduler
.
reset_prefix_cache
()
self
.
scheduler
.
reset_prefix_cache
()
...
...
vllm/v1/outputs.py
View file @
80f63a39
...
@@ -43,7 +43,10 @@ class LogprobsTensors(NamedTuple):
...
@@ -43,7 +43,10 @@ class LogprobsTensors(NamedTuple):
@
dataclass
@
dataclass
class
SamplerOutput
:
class
SamplerOutput
:
# [num_reqs]
# [num_reqs, max_num_generated_tokens]
# Different requests can have different number of generated tokens.
# All requests are padded to max_num_generated_tokens.
# INVALID_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids
:
torch
.
Tensor
sampled_token_ids
:
torch
.
Tensor
logprobs_tensors
:
Optional
[
LogprobsTensors
]
logprobs_tensors
:
Optional
[
LogprobsTensors
]
...
@@ -58,8 +61,11 @@ class ModelRunnerOutput:
...
@@ -58,8 +61,11 @@ class ModelRunnerOutput:
# req_id -> index
# req_id -> index
req_id_to_index
:
Dict
[
str
,
int
]
req_id_to_index
:
Dict
[
str
,
int
]
# [num_reqs]
# num_reqs x num_generated_tokens
sampled_token_ids
:
List
[
int
]
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids
:
List
[
List
[
int
]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
...
...
vllm/v1/request.py
View file @
80f63a39
...
@@ -46,6 +46,7 @@ class Request:
...
@@ -46,6 +46,7 @@ class Request:
self
.
num_prompt_tokens
=
len
(
self
.
prompt_token_ids
)
self
.
num_prompt_tokens
=
len
(
self
.
prompt_token_ids
)
self
.
_output_token_ids
:
List
[
int
]
=
[]
self
.
_output_token_ids
:
List
[
int
]
=
[]
self
.
_all_token_ids
:
List
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
_all_token_ids
:
List
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
spec_token_ids
:
List
[
int
]
=
[]
self
.
num_computed_tokens
=
0
self
.
num_computed_tokens
=
0
# Multi-modal related
# Multi-modal related
...
@@ -103,10 +104,26 @@ class Request:
...
@@ -103,10 +104,26 @@ class Request:
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
def
append_spec_token_ids
(
self
,
token_ids
:
Union
[
int
,
List
[
int
]],
)
->
None
:
if
isinstance
(
token_ids
,
int
):
self
.
spec_token_ids
.
append
(
token_ids
)
else
:
self
.
spec_token_ids
.
extend
(
token_ids
)
def
clear_spec_tokens
(
self
)
->
None
:
self
.
spec_token_ids
.
clear
()
@
property
@
property
def
num_tokens
(
self
)
->
int
:
def
num_tokens
(
self
)
->
int
:
return
len
(
self
.
_all_token_ids
)
return
len
(
self
.
_all_token_ids
)
@
property
def
num_tokens_with_spec
(
self
)
->
int
:
return
len
(
self
.
_all_token_ids
)
+
len
(
self
.
spec_token_ids
)
@
property
@
property
def
num_output_tokens
(
self
)
->
int
:
def
num_output_tokens
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
return
len
(
self
.
_output_token_ids
)
...
...
vllm/v1/sample/metadata.py
View file @
80f63a39
...
@@ -12,6 +12,8 @@ class SamplingMetadata:
...
@@ -12,6 +12,8 @@ class SamplingMetadata:
temperature
:
torch
.
Tensor
temperature
:
torch
.
Tensor
all_greedy
:
bool
all_greedy
:
bool
all_random
:
bool
all_random
:
bool
rejection_sampling
:
bool
spec_token_ids
:
List
[
List
[
int
]]
top_p
:
torch
.
Tensor
top_p
:
torch
.
Tensor
top_k
:
torch
.
Tensor
top_k
:
torch
.
Tensor
...
...
vllm/v1/sample/rejection_sampler.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
import
torch
import
torch.nn
as
nn
from
torch.nn.utils.rnn
import
pad_sequence
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
try
:
import
flashinfer.sampling
as
fs
is_flashinfer_available
=
True
except
ImportError
:
is_flashinfer_available
=
False
logger
=
init_logger
(
__name__
)
INVALID_TOKEN_ID
=
-
1
class
RejectionSampler
(
nn
.
Module
):
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
if
not
sampling_metadata
.
all_greedy
:
raise
NotImplementedError
(
"Only greedy sampling is supported by rejection sampler."
)
if
is_flashinfer_available
:
logger
.
info
(
"User FlashInfer for rejection sampling."
)
return
RejectionSampler
.
flashinfer_sample
(
logits
,
sampling_metadata
)
else
:
logger
.
warning
(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling."
)
return
RejectionSampler
.
greedy_sample_native
(
logits
,
sampling_metadata
)
@
staticmethod
def
flashinfer_sample
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
spec_token_ids
=
sampling_metadata
.
spec_token_ids
max_spec_len
=
max
(
len
(
s
)
for
s
in
spec_token_ids
)
batch_size
=
len
(
spec_token_ids
)
draft_token_ids
=
torch
.
full
((
batch_size
,
max_spec_len
),
INVALID_TOKEN_ID
,
device
=
"cpu"
,
dtype
=
torch
.
long
)
target_token_ids
=
torch
.
full
((
batch_size
,
max_spec_len
+
1
),
fill_value
=
INVALID_TOKEN_ID
,
device
=
logits
.
device
,
dtype
=
torch
.
long
)
# TODO: Vectorize the following loop for better performance.
start_loc
=
0
for
i
in
range
(
batch_size
):
num_spec_tokens
=
len
(
spec_token_ids
[
i
])
draft_token_ids
[
i
,
:
num_spec_tokens
]
=
torch
.
tensor
(
spec_token_ids
[
i
],
device
=
"cpu"
,
dtype
=
torch
.
long
)
end_loc
=
start_loc
+
num_spec_tokens
+
1
# Assume greedy sampling.
target_token_ids
[
i
,
:
num_spec_tokens
+
1
]
=
torch
.
argmax
(
logits
[
start_loc
:
end_loc
],
dim
=-
1
)
start_loc
=
end_loc
vocab_size
=
logits
.
size
(
-
1
)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids
=
draft_token_ids
.
to
(
logits
.
device
)
draft_probs
=
RejectionSampler
.
_create_greedy_token_probs
(
draft_token_ids
,
vocab_size
,
logits
.
device
)
target_probs
=
RejectionSampler
.
_create_greedy_token_probs
(
target_token_ids
,
vocab_size
,
logits
.
device
)
uniform_samples
=
torch
.
zeros
(
batch_size
,
max_spec_len
+
1
,
device
=
logits
.
device
)
sampled_token_ids
,
_
,
_
=
fs
.
chain_speculative_sampling
(
draft_probs
,
draft_token_ids
,
uniform_samples
,
target_probs
,
)
return
SamplerOutput
(
sampled_token_ids
=
sampled_token_ids
,
logprobs_tensors
=
None
)
# TODO: The following method can be optimized for better performance.
@
staticmethod
def
greedy_sample_native
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
spec_lens
=
[
len
(
x
)
for
x
in
sampling_metadata
.
spec_token_ids
]
# Add 1 to include the 'bonus' token.
sample_lens
=
[
x
+
1
for
x
in
spec_lens
]
output_token_ids
=
logits
.
argmax
(
dim
=-
1
).
view
(
-
1
)
output_token_ids
=
output_token_ids
.
split
(
sample_lens
)
output_token_ids
=
pad_sequence
(
output_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
spec_token_ids
=
[
torch
.
tensor
(
x
,
dtype
=
output_token_ids
.
dtype
,
device
=
output_token_ids
.
device
)
for
x
in
sampling_metadata
.
spec_token_ids
]
spec_token_ids
=
pad_sequence
(
spec_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask
=
(
output_token_ids
[:,
:
-
1
]
==
spec_token_ids
).
cumprod
(
dim
=
1
)
# Identify valid positions (non-padding).
valid_mask
=
output_token_ids
!=
INVALID_TOKEN_ID
# Generate mask with bonus token.
generate_mask
=
torch
.
cat
([
accept_mask
,
torch
.
zeros
(
accept_mask
.
size
(
0
),
1
,
device
=
accept_mask
.
device
)
],
dim
=
1
).
to
(
torch
.
bool
)
&
valid_mask
zeros_mask
=
(
generate_mask
==
0
)
first_zero_idx
=
zeros_mask
.
float
().
argmax
(
dim
=
1
)
# Figure out which rows actually contain at least one zero.
rows_with_zero
=
zeros_mask
.
any
(
dim
=
1
)
# Use indexing to set the first zero in each of those rows to 1.
generate_mask
[
rows_with_zero
,
first_zero_idx
[
rows_with_zero
]]
=
1
output_token_ids
[
~
generate_mask
]
=
INVALID_TOKEN_ID
return
SamplerOutput
(
sampled_token_ids
=
output_token_ids
,
logprobs_tensors
=
None
)
@
staticmethod
def
_create_greedy_token_probs
(
token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
out_device
:
torch
.
device
)
->
torch
.
Tensor
:
batch_size
,
num_tokens
=
token_ids
.
shape
token_probs
=
torch
.
zeros
(
batch_size
,
num_tokens
,
vocab_size
,
dtype
=
torch
.
float
,
device
=
out_device
)
# Ignore INVALID_TOKEN_ID.
valid_mask
=
(
token_ids
!=
INVALID_TOKEN_ID
)
valid_indices
=
token_ids
.
clone
()
valid_indices
[
~
valid_mask
]
=
0
token_probs
.
scatter_
(
dim
=
2
,
index
=
valid_indices
.
unsqueeze
(
-
1
),
src
=
valid_mask
.
unsqueeze
(
-
1
).
float
())
return
token_probs
vllm/v1/sample/sampler.py
View file @
80f63a39
...
@@ -9,6 +9,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
...
@@ -9,6 +9,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from
vllm.v1.sample.ops.penalties
import
(
apply_all_penalties
,
from
vllm.v1.sample.ops.penalties
import
(
apply_all_penalties
,
apply_min_token_penalties
)
apply_min_token_penalties
)
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -18,12 +19,21 @@ class Sampler(nn.Module):
...
@@ -18,12 +19,21 @@ class Sampler(nn.Module):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
rejection_sampler
=
RejectionSampler
()
def
forward
(
def
forward
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
if
sampling_metadata
.
rejection_sampling
:
if
sampling_metadata
.
max_num_logprobs
:
raise
NotImplementedError
(
"Rejection sampling does not support logprobs."
)
return
self
.
rejection_sampler
(
logits
,
sampling_metadata
,
)
# NOTE(woosuk): Use the original logits (before any penalties or
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# temperature scaling) for the top-k logprobs.
...
@@ -54,7 +64,10 @@ class Sampler(nn.Module):
...
@@ -54,7 +64,10 @@ class Sampler(nn.Module):
# These are GPU tensors.
# These are GPU tensors.
sampler_output
=
SamplerOutput
(
sampler_output
=
SamplerOutput
(
sampled_token_ids
=
sampled
,
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids
=
sampled
.
unsqueeze
(
-
1
),
logprobs_tensors
=
logprobs_tensors
,
logprobs_tensors
=
logprobs_tensors
,
)
)
return
sampler_output
return
sampler_output
...
...
vllm/v1/spec_decode/ngram_proposer.py
0 → 100644
View file @
80f63a39
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
vllm.v1.utils
import
ConstantList
class
NgramProposer
:
def
__init__
(
self
):
pass
def
propose
(
self
,
context_token_ids
:
ConstantList
[
int
],
n
:
int
,
k
:
int
)
->
Optional
[
List
[
int
]]:
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
that match.
Args:
context_token_ids: List of token IDs representing the
context sequence.
n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return
the maximum amount of tokens until the end.
Returns:
List[int]: The sequence of tokens that followed
the matched n-gram in the context.
None: If no matching n-gram pattern is found.
Example:
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
- The last 2 tokens [2,3] will be matched against the previous
4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# TODO: Use c++ to implement the _find_subarray_kmp to
# improve the efficiency
return
self
.
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
@
staticmethod
def
_kmp_lps_array
(
pattern
:
List
[
int
])
->
List
[
int
]:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps
=
[
0
]
*
len
(
pattern
)
prev_lps
=
0
# length of the previous longest prefix suffix
i
=
1
while
i
<
len
(
pattern
):
if
pattern
[
i
]
==
pattern
[
prev_lps
]:
prev_lps
+=
1
lps
[
i
]
=
prev_lps
i
+=
1
else
:
if
prev_lps
!=
0
:
prev_lps
=
lps
[
prev_lps
-
1
]
else
:
lps
[
i
]
=
0
i
+=
1
return
lps
@
staticmethod
def
_find_subarray_kmp
(
context_token_ids
:
ConstantList
[
int
],
n
:
int
,
k
:
int
)
->
Optional
[
List
[
int
]]:
context_len
=
len
(
context_token_ids
)
assert
n
>
0
pattern
=
context_token_ids
[
-
n
:]
# Precompute lps array for Y
lps
=
NgramProposer
.
_kmp_lps_array
(
pattern
)
i
=
0
j
=
0
# -n because the last n tokens are used as pattern
while
i
<
context_len
-
n
:
if
context_token_ids
[
i
]
==
pattern
[
j
]:
i
+=
1
j
+=
1
# If we have matched the entire Y
if
j
==
n
:
# Found pattern in context, gather the next K elements
return
context_token_ids
[
i
:
i
+
k
]
else
:
# Mismatch
if
j
!=
0
:
# Use the lps array to avoid re-checking elements
j
=
lps
[
j
-
1
]
else
:
i
+=
1
# Y not found
return
None
vllm/v1/worker/gpu_input_batch.py
View file @
80f63a39
...
@@ -390,6 +390,7 @@ class InputBatch:
...
@@ -390,6 +390,7 @@ class InputBatch:
def
make_sampling_metadata
(
def
make_sampling_metadata
(
self
,
self
,
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]],
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]],
req_id_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
skip_copy
:
bool
=
False
,
skip_copy
:
bool
=
False
,
)
->
SamplingMetadata
:
)
->
SamplingMetadata
:
if
not
skip_copy
:
if
not
skip_copy
:
...
@@ -423,7 +424,8 @@ class InputBatch:
...
@@ -423,7 +424,8 @@ class InputBatch:
self
.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
self
.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
output_token_ids
:
List
[
List
[
int
]]
=
[]
output_token_ids
:
List
[
List
[
int
]]
=
[]
spec_token_ids
:
List
[
List
[
int
]]
=
[]
rejection_sampling
=
False
for
req_id
in
self
.
req_ids
[:
self
.
num_reqs
]:
for
req_id
in
self
.
req_ids
[:
self
.
num_reqs
]:
assert
req_id
is
not
None
assert
req_id
is
not
None
# Currently we create a tensor for output_token_ids from scratch
# Currently we create a tensor for output_token_ids from scratch
...
@@ -434,11 +436,18 @@ class InputBatch:
...
@@ -434,11 +436,18 @@ class InputBatch:
# TODO - Replace this with incremental update to output token
# TODO - Replace this with incremental update to output token
# statistics.
# statistics.
output_token_ids
.
append
(
req_id_output_token_ids
[
req_id
])
output_token_ids
.
append
(
req_id_output_token_ids
[
req_id
])
req_spec_token_ids
=
req_id_to_spec_token_ids
.
get
(
req_id
,
[])
spec_token_ids
.
append
(
req_spec_token_ids
)
if
req_spec_token_ids
:
# If any of the requests require speculative decoding, set the
# flag to True.
rejection_sampling
=
True
return
SamplingMetadata
(
return
SamplingMetadata
(
temperature
=
self
.
temperature
[:
self
.
num_reqs
],
temperature
=
self
.
temperature
[:
self
.
num_reqs
],
all_greedy
=
self
.
all_greedy
,
all_greedy
=
self
.
all_greedy
,
all_random
=
self
.
all_random
,
all_random
=
self
.
all_random
,
rejection_sampling
=
rejection_sampling
,
top_p
=
self
.
top_p
[:
self
.
num_reqs
],
top_p
=
self
.
top_p
[:
self
.
num_reqs
],
top_k
=
self
.
top_k
[:
self
.
num_reqs
],
top_k
=
self
.
top_k
[:
self
.
num_reqs
],
min_p
=
self
.
min_p
[:
self
.
num_reqs
],
min_p
=
self
.
min_p
[:
self
.
num_reqs
],
...
@@ -452,6 +461,7 @@ class InputBatch:
...
@@ -452,6 +461,7 @@ class InputBatch:
presence_penalties
=
self
.
presence_penalties
[:
self
.
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
self
.
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
self
.
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
self
.
num_reqs
],
output_token_ids
=
output_token_ids
,
output_token_ids
=
output_token_ids
,
spec_token_ids
=
spec_token_ids
,
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
no_penalties
=
self
.
no_penalties
,
no_penalties
=
self
.
no_penalties
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
80f63a39
...
@@ -32,6 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...
@@ -32,6 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec
)
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
...
@@ -180,6 +181,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -180,6 +181,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
max_model_len
,
self
.
max_model_len
,
self
.
max_num_tokens
),
self
.
max_num_tokens
),
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
self
.
arange_cpu
=
torch
.
from_numpy
(
self
.
arange_np
)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
# not make any assumptions about the values in these tensors.
...
@@ -368,7 +370,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -368,7 +370,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
batch_changed
return
batch_changed
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
Tuple
[
FlashAttentionMetadata
,
torch
.
Tensor
]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
assert
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
...
@@ -382,12 +386,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -382,12 +386,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: The Python loop can be slow. Optimize.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens_list
:
List
[
int
]
=
[]
num_scheduled_tokens_list
:
List
[
int
]
=
[]
max_num_scheduled_tokens
=
0
max_num_scheduled_tokens
=
0
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
all_spec_token_ids
:
List
[
int
]
=
[]
num_spec_tokens_list
:
List
[
int
]
=
[]
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
assert
req_id
is
not
None
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens_list
.
append
(
num_tokens
)
num_scheduled_tokens_list
.
append
(
num_tokens
)
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
num_tokens
)
num_tokens
)
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
all_spec_token_ids
.
extend
(
spec_token_ids
)
num_spec_tokens_list
.
append
(
len
(
spec_token_ids
))
num_scheduled_tokens
:
np
.
ndarray
=
np
.
array
(
num_scheduled_tokens_list
,
num_scheduled_tokens
:
np
.
ndarray
=
np
.
array
(
num_scheduled_tokens_list
,
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
assert
max_num_scheduled_tokens
>
0
assert
max_num_scheduled_tokens
>
0
...
@@ -426,6 +437,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -426,6 +437,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# where M is the max_model_len.
# where M is the max_model_len.
token_indices
=
(
positions_np
+
token_indices
=
(
positions_np
+
req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
use_spec_decode
=
len
(
all_spec_token_ids
)
>
0
if
use_spec_decode
:
# 1. Write spec_token_ids to input batch.
# Step 1. Get req indices that perform spec decode and repeat
# the req indices by the number of spec tokens. Note
# for requests that don't perform spec decode, the
# number of spec tokens is 0 and the req index is
# repeated 0 times.
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
# spec_req_indices: [0, 0, 0, 2, 2, 4]
spec_req_indices
=
np
.
repeat
(
self
.
arange_np
[:
num_reqs
],
num_spec_tokens_list
)
# spec_offsets: offsets within each spec token list.
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
spec_offsets
=
np
.
concatenate
(
[
self
.
arange_np
[
1
:
val
+
1
]
for
val
in
num_spec_tokens_list
])
# spec_seq_offsets: offsets within each sequence.
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
# after repeating: [1, 1, 1, 3, 3, 2]
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
# = [2, 3, 4, 4, 5, 3]
spec_seq_offsets
=
np
.
repeat
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
],
num_spec_tokens_list
)
+
spec_offsets
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
cumsums_spec_offsets
=
(
spec_seq_offsets
+
spec_req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
cumsums_spec_offsets
=
torch
.
from_numpy
(
cumsums_spec_offsets
).
to
(
torch
.
int64
)
all_spec_token_ids
=
torch
.
tensor
(
all_spec_token_ids
,
device
=
"cpu"
,
dtype
=
self
.
input_ids_cpu
.
dtype
)
# Step 2. Write spec token ids to input_ids_cpu.
self
.
input_batch
.
token_ids_cpu_tensor
.
flatten
().
scatter_
(
0
,
cumsums_spec_offsets
,
all_spec_token_ids
)
# 2. Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_spec_tokens_np
=
np
.
array
(
num_spec_tokens_list
,
dtype
=
np
.
int32
)
num_sampled_tokens
=
num_spec_tokens_np
+
1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc
=
cu_num_tokens
-
num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc
=
np
.
repeat
(
logits_start_loc
,
num_sampled_tokens
)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens
=
np
.
cumsum
(
num_sampled_tokens
)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets
=
np
.
repeat
(
cu_num_sampled_tokens
-
num_sampled_tokens
,
num_sampled_tokens
)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens
=
num_sampled_tokens
.
sum
()
sampled_arange
=
(
self
.
arange_np
[:
total_num_sampled_tokens
]
-
cumsums_sampled_offsets
)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices
=
logits_start_loc
+
sampled_arange
# NOTE(woosuk): We use torch.index_select instead of np.take here
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# because torch.index_select is much faster than np.take for large
# tensors.
# tensors.
...
@@ -519,16 +603,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -519,16 +603,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
)
)
if
use_spec_decode
:
logits_indices
=
torch
.
from_numpy
(
spec_decode_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
else
:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
# Hot-Swap lora model
# Hot-Swap lora model
if
self
.
lora_config
:
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
# tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
return
attn_metadata
,
logits_indices
return
attn_metadata
,
logits_indices
def
_compute_cascade_attn_prefix_len
(
def
_compute_cascade_attn_prefix_len
(
...
@@ -673,6 +762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -673,6 +762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_prepare_sampling
(
def
_prepare_sampling
(
self
,
self
,
batch_changed
:
bool
,
batch_changed
:
bool
,
req_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
)
->
SamplingMetadata
:
)
->
SamplingMetadata
:
# Create the sampling metadata.
# Create the sampling metadata.
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]]
=
\
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]]
=
\
...
@@ -680,7 +770,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -680,7 +770,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
req_id
,
req
in
self
.
requests
.
items
()}
for
req_id
,
req
in
self
.
requests
.
items
()}
sampling_metadata
=
self
.
input_batch
.
make_sampling_metadata
(
sampling_metadata
=
self
.
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
skip_copy
=
not
batch_changed
)
req_id_output_token_ids
,
req_to_spec_token_ids
,
not
batch_changed
)
return
sampling_metadata
return
sampling_metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
...
@@ -847,7 +937,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -847,7 +937,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
# Sample the next token and get logprobs if needed.
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
_prepare_sampling
(
batch_changed
)
sampling_metadata
=
self
.
_prepare_sampling
(
batch_changed
,
scheduler_output
.
scheduled_spec_decode_tokens
)
sampler_output
=
self
.
model
.
sample
(
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
...
@@ -857,18 +948,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -857,18 +948,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the requests one by one. Optimize.
# the requests one by one. Optimize.
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
request_seq_lens
:
List
[
Tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
request_seq_lens
:
List
[
Tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
for
i
,
req_id
in
enumerate
(
# type: ignore[assignment]
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
self
.
input_batch
.
req_ids
[:
num_reqs
]):
assert
req_id
is
not
None
assert
req_id
is
not
None
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
scheduler_output
.
num_scheduled_tokens
[
req_id
])
assert
seq_len
<=
req_state
.
num_tokens
if
seq_len
>=
req_state
.
num_tokens
:
if
seq_len
==
req_state
.
num_tokens
:
# Append the sampled token to the output token ids.
self
.
input_batch
.
num_tokens
[
i
]
+=
1
# OPTIMIZATION: Priming the state updates for later updates.
req_state
.
output_token_ids
.
append
(
0
)
request_seq_lens
.
append
((
i
,
req_state
,
seq_len
))
request_seq_lens
.
append
((
i
,
req_state
,
seq_len
))
else
:
else
:
# Ignore the sampled token from the partial request.
# Ignore the sampled token from the partial request.
...
@@ -886,7 +971,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -886,7 +971,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: GPU -> CPU Sync happens here.
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
.
tolist
()
logprobs_tensors
=
sampler_output
.
logprobs_tensors
logprobs_tensors
=
sampler_output
.
logprobs_tensors
logprobs_lists
=
logprobs_tensors
.
tolists
()
\
logprobs_lists
=
logprobs_tensors
.
tolists
()
\
if
logprobs_tensors
is
not
None
else
None
if
logprobs_tensors
is
not
None
else
None
...
@@ -897,16 +981,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -897,16 +981,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
,
scheduler_output
,
)
)
# Update with the actual token ids
# Update batch with the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
token_id
=
sampled_token_ids
[
i
]
token_id
=
valid_
sampled_token_ids
[
i
]
[
0
]
self
.
input_batch
.
token_ids_cpu
[
i
,
seq_len
]
=
token_id
self
.
input_batch
.
token_ids_cpu
[
i
,
seq_len
]
=
token_id
req_state
.
output_token_ids
[
-
1
]
=
token_id
req_state
.
output_token_ids
.
append
(
token_id
)
self
.
input_batch
.
num_tokens
[
i
]
+=
1
else
:
valid_mask
=
sampled_token_ids
!=
INVALID_TOKEN_ID
gen_lens
=
valid_mask
.
sum
(
dim
=
1
).
tolist
()
valid_sampled_token_ids
=
[
seq
.
tolist
()
for
seq
in
sampled_token_ids
[
valid_mask
].
split
(
gen_lens
)
]
self
.
input_batch
.
num_tokens
[:
num_reqs
]
+=
gen_lens
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
target_slice
=
slice
(
seq_len
-
gen_lens
[
i
]
+
1
,
seq_len
+
1
)
self
.
input_batch
.
token_ids_cpu
[
i
,
target_slice
]
=
valid_sampled_token_ids
[
i
]
req_state
.
output_token_ids
.
extend
(
valid_sampled_token_ids
[
i
])
model_runner_output
=
ModelRunnerOutput
(
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
valid_
sampled_token_ids
,
logprobs
=
logprobs_lists
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment