Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc050558
Unverified
Commit
cc050558
authored
Dec 04, 2025
by
Woosuk Kwon
Committed by
GitHub
Dec 04, 2025
Browse files
[Model Runner V2] Implement get_num_sampled_and_rejected kernel (#30029)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
5c32a06a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
29 deletions
+65
-29
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+49
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+16
-17
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
+0
-12
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
cc050558
...
@@ -354,6 +354,55 @@ def combine_sampled_and_draft_tokens(
...
@@ -354,6 +354,55 @@ def combine_sampled_and_draft_tokens(
return
logits_indices
return
logits_indices
@
triton
.
jit
def
_get_num_sampled_and_rejected_kernel
(
num_sampled_ptr
,
num_rejected_ptr
,
seq_lens_ptr
,
cu_num_logits_ptr
,
idx_mapping_ptr
,
prefill_len_ptr
,
):
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
seq_len
=
tl
.
load
(
seq_lens_ptr
+
batch_idx
)
prefill_len
=
tl
.
load
(
prefill_len_ptr
+
req_state_idx
)
is_chunked_prefilling
=
seq_len
<
prefill_len
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
num_sampled
=
tl
.
where
(
is_chunked_prefilling
,
0
,
num_sampled
)
tl
.
store
(
num_sampled_ptr
+
batch_idx
,
num_sampled
)
logits_start
=
tl
.
load
(
cu_num_logits_ptr
+
batch_idx
)
logits_end
=
tl
.
load
(
cu_num_logits_ptr
+
batch_idx
+
1
)
num_logits
=
logits_end
-
logits_start
num_rejected
=
num_logits
-
num_sampled
num_rejected
=
tl
.
where
(
is_chunked_prefilling
,
0
,
num_rejected
)
tl
.
store
(
num_rejected_ptr
+
batch_idx
,
num_rejected
)
def
get_num_sampled_and_rejected
(
num_sampled
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
cu_num_logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
prefill_len
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_reqs
=
idx_mapping
.
shape
[
0
]
num_rejected
=
torch
.
empty_like
(
num_sampled
)
_get_num_sampled_and_rejected_kernel
[(
num_reqs
,)](
num_sampled
,
num_rejected
,
seq_lens
,
cu_num_logits
,
idx_mapping
,
prefill_len
,
)
return
num_sampled
,
num_rejected
@
triton
.
jit
@
triton
.
jit
def
_post_update_kernel
(
def
_post_update_kernel
(
idx_mapping_ptr
,
idx_mapping_ptr
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
cc050558
...
@@ -43,6 +43,7 @@ from vllm.v1.worker.gpu.input_batch import (
...
@@ -43,6 +43,7 @@ from vllm.v1.worker.gpu.input_batch import (
InputBatch
,
InputBatch
,
InputBuffers
,
InputBuffers
,
combine_sampled_and_draft_tokens
,
combine_sampled_and_draft_tokens
,
get_num_sampled_and_rejected
,
post_update
,
post_update
,
prepare_pos_seq_lens
,
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
prepare_prefill_inputs
,
...
@@ -54,10 +55,7 @@ from vllm.v1.worker.gpu.sample.metadata import (
...
@@ -54,10 +55,7 @@ from vllm.v1.worker.gpu.sample.metadata import (
)
)
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
(
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
get_num_rejected
,
rejection_sample
,
)
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
...
@@ -621,16 +619,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -621,16 +619,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Sample tokens and compute logprobs (if needed).
# Sample tokens and compute logprobs (if needed).
sampler_output
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampler_output
=
self
.
sampler
(
logits
,
sampling_metadata
)
# Get the number of sampled tokens.
prefill_len
=
self
.
req_states
.
prefill_len
.
gpu
[
input_batch
.
idx_mapping
]
is_chunked_prefilling
=
input_batch
.
seq_lens
<
prefill_len
if
input_batch
.
num_draft_tokens
==
0
:
if
input_batch
.
num_draft_tokens
==
0
:
# No draft tokens (common case).
# No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not.
num_sampled
=
torch
.
ones
(
num_sampled
=
(
~
is_chunked_prefilling
).
int
()
input_batch
.
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
num_rejected
=
torch
.
zeros_like
(
num_sampled
)
)
else
:
else
:
#
Draft tokens
for spec decoding.
#
Rejection sampling
for spec decoding.
input_ids
=
input_batch
.
input_ids
[
input_batch
.
logits_indices
]
input_ids
=
input_batch
.
input_ids
[
input_batch
.
logits_indices
]
sampled_tokens
,
num_sampled
=
rejection_sample
(
sampled_tokens
,
num_sampled
=
rejection_sample
(
sampler_output
.
sampled_token_ids
,
sampler_output
.
sampled_token_ids
,
...
@@ -638,13 +633,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -638,13 +633,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
.
cu_num_logits
,
input_batch
.
cu_num_logits
,
self
.
num_speculative_steps
,
self
.
num_speculative_steps
,
)
)
num_sampled
*=
~
is_chunked_prefilling
sampler_output
.
sampled_token_ids
=
sampled_tokens
num_rejected
=
get_num_rejected
(
input_batch
.
cu_num_logits
,
# Get the number of sampled and rejected tokens.
# For chunked prefills, num_sampled and num_rejected are both 0.
num_sampled
,
num_rejected
=
get_num_sampled_and_rejected
(
num_sampled
,
num_sampled
,
input_batch
.
seq_lens
,
input_batch
.
cu_num_logits
,
input_batch
.
idx_mapping
,
self
.
req_states
.
prefill_len
.
gpu
,
)
)
sampler_output
.
sampled_token_ids
=
sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding.
return
sampler_output
,
num_sampled
,
num_rejected
return
sampler_output
,
num_sampled
,
num_rejected
def
compute_prompt_logprobs
(
def
compute_prompt_logprobs
(
...
...
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
View file @
cc050558
...
@@ -69,15 +69,3 @@ def rejection_sample(
...
@@ -69,15 +69,3 @@ def rejection_sample(
num_warps
=
1
,
num_warps
=
1
,
)
)
return
sampled
,
num_sampled
return
sampled
,
num_sampled
@
torch
.
compile
(
dynamic
=
True
)
def
get_num_rejected
(
cu_num_logits
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_logits
=
cu_num_logits
[
1
:]
-
cu_num_logits
[:
-
1
]
num_rejected
=
num_logits
-
num_sampled
# No token is rejected for chunked prefills.
num_rejected
*=
num_sampled
>
0
return
num_rejected
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment