Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b004c004
Unverified
Commit
b004c004
authored
Nov 23, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 23, 2025
Browse files
[Model Runner V2] Support spec decoding [1/N] (#29274)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
7f12c82f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
350 additions
and
29 deletions
+350
-29
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+112
-16
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+73
-13
vllm/v1/worker/gpu/spec_decode/__init__.py
vllm/v1/worker/gpu/spec_decode/__init__.py
+0
-0
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
+71
-0
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+94
-0
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
b004c004
...
@@ -35,6 +35,7 @@ class InputBuffers:
...
@@ -35,6 +35,7 @@ class InputBuffers:
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
query_start_loc
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
query_start_loc
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
cu_num_logits
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
# Structured outputs.
# Structured outputs.
self
.
bitmask_indices
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
bitmask_indices
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
...
@@ -64,6 +65,7 @@ class InputBatch:
...
@@ -64,6 +65,7 @@ class InputBatch:
# sum(num_scheduled_tokens)
# sum(num_scheduled_tokens)
num_tokens
:
int
num_tokens
:
int
num_tokens_after_padding
:
int
num_tokens_after_padding
:
int
num_draft_tokens
:
int
# [num_reqs + 1]
# [num_reqs + 1]
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
...
@@ -80,8 +82,10 @@ class InputBatch:
...
@@ -80,8 +82,10 @@ class InputBatch:
# layer_name -> Metadata
# layer_name -> Metadata
attn_metadata
:
dict
[
str
,
Any
]
attn_metadata
:
dict
[
str
,
Any
]
# [
num_req
s]
# [
total_num_logit
s]
logits_indices
:
torch
.
Tensor
logits_indices
:
torch
.
Tensor
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
@
classmethod
@
classmethod
def
make_dummy
(
def
make_dummy
(
...
@@ -118,6 +122,7 @@ class InputBatch:
...
@@ -118,6 +122,7 @@ class InputBatch:
positions
=
input_buffers
.
positions
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
# attn_metadata = defaultdict(lambda: None)
# attn_metadata = defaultdict(lambda: None)
logits_indices
=
query_start_loc
[
1
:]
-
1
logits_indices
=
query_start_loc
[
1
:]
-
1
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
return
cls
(
return
cls
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
...
@@ -126,6 +131,7 @@ class InputBatch:
...
@@ -126,6 +131,7 @@ class InputBatch:
num_scheduled_tokens
=
num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
num_tokens_after_padding
=
num_tokens
,
num_tokens_after_padding
=
num_tokens
,
num_draft_tokens
=
0
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
@@ -134,6 +140,7 @@ class InputBatch:
...
@@ -134,6 +140,7 @@ class InputBatch:
positions
=
positions
,
positions
=
positions
,
attn_metadata
=
None
,
# type: ignore
attn_metadata
=
None
,
# type: ignore
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
)
)
...
@@ -279,19 +286,53 @@ def _combine_sampled_and_draft_tokens_kernel(
...
@@ -279,19 +286,53 @@ def _combine_sampled_and_draft_tokens_kernel(
query_start_loc_ptr
,
query_start_loc_ptr
,
seq_lens_ptr
,
seq_lens_ptr
,
prefill_len_ptr
,
prefill_len_ptr
,
draft_tokens_ptr
,
draft_tokens_stride
,
cu_num_logits_ptr
,
logits_indices_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
batch_idx
=
tl
.
program_id
(
0
)
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
# Get the number of logits and draft tokens.
cu_num_logits_start
=
tl
.
load
(
cu_num_logits_ptr
+
batch_idx
)
cu_num_logits_end
=
tl
.
load
(
cu_num_logits_ptr
+
batch_idx
+
1
)
num_logits
=
cu_num_logits_end
-
cu_num_logits_start
num_draft_tokens
=
num_logits
-
1
# Compute the logits indices.
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
logits_start
=
query_end
-
num_logits
tl
.
store
(
logits_indices_ptr
+
cu_num_logits_start
+
block
,
logits_start
+
block
,
mask
=
block
<
num_logits
,
)
seq_len
=
tl
.
load
(
seq_lens_ptr
+
batch_idx
)
seq_len
=
tl
.
load
(
seq_lens_ptr
+
batch_idx
)
prefill_len
=
tl
.
load
(
prefill_len_ptr
+
req_state_idx
)
prefill_len
=
tl
.
load
(
prefill_len_ptr
+
req_state_idx
)
if
seq_len
<=
prefill_len
:
if
seq_len
<=
prefill_len
:
# Handling prefill tokens.
# Handling prefill tokens.
No sampled or draft tokens.
return
return
# Write the last sampled token ID to input_ids.
last_token_id
=
tl
.
load
(
last_sampled_tokens_ptr
+
req_state_idx
)
last_token_id
=
tl
.
load
(
last_sampled_tokens_ptr
+
req_state_idx
)
end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
tl
.
store
(
input_ids_ptr
+
query_end
-
num_logits
,
last_token_id
)
tl
.
store
(
input_ids_ptr
+
end
-
1
,
last_token_id
)
# Write the draft tokens (if any) to input_ids.
if
num_draft_tokens
>
0
:
mask
=
block
<
num_draft_tokens
draft_tokens
=
tl
.
load
(
draft_tokens_ptr
+
req_state_idx
*
draft_tokens_stride
+
block
,
mask
=
mask
,
)
tl
.
store
(
input_ids_ptr
+
query_end
-
num_draft_tokens
+
block
,
draft_tokens
,
mask
=
mask
,
)
def
combine_sampled_and_draft_tokens
(
def
combine_sampled_and_draft_tokens
(
...
@@ -301,8 +342,18 @@ def combine_sampled_and_draft_tokens(
...
@@ -301,8 +342,18 @@ def combine_sampled_and_draft_tokens(
query_start_loc
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
prefill_len
:
torch
.
Tensor
,
prefill_len
:
torch
.
Tensor
,
draft_tokens
:
torch
.
Tensor
,
cu_num_logits
:
torch
.
Tensor
,
num_logits
:
int
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_reqs
=
seq_lens
.
shape
[
0
]
num_reqs
=
seq_lens
.
shape
[
0
]
num_speculative_steps
=
draft_tokens
.
shape
[
-
1
]
logits_indices
=
torch
.
empty
(
num_logits
,
dtype
=
torch
.
int64
,
device
=
input_ids
.
device
,
)
_combine_sampled_and_draft_tokens_kernel
[(
num_reqs
,)](
_combine_sampled_and_draft_tokens_kernel
[(
num_reqs
,)](
input_ids
,
input_ids
,
idx_mapping
,
idx_mapping
,
...
@@ -310,35 +361,80 @@ def combine_sampled_and_draft_tokens(
...
@@ -310,35 +361,80 @@ def combine_sampled_and_draft_tokens(
query_start_loc
,
query_start_loc
,
seq_lens
,
seq_lens
,
prefill_len
,
prefill_len
,
draft_tokens
,
draft_tokens
.
stride
(
0
),
cu_num_logits
,
logits_indices
,
# NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
# in addition to all draft tokens.
BLOCK_SIZE
=
triton
.
next_power_of_2
(
num_speculative_steps
+
1
),
)
)
return
input_id
s
return
logits_indice
s
@
triton
.
jit
@
triton
.
jit
def
_update
_num_computed_tokens
_kernel
(
def
_
post_
update_kernel
(
idx_mapping_ptr
,
idx_mapping_ptr
,
num_computed_tokens_ptr
,
num_computed_tokens_ptr
,
last_sampled_tokens_ptr
,
sampled_tokens_ptr
,
sampled_tokens_stride
,
num_sampled_ptr
,
query_start_loc_ptr
,
query_start_loc_ptr
,
cu_num_logits_ptr
,
):
):
req_id
=
tl
.
program_id
(
0
)
req_id
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_id
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_id
)
start
=
tl
.
load
(
query_start_loc_ptr
+
req_id
)
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
req_id
)
end
=
tl
.
load
(
query_start_loc_ptr
+
req_id
+
1
)
if
num_sampled
>
0
:
query_len
=
end
-
start
token_id
=
tl
.
load
(
sampled_tokens_ptr
+
req_id
*
sampled_tokens_stride
+
num_sampled
-
1
n
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
)
tl
.
store
(
num_computed_tokens_ptr
+
req_state_idx
,
n
+
query_len
)
tl
.
store
(
last_sampled_tokens_ptr
+
req_state_idx
,
token_id
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
req_id
)
def
update_num_computed_tokens
(
query_end
=
tl
.
load
(
query_start_loc_ptr
+
req_id
+
1
)
query_len
=
query_end
-
query_start
num_computed
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
num_computed
+=
query_len
# Consider the rejected tokens in spec decoding.
if
num_sampled
>
0
:
# NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills.
logits_start
=
tl
.
load
(
cu_num_logits_ptr
+
req_id
)
logits_end
=
tl
.
load
(
cu_num_logits_ptr
+
req_id
+
1
)
num_logits
=
logits_end
-
logits_start
num_rejected
=
num_logits
-
num_sampled
num_computed
-=
num_rejected
tl
.
store
(
num_computed_tokens_ptr
+
req_state_idx
,
num_computed
)
def
post_update
(
# [num_reqs]
idx_mapping
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
# [max_num_reqs]
num_computed_tokens
:
torch
.
Tensor
,
num_computed_tokens
:
torch
.
Tensor
,
# [max_num_reqs]
last_sampled_tokens
:
torch
.
Tensor
,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens
:
torch
.
Tensor
,
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
# [num_reqs + 1]
query_start_loc
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
num_reqs
=
idx_mapping
.
shape
[
0
]
num_reqs
=
idx_mapping
.
shape
[
0
]
_update
_num_computed_tokens
_kernel
[(
num_reqs
,)](
_
post_
update_kernel
[(
num_reqs
,)](
idx_mapping
,
idx_mapping
,
num_computed_tokens
,
num_computed_tokens
,
last_sampled_tokens
,
sampled_tokens
,
sampled_tokens
.
stride
(
0
),
num_sampled
,
query_start_loc
,
query_start_loc
,
cu_num_logits
,
num_warps
=
1
,
)
)
vllm/v1/worker/gpu/model_runner.py
View file @
b004c004
...
@@ -40,11 +40,12 @@ from vllm.v1.worker.gpu.input_batch import (
...
@@ -40,11 +40,12 @@ from vllm.v1.worker.gpu.input_batch import (
InputBatch
,
InputBatch
,
InputBuffers
,
InputBuffers
,
combine_sampled_and_draft_tokens
,
combine_sampled_and_draft_tokens
,
post_update
,
prepare_pos_seq_lens
,
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
prepare_prefill_inputs
,
update_num_computed_tokens
,
)
)
from
vllm.v1.worker.gpu.sampler
import
Sampler
,
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sampler
import
Sampler
,
compute_prompt_logprobs
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.states
import
RequestState
,
SamplingMetadata
from
vllm.v1.worker.gpu.states
import
RequestState
,
SamplingMetadata
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
...
@@ -100,10 +101,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -100,10 +101,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
input_prep_event
=
None
self
.
input_prep_event
=
None
self
.
structured_outputs_event
=
None
self
.
structured_outputs_event
=
None
if
self
.
speculative_config
is
not
None
:
self
.
do_spec_decode
=
True
self
.
num_speculative_steps
=
self
.
speculative_config
.
num_speculative_tokens
else
:
self
.
do_spec_decode
=
False
self
.
num_speculative_steps
=
0
self
.
req_states
=
RequestState
(
self
.
req_states
=
RequestState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
num_speculative_steps
=
self
.
num_speculative_steps
,
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
...
@@ -427,6 +436,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -427,6 +436,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np
=
idx_mapping
.
np
[:
num_reqs
]
idx_mapping_np
=
idx_mapping
.
np
[:
num_reqs
]
idx_mapping
=
idx_mapping
.
copy_to_gpu
(
num_reqs
)
idx_mapping
=
idx_mapping
.
copy_to_gpu
(
num_reqs
)
# Get the number of draft tokens for each request.
if
not
scheduler_output
.
scheduled_spec_decode_tokens
:
# No draft token scheduled (common case).
total_num_draft_tokens
=
0
total_num_logits
=
num_reqs
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
else
:
draft_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
num_draft_tokens
=
np
.
array
(
[
len
(
draft_tokens
[
req_id
])
if
req_id
in
draft_tokens
else
0
for
req_id
in
req_ids
],
dtype
=
np
.
int32
,
)
total_num_draft_tokens
=
int
(
num_draft_tokens
.
sum
())
total_num_logits
=
num_reqs
+
total_num_draft_tokens
np
.
cumsum
(
num_draft_tokens
+
1
,
out
=
self
.
input_buffers
.
cu_num_logits
.
np
[
1
:
num_reqs
+
1
],
)
cu_num_logits
=
self
.
input_buffers
.
cu_num_logits
.
copy_to_gpu
(
num_reqs
+
1
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables
=
self
.
block_tables
.
gather_block_tables
(
idx_mapping
)
block_tables
=
self
.
block_tables
.
gather_block_tables
(
idx_mapping
)
...
@@ -456,14 +491,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -456,14 +491,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
# Some input token ids are directly read from the last sampled tokens
# Some input token ids are directly read from the last sampled tokens
# and draft tokens.
# and draft tokens.
Also, get the logits indices to sample tokens from.
combine_sampled_and_draft_tokens
(
logits_indices
=
combine_sampled_and_draft_tokens
(
self
.
input_buffers
.
input_ids
.
gpu
,
self
.
input_buffers
.
input_ids
.
gpu
,
idx_mapping
,
idx_mapping
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
last_sampled_tokens
,
query_start_loc_gpu
,
query_start_loc_gpu
,
seq_lens
,
seq_lens
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
draft_tokens
,
cu_num_logits
,
total_num_logits
,
)
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
...
@@ -471,9 +509,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -471,9 +509,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc_gpu
,
self
.
input_buffers
.
positions
[:
num_tokens
]
query_start_loc_gpu
,
self
.
input_buffers
.
positions
[:
num_tokens
]
)
)
# Logits indices to sample next token from.
logits_indices
=
query_start_loc_gpu
[
1
:]
-
1
# Get num_computed_tokens.
# Get num_computed_tokens.
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
# num_computed_tokens_cpu. This works for most cases.
# num_computed_tokens_cpu. This works for most cases.
...
@@ -508,6 +543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -508,6 +543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens
=
num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
num_tokens_after_padding
=
num_tokens_after_padding
,
num_tokens_after_padding
=
num_tokens_after_padding
,
num_draft_tokens
=
total_num_draft_tokens
,
query_start_loc
=
query_start_loc_gpu
,
query_start_loc
=
query_start_loc_gpu
,
query_start_loc_np
=
query_start_loc_np
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
@@ -516,6 +552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -516,6 +552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions
=
positions
,
positions
=
positions
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
)
)
def
sample
(
def
sample
(
...
@@ -530,6 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -530,6 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
grammar_output
is
not
None
:
if
grammar_output
is
not
None
:
# Apply grammar bitmask to the logits in-place.
# Apply grammar bitmask to the logits in-place.
# TODO(woosuk): Make compatible with spec decoding.
# TODO(woosuk): Make compatible with spec decoding.
assert
input_batch
.
num_draft_tokens
==
0
with
async_barrier
(
self
.
structured_outputs_event
):
with
async_barrier
(
self
.
structured_outputs_event
):
apply_grammar_bitmask
(
apply_grammar_bitmask
(
logits
,
logits
,
...
@@ -539,12 +577,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -539,12 +577,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
input_buffers
,
self
.
input_buffers
,
)
)
# Sample tokens and compute logprobs (if needed).
sampler_output
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampler_output
=
self
.
sampler
(
logits
,
sampling_metadata
)
# Get the number of sampled tokens.
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
prefill_len
=
self
.
req_states
.
prefill_len
.
gpu
[
input_batch
.
idx_mapping
]
prefill_len
=
self
.
req_states
.
prefill_len
.
gpu
[
input_batch
.
idx_mapping
]
is_chunked_prefilling
=
input_batch
.
seq_lens
<
prefill_len
is_chunked_prefilling
=
input_batch
.
seq_lens
<
prefill_len
if
input_batch
.
num_draft_tokens
==
0
:
# No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not.
num_sampled
=
(
~
is_chunked_prefilling
).
int
()
num_sampled
=
(
~
is_chunked_prefilling
).
int
()
else
:
# Draft tokens for spec decoding.
input_ids
=
input_batch
.
input_ids
[
input_batch
.
logits_indices
]
sampled_tokens
,
num_sampled
=
rejection_sample
(
sampler_output
.
sampled_token_ids
,
input_ids
,
input_batch
.
cu_num_logits
,
self
.
num_speculative_steps
,
)
num_sampled
*=
~
is_chunked_prefilling
sampler_output
.
sampled_token_ids
=
sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding.
return
sampler_output
,
num_sampled
return
sampler_output
,
num_sampled
def
compute_prompt_logprobs
(
def
compute_prompt_logprobs
(
...
@@ -653,11 +707,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -653,11 +707,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
# Update the number of computed tokens.
# Update the number of computed tokens.
update
_num_computed_tokens
(
post_
update
(
input_batch
.
idx_mapping
,
input_batch
.
idx_mapping
,
self
.
req_states
.
num_computed_tokens
,
self
.
req_states
.
num_computed_tokens
,
self
.
req_states
.
last_sampled_tokens
,
sampled_tokens
,
num_sampled
,
input_batch
.
query_start_loc
,
input_batch
.
query_start_loc
,
input_batch
.
cu_num_logits
,
)
)
# Update the number of computed prefill tokens.
idx_mapping_np
=
input_batch
.
idx_mapping_np
idx_mapping_np
=
input_batch
.
idx_mapping_np
computed_prefill
=
self
.
req_states
.
num_computed_prefill_tokens
computed_prefill
=
self
.
req_states
.
num_computed_prefill_tokens
# TODO(woosuk): Simplify this.
# TODO(woosuk): Simplify this.
...
@@ -666,10 +726,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -666,10 +726,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
req_states
.
prefill_len
.
np
[
idx_mapping_np
],
self
.
req_states
.
prefill_len
.
np
[
idx_mapping_np
],
)
)
# Store the last sampled token ids.
last_sampled
=
sampled_tokens
self
.
req_states
.
last_sampled_tokens
[
input_batch
.
idx_mapping
]
=
last_sampled
def
get_cudagraph_and_dp_padding
(
def
get_cudagraph_and_dp_padding
(
self
,
self
,
scheduler_output
:
SchedulerOutput
,
scheduler_output
:
SchedulerOutput
,
...
@@ -761,6 +817,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -761,6 +817,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata
=
self
.
req_states
.
make_sampling_metadata
(
sampling_metadata
=
self
.
req_states
.
make_sampling_metadata
(
input_batch
.
idx_mapping_np
,
pos
input_batch
.
idx_mapping_np
,
pos
)
)
if
input_batch
.
num_draft_tokens
>
0
:
sampling_metadata
=
self
.
req_states
.
expand_sampling_metadata
(
sampling_metadata
,
input_batch
.
cu_num_logits
)
if
self
.
lora_config
:
if
self
.
lora_config
:
# Activate LoRA adapters.
# Activate LoRA adapters.
...
...
vllm/v1/worker/gpu/spec_decode/__init__.py
0 → 100644
View file @
b004c004
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
0 → 100644
View file @
b004c004
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
_rejection_sample_kernel
(
sampled_ptr
,
# [num_reqs, num_speculative_steps + 1]
sampled_stride
,
num_sampled_ptr
,
# [num_reqs]
target_sampled_ptr
,
# [num_draft_tokens + num_reqs]
input_ids_ptr
,
# [num_draft_tokens + num_reqs]
cu_num_logits_ptr
,
# [num_reqs + 1]
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
num_sampled
=
0
rejected
=
False
for
i
in
range
(
num_tokens
-
1
):
if
not
rejected
:
target_sampled
=
tl
.
load
(
target_sampled_ptr
+
start_idx
+
i
)
draft_sampled
=
tl
.
load
(
input_ids_ptr
+
start_idx
+
i
+
1
)
tl
.
store
(
sampled_ptr
+
req_idx
*
sampled_stride
+
i
,
target_sampled
)
num_sampled
+=
1
if
target_sampled
!=
draft_sampled
:
rejected
=
True
if
not
rejected
:
target_sampled
=
tl
.
load
(
target_sampled_ptr
+
start_idx
+
num_tokens
-
1
)
tl
.
store
(
sampled_ptr
+
req_idx
*
sampled_stride
+
num_tokens
-
1
,
target_sampled
)
num_sampled
+=
1
tl
.
store
(
num_sampled_ptr
+
req_idx
,
num_sampled
)
def
rejection_sample
(
# [num_draft_tokens + num_reqs]
target_sampled
:
torch
.
Tensor
,
# [num_draft_tokens + num_reqs]
input_ids
:
torch
.
Tensor
,
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
,
num_speculative_steps
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
sampled
=
torch
.
empty
(
num_reqs
,
num_speculative_steps
+
1
,
dtype
=
target_sampled
.
dtype
,
device
=
target_sampled
.
device
,
)
num_sampled
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
target_sampled
.
device
,
)
_rejection_sample_kernel
[(
num_reqs
,)](
sampled
,
sampled
.
stride
(
0
),
num_sampled
,
target_sampled
,
input_ids
,
cu_num_logits
,
num_warps
=
1
,
)
return
sampled
,
num_sampled
vllm/v1/worker/gpu/states.py
View file @
b004c004
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
...
@@ -63,6 +64,7 @@ class RequestState:
...
@@ -63,6 +64,7 @@ class RequestState:
max_num_reqs
:
int
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
max_num_batched_tokens
:
int
,
num_speculative_steps
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
pin_memory
:
bool
,
pin_memory
:
bool
,
...
@@ -70,6 +72,7 @@ class RequestState:
...
@@ -70,6 +72,7 @@ class RequestState:
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
num_speculative_steps
=
num_speculative_steps
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
device
=
device
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
...
@@ -100,6 +103,14 @@ class RequestState:
...
@@ -100,6 +103,14 @@ class RequestState:
device
=
device
,
device
=
device
,
)
)
# Draft tokens.
self
.
draft_tokens
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
num_speculative_steps
,
dtype
=
torch
.
int64
,
device
=
device
,
)
# LoRA.
# LoRA.
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
...
@@ -226,6 +237,17 @@ class RequestState:
...
@@ -226,6 +237,17 @@ class RequestState:
max_num_logprobs
=
max_num_logprobs
,
max_num_logprobs
=
max_num_logprobs
,
)
)
def
expand_sampling_metadata
(
self
,
sampling_metadata
:
SamplingMetadata
,
cu_num_logits
:
torch
.
Tensor
,
)
->
SamplingMetadata
:
# For draft tokens, we need to expand the sampling param tensors as
# each request samples multiple tokens in each step.
return
expand_sampling_metadata
(
sampling_metadata
,
cu_num_logits
,
self
.
num_speculative_steps
)
def
make_lora_inputs
(
def
make_lora_inputs
(
self
,
self
,
req_ids
:
list
[
str
],
req_ids
:
list
[
str
],
...
@@ -270,3 +292,75 @@ class Param:
...
@@ -270,3 +292,75 @@ class Param:
class
ExtraData
:
class
ExtraData
:
lora_request
:
LoRARequest
|
None
lora_request
:
LoRARequest
|
None
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@
triton
.
jit
def
_expand_sampling_metadata_kernel
(
temp_ptr
,
expanded_temp_ptr
,
top_p_ptr
,
expanded_top_p_ptr
,
top_k_ptr
,
expanded_top_k_ptr
,
seeds_ptr
,
expanded_seeds_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
num_tokens
temp
=
tl
.
load
(
temp_ptr
+
req_idx
)
tl
.
store
(
expanded_temp_ptr
+
start_idx
+
block
,
temp
,
mask
=
mask
)
if
top_p_ptr
is
not
None
:
top_p
=
tl
.
load
(
top_p_ptr
+
req_idx
)
tl
.
store
(
expanded_top_p_ptr
+
start_idx
+
block
,
top_p
,
mask
=
mask
)
if
top_k_ptr
is
not
None
:
top_k
=
tl
.
load
(
top_k_ptr
+
req_idx
)
tl
.
store
(
expanded_top_k_ptr
+
start_idx
+
block
,
top_k
,
mask
=
mask
)
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
tl
.
store
(
expanded_seeds_ptr
+
start_idx
+
block
,
seed
,
mask
=
mask
)
def
expand_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
,
cu_num_logits
:
torch
.
Tensor
,
num_speculative_steps
:
int
,
)
->
SamplingMetadata
:
total_num_logits
=
sampling_metadata
.
pos
.
shape
[
0
]
create_empty
=
lambda
x
:
x
.
new_empty
(
total_num_logits
)
if
x
is
not
None
else
None
expanded_temp
=
create_empty
(
sampling_metadata
.
temperature
)
expanded_top_p
=
create_empty
(
sampling_metadata
.
top_p
)
expanded_top_k
=
create_empty
(
sampling_metadata
.
top_k
)
expanded_seeds
=
create_empty
(
sampling_metadata
.
seeds
)
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
_expand_sampling_metadata_kernel
[(
num_reqs
,)](
sampling_metadata
.
temperature
,
expanded_temp
,
sampling_metadata
.
top_p
,
expanded_top_p
,
sampling_metadata
.
top_k
,
expanded_top_k
,
sampling_metadata
.
seeds
,
expanded_seeds
,
cu_num_logits
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
num_speculative_steps
+
1
),
)
return
SamplingMetadata
(
temperature
=
expanded_temp
,
top_p
=
expanded_top_p
,
top_k
=
expanded_top_k
,
seeds
=
expanded_seeds
,
pos
=
sampling_metadata
.
pos
,
max_num_logprobs
=
sampling_metadata
.
max_num_logprobs
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment