Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90c08369
Unverified
Commit
90c08369
authored
Jan 13, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 13, 2026
Browse files
[Model Runner V2] Refactor Sampler (#32245)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
8ef50d9a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
289 additions
and
269 deletions
+289
-269
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+36
-25
vllm/v1/worker/gpu/sample/metadata.py
vllm/v1/worker/gpu/sample/metadata.py
+0
-79
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+103
-11
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+70
-22
vllm/v1/worker/gpu/sample/states.py
vllm/v1/worker/gpu/sample/states.py
+74
-0
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+6
-4
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+0
-128
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
90c08369
...
...
@@ -49,7 +49,6 @@ from vllm.v1.worker.gpu.input_batch import (
)
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
...
...
@@ -139,7 +138,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
self
.
sampler
=
Sampler
(
max_num_reqs
=
self
.
max_num_reqs
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
logprobs_mode
=
self
.
model_config
.
logprobs_mode
,
)
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
...
...
@@ -310,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states
:
torch
.
Tensor
,
)
->
None
:
num_reqs
=
hidden_states
.
shape
[
0
]
sampling_metadata
=
SamplingMetadata
.
make_dummy
(
num_reqs
=
num_reqs
,
device
=
self
.
device
,
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
)
self
.
sampler
(
logits
,
sampling_metadata
)
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self
.
sampler
(
logits
,
idx_mapping
,
idx_mapping_np
,
pos
)
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
...
...
@@ -401,9 +407,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert
new_req_data
.
prefill_token_ids
is
not
None
assert
new_req_data
.
sampling_params
is
not
None
req_id
=
new_req_data
.
req_id
prompt_len
=
len
(
new_req_data
.
prompt_token_ids
)
self
.
req_states
.
add_request
(
req_id
=
req_id
,
prompt_len
=
len
(
new_req_data
.
prompt_token_ids
)
,
prompt_len
=
prompt_len
,
prefill_token_ids
=
new_req_data
.
prefill_token_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
sampling_params
=
new_req_data
.
sampling_params
,
...
...
@@ -423,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
)
self
.
sampler
.
add_request
(
req_index
,
prompt_len
,
new_req_data
.
sampling_params
)
# Add new blocks for the existing requests.
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
...
...
@@ -436,6 +446,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
req_states
.
apply_staged_writes
()
self
.
block_tables
.
apply_staged_writes
()
self
.
sampler
.
apply_staged_writes
(
self
.
req_states
.
prefill_token_ids
.
gpu
,
self
.
req_states
.
prefill_len
.
np
,
self
.
req_states
.
prompt_len
,
)
if
self
.
uses_mrope
:
self
.
mrope_states
.
apply_staged_writes
()
...
...
@@ -612,10 +627,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
,
hidden_states
:
torch
.
Tensor
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
grammar_output
:
GrammarOutput
|
None
,
)
->
tuple
[
SamplerOutput
,
torch
.
Tensor
,
torch
.
Tensor
]:
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
grammar_output
is
not
None
:
# Apply grammar bitmask to the logits in-place.
...
...
@@ -627,7 +642,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
# Sample tokens and compute logprobs (if needed).
sampler_output
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampler_output
=
self
.
sampler
(
logits
,
input_batch
.
expanded_idx_mapping
,
input_batch
.
idx_mapping_np
,
sample_pos
,
)
if
input_batch
.
num_draft_tokens
==
0
:
# No draft tokens (common case).
...
...
@@ -766,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
.
idx_mapping
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req
_state
s
.
output_bin_counts
,
self
.
sampler
.
penalties
_state
.
output_bin_counts
,
sampled_tokens
,
num_sampled
,
num_rejected
,
...
...
@@ -786,7 +806,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def
propose_draft
(
self
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
last_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
num_sampled
:
torch
.
Tensor
,
...
...
@@ -801,13 +820,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
]
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
sampling_metadata
,
last_hidden_states
,
aux_hidden_states
,
num_sampled
,
num_rejected
,
last_sampled_tokens
,
next_prefill_tokens
,
self
.
sampler
.
sampling_states
.
temperature
.
gpu
,
self
.
sampler
.
sampling_states
.
seeds
.
gpu
,
)
return
draft_tokens
...
...
@@ -893,12 +913,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output
,
num_tokens_after_padding
,
)
pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
sampling_metadata
=
self
.
req_states
.
make_sampling_metadata
(
input_batch
.
expanded_idx_mapping
,
input_batch
.
idx_mapping_np
,
pos
)
if
self
.
lora_config
:
# Activate LoRA adapters.
lora_inputs
=
self
.
req_states
.
make_lora_inputs
(
...
...
@@ -917,7 +931,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device
=
self
.
device
,
)
self
.
prepare_dummy_attn_metadata
(
input_batch
)
sampling_metadata
=
None
# Run model.
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
...
...
@@ -946,7 +959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions
=
positions
,
)
self
.
execute_model_state
=
hidden_states
,
input_batch
,
sampling_metadata
self
.
execute_model_state
=
hidden_states
,
input_batch
return
None
@
torch
.
inference_mode
()
...
...
@@ -955,12 +968,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
grammar_output
:
GrammarOutput
|
None
,
)
->
AsyncOutput
|
ModelRunnerOutput
:
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
,
sampling_metadata
=
self
.
execute_model_state
hidden_states
,
input_batch
=
self
.
execute_model_state
self
.
execute_model_state
=
None
# type: ignore
assert
sampling_metadata
is
not
None
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
hidden_states
,
input_batch
,
sampling_metadata
,
grammar_output
hidden_states
,
input_batch
,
grammar_output
)
prompt_logprobs_dict
=
self
.
compute_prompt_logprobs
(
hidden_states
,
input_batch
)
...
...
@@ -992,7 +1004,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
self
.
do_spec_decode
:
draft_tokens
=
self
.
propose_draft
(
input_batch
,
sampling_metadata
,
hidden_states
,
None
,
# aux_hidden_states
num_sampled
,
...
...
vllm/v1/worker/gpu/sample/metadata.py
deleted
100644 → 0
View file @
8ef50d9a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
@
dataclass
class
SamplingMetadata
:
idx_mapping
:
torch
.
Tensor
temperature
:
torch
.
Tensor
top_p
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
min_p
:
torch
.
Tensor
|
None
# For penalties
repetition_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
presence_penalty
:
torch
.
Tensor
prompt_bin_mask
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
seeds
:
torch
.
Tensor
pos
:
torch
.
Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs
:
int
|
None
@
classmethod
def
make_dummy
(
cls
,
num_reqs
:
int
,
device
:
torch
.
device
,
)
->
"SamplingMetadata"
:
assert
num_reqs
>
0
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
temperature
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
temperature
[
0
]
=
0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p
=
None
top_k
=
None
min_p
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
frequency_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
presence_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_mask
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
return
cls
(
idx_mapping
=
idx_mapping
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
min_p
=
min_p
,
repetition_penalty
=
repetition_penalty
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
prompt_bin_mask
=
prompt_bin_mask
,
output_bin_counts
=
output_bin_counts
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
)
vllm/v1/worker/gpu/sample/penalties.py
View file @
90c08369
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBackedTensor
class
PenaltiesState
:
def
__init__
(
self
,
max_num_reqs
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
):
self
.
max_num_reqs
=
max_num_reqs
self
.
vocab_size
=
vocab_size
self
.
device
=
device
self
.
repetition_penalty
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
frequency_penalty
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
presence_penalty
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
# Initialize repetition penalty manually because 0 is an invalid value for it.
self
.
repetition_penalty
.
np
.
fill
(
1.0
)
self
.
repetition_penalty
.
copy_to_uva
()
# Statistics for penalties.
self
.
prompt_bin_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self
.
output_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
_penalties_reqs
:
list
[
int
]
=
[]
def
add_request
(
self
,
req_idx
:
int
,
sampling_params
:
SamplingParams
)
->
None
:
self
.
repetition_penalty
.
np
[
req_idx
]
=
sampling_params
.
repetition_penalty
self
.
frequency_penalty
.
np
[
req_idx
]
=
sampling_params
.
frequency_penalty
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
if
use_penalty
(
sampling_params
):
self
.
_penalties_reqs
.
append
(
req_idx
)
def
apply_staged_writes
(
self
,
prefill_token_ids
:
torch
.
Tensor
,
prefill_lens
:
np
.
ndarray
,
prompt_lens
:
np
.
ndarray
,
)
->
None
:
# TODO(woosuk): Optimize this.
for
req_idx
in
self
.
_penalties_reqs
:
bincount
(
prefill_token_ids
[
req_idx
],
int
(
prefill_lens
[
req_idx
]),
int
(
prompt_lens
[
req_idx
]),
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
self
.
_penalties_reqs
.
clear
()
self
.
repetition_penalty
.
copy_to_uva
()
self
.
frequency_penalty
.
copy_to_uva
()
self
.
presence_penalty
.
copy_to_uva
()
def
apply_penalties_and_temperature
(
self
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
)
->
None
:
apply_penalties_and_temperature
(
logits
,
idx_mapping
,
temperature
,
self
.
repetition_penalty
.
gpu
,
self
.
frequency_penalty
.
gpu
,
self
.
presence_penalty
.
gpu
,
self
.
prompt_bin_mask
,
self
.
output_bin_counts
,
)
@
triton
.
jit
...
...
@@ -84,7 +162,13 @@ def _penalties_and_temperature_kernel(
def
apply_penalties_and_temperature
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
idx_mapping
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
repetition_penalty
:
torch
.
Tensor
,
frequency_penalty
:
torch
.
Tensor
,
presence_penalty
:
torch
.
Tensor
,
prompt_bin_mask
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
)
->
None
:
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
8192
...
...
@@ -92,15 +176,15 @@ def apply_penalties_and_temperature(
_penalties_and_temperature_kernel
[(
num_reqs
,
num_blocks
)](
logits
,
logits
.
stride
(
0
),
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
sampling_metadata
.
presence_penalty
,
sampling_metadata
.
temperature
,
sampling_metadata
.
prompt_bin_mask
,
sampling_metadata
.
prompt_bin_mask
.
stride
(
0
),
sampling_metadata
.
output_bin_counts
,
sampling_metadata
.
output_bin_counts
.
stride
(
0
),
idx_mapping
,
repetition_penalty
,
frequency_penalty
,
presence_penalty
,
temperature
,
prompt_bin_mask
,
prompt_bin_mask
.
stride
(
0
),
output_bin_counts
,
output_bin_counts
.
stride
(
0
),
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
...
...
@@ -153,3 +237,11 @@ def bincount(
output_bin_counts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
return
(
sampling_params
.
repetition_penalty
!=
1.0
or
sampling_params
.
frequency_penalty
!=
0.0
or
sampling_params
.
presence_penalty
!=
0.0
)
vllm/v1/worker/gpu/sample/sampler.py
View file @
90c08369
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
torch
import
vllm.envs
as
envs
from
vllm.config.model
import
LogprobsMode
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.worker.gpu.metrics.logits
import
get_num_nans
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.logit_bias
import
LogitBiasState
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.min_p
import
apply_min_p
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.penalties
import
apply_penalties_and_temperature
from
vllm.v1.worker.gpu.sample.penalties
import
PenaltiesState
from
vllm.v1.worker.gpu.sample.states
import
NO_LOGPROBS
,
SamplingStates
class
Sampler
:
def
__init__
(
self
,
max_num_reqs
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
,
):
if
logprobs_mode
not
in
[
"processed_logprobs"
,
"raw_logprobs"
]:
...
...
@@ -25,26 +31,54 @@ class Sampler:
self
.
logprobs_mode
=
logprobs_mode
self
.
compute_nans
=
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
# False by default.
self
.
sampling_states
=
SamplingStates
(
max_num_reqs
,
vocab_size
)
self
.
penalties_state
=
PenaltiesState
(
max_num_reqs
,
vocab_size
,
device
)
self
.
logit_bias_state
=
LogitBiasState
(
max_num_reqs
,
device
)
def
add_request
(
self
,
req_idx
:
int
,
prompt_len
:
int
,
sampling_params
:
SamplingParams
,
)
->
None
:
self
.
sampling_states
.
add_request
(
req_idx
,
sampling_params
)
self
.
penalties_state
.
add_request
(
req_idx
,
sampling_params
)
self
.
logit_bias_state
.
add_request
(
req_idx
,
prompt_len
,
sampling_params
)
def
apply_staged_writes
(
self
,
prefill_token_ids
:
torch
.
Tensor
,
prefill_lens
:
np
.
ndarray
,
prompt_lens
:
np
.
ndarray
,
)
->
None
:
self
.
sampling_states
.
apply_staged_writes
()
self
.
penalties_state
.
apply_staged_writes
(
prefill_token_ids
,
prefill_lens
,
prompt_lens
)
self
.
logit_bias_state
.
apply_staged_writes
()
def
__call__
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
)
->
SamplerOutput
:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans
=
get_num_nans
(
logits
)
if
self
.
compute_nans
else
None
sampled
,
processed_logits
=
self
.
sample
(
logits
,
sampling_metadata
)
if
sampling_metadata
.
max_num_logprobs
is
not
None
:
sampled
,
processed_logits
=
self
.
sample
(
logits
,
idx_mapping
,
idx_mapping_np
,
pos
)
max_num_logprobs
=
self
.
sampling_states
.
max_num_logprobs
(
idx_mapping_np
)
if
max_num_logprobs
!=
NO_LOGPROBS
:
logits
=
(
processed_logits
if
self
.
logprobs_mode
==
"processed_logprobs"
else
logits
)
logprobs_tensors
=
compute_topk_logprobs
(
logits
,
sampling_metadata
.
max_num_logprobs
,
sampled
,
)
logprobs_tensors
=
compute_topk_logprobs
(
logits
,
max_num_logprobs
,
sampled
)
else
:
logprobs_tensors
=
None
...
...
@@ -62,27 +96,41 @@ class Sampler:
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Copy logits to a new FP32 tensor.
logits
=
torch
.
empty_like
(
logits
,
dtype
=
torch
.
float32
).
copy_
(
logits
)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self
.
logit_bias_state
.
apply_logit_bias
(
logits
,
idx_mapping
,
pos
)
# Apply penalties and temperature in place.
apply_penalties_and_temperature
(
logits
,
sampling_metadata
)
# Apply min_p in place.
if
sampling_metadata
.
min_p
is
not
None
:
apply_min_p
(
logits
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p. This might return a new tensor.
logits
=
apply_top_k_top_p
(
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
self
.
penalties_state
.
apply_penalties_and_temperature
(
logits
,
idx_mapping
,
self
.
sampling_states
.
temperature
.
gpu
)
# Apply min_p in place if any request has a non-zero min_p.
do_min_p
=
self
.
sampling_states
.
do_min_p
(
idx_mapping_np
)
if
do_min_p
:
apply_min_p
(
logits
,
idx_mapping
,
self
.
sampling_states
.
min_p
.
gpu
)
# Apply top_k and/or top_p. This might return a new tensor.
do_top_k
=
self
.
sampling_states
.
do_top_k
(
idx_mapping_np
)
top_k
=
self
.
sampling_states
.
top_k
.
gpu
[
idx_mapping
]
if
do_top_k
else
None
do_top_p
=
self
.
sampling_states
.
do_top_p
(
idx_mapping_np
)
top_p
=
self
.
sampling_states
.
top_p
.
gpu
[
idx_mapping
]
if
do_top_p
else
None
if
do_top_k
or
do_top_p
:
logits
=
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
# Sample the next token.
sampled
=
gumbel_sample
(
logits
,
sampling_metadata
.
idx_mapping
,
sampling_
metadata
.
temperature
,
sampling_
metadata
.
seeds
,
sampling_metadata
.
pos
,
idx_mapping
,
self
.
sampling_
states
.
temperature
.
gpu
,
self
.
sampling_
states
.
seeds
.
gpu
,
pos
,
apply_temperature
=
False
,
)
return
sampled
,
logits
vllm/v1/worker/gpu/sample/states.py
0 → 100644
View file @
90c08369
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBackedTensor
NO_LOGPROBS
=
-
1
_NP_INT64_MIN
=
np
.
iinfo
(
np
.
int64
).
min
_NP_INT64_MAX
=
np
.
iinfo
(
np
.
int64
).
max
class
SamplingStates
:
def
__init__
(
self
,
max_num_reqs
:
int
,
vocab_size
:
int
):
self
.
max_num_reqs
=
max_num_reqs
self
.
vocab_size
=
vocab_size
self
.
temperature
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_k
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
top_p
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
min_p
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
seeds
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
int64
)
# Initialize top_k and top_p manually because 0 is an invalid value for them.
self
.
top_k
.
np
.
fill
(
self
.
vocab_size
)
self
.
top_k
.
copy_to_uva
()
self
.
top_p
.
np
.
fill
(
1.0
)
self
.
top_p
.
copy_to_uva
()
self
.
num_logprobs
=
np
.
empty
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# -1 means no logprobs are requested.
self
.
num_logprobs
.
fill
(
NO_LOGPROBS
)
def
add_request
(
self
,
req_idx
:
int
,
sampling_params
:
SamplingParams
)
->
None
:
self
.
temperature
.
np
[
req_idx
]
=
sampling_params
.
temperature
self
.
top_p
.
np
[
req_idx
]
=
sampling_params
.
top_p
if
0
<
sampling_params
.
top_k
<
self
.
vocab_size
:
top_k
=
sampling_params
.
top_k
else
:
top_k
=
self
.
vocab_size
self
.
top_k
.
np
[
req_idx
]
=
top_k
self
.
min_p
.
np
[
req_idx
]
=
sampling_params
.
min_p
if
sampling_params
.
seed
is
not
None
:
seed
=
sampling_params
.
seed
else
:
seed
=
np
.
random
.
randint
(
_NP_INT64_MIN
,
_NP_INT64_MAX
)
self
.
seeds
.
np
[
req_idx
]
=
seed
if
sampling_params
.
logprobs
is
not
None
:
num_logprobs
=
sampling_params
.
logprobs
else
:
num_logprobs
=
NO_LOGPROBS
self
.
num_logprobs
[
req_idx
]
=
num_logprobs
def
apply_staged_writes
(
self
)
->
None
:
self
.
temperature
.
copy_to_uva
()
self
.
top_p
.
copy_to_uva
()
self
.
top_k
.
copy_to_uva
()
self
.
min_p
.
copy_to_uva
()
self
.
seeds
.
copy_to_uva
()
def
do_min_p
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
return
np
.
any
(
self
.
min_p
.
np
[
idx_mapping_np
]
!=
0.0
)
def
do_top_k
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
return
np
.
any
(
self
.
top_k
.
np
[
idx_mapping_np
]
!=
self
.
vocab_size
)
def
do_top_p
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
return
np
.
any
(
self
.
top_p
.
np
[
idx_mapping_np
]
!=
1.0
)
def
max_num_logprobs
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
int
:
return
int
(
np
.
max
(
self
.
num_logprobs
[
idx_mapping_np
]))
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
90c08369
...
...
@@ -17,7 +17,6 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.spec_decode.eagle_cudagraph
import
EagleCudaGraphManager
logger
=
init_logger
(
__name__
)
...
...
@@ -188,7 +187,6 @@ class EagleSpeculator:
def
propose
(
self
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
# [num_tokens, hidden_size]
last_hidden_states
:
torch
.
Tensor
,
# num_layers x [num_tokens, hidden_size]
...
...
@@ -201,6 +199,10 @@ class EagleSpeculator:
last_sampled
:
torch
.
Tensor
,
# [num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
# [max_num_reqs]
temperature
:
torch
.
Tensor
,
# [max_num_reqs]
seeds
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
...
...
@@ -246,8 +248,8 @@ class EagleSpeculator:
# affect the output distribution after rejection sampling.
idx_mapping
=
self
.
idx_mapping
[:
num_reqs
]
idx_mapping
.
copy_
(
input_batch
.
idx_mapping
)
self
.
temperature
.
copy_
(
sampling_metadata
.
temperature
)
self
.
seeds
.
copy_
(
sampling_metadata
.
seeds
)
self
.
temperature
.
copy_
(
temperature
)
self
.
seeds
.
copy_
(
seeds
)
# Gather the values and copy them to the pre-allocated buffers.
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
...
...
vllm/v1/worker/gpu/states.py
View file @
90c08369
...
...
@@ -7,14 +7,9 @@ import torch
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.penalties
import
bincount
_NP_INT64_MIN
=
np
.
iinfo
(
np
.
int64
).
min
_NP_INT64_MAX
=
np
.
iinfo
(
np
.
int64
).
max
NO_LORA_ID
=
0
...
...
@@ -81,38 +76,8 @@ class RequestState:
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
# Sampling parameters.
self
.
temperature
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_p
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_k
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
min_p
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
repetition_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
frequency_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
presence_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
seeds
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
)
self
.
num_logprobs
=
np
.
empty
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# -1 means no logprobs are requested.
self
.
num_logprobs
.
fill
(
-
1
)
self
.
needs_prompt_logprobs
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
bool
)
# Statistics for penalties.
self
.
prompt_bin_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# TODO(woosuk): This tensor is rarely used but can be extremely large.
# Optimize the memory usage.
self
.
output_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
_penalties_reqs
:
list
[
int
]
=
[]
@
property
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_id_to_index
)
...
...
@@ -147,33 +112,6 @@ class RequestState:
else
:
self
.
lora_ids
[
req_idx
]
=
NO_LORA_ID
self
.
temperature
.
np
[
req_idx
]
=
sampling_params
.
temperature
self
.
top_p
.
np
[
req_idx
]
=
sampling_params
.
top_p
if
0
<
sampling_params
.
top_k
<
self
.
vocab_size
:
top_k
=
sampling_params
.
top_k
else
:
top_k
=
self
.
vocab_size
self
.
top_k
.
np
[
req_idx
]
=
top_k
self
.
min_p
.
np
[
req_idx
]
=
sampling_params
.
min_p
self
.
repetition_penalty
.
np
[
req_idx
]
=
sampling_params
.
repetition_penalty
self
.
frequency_penalty
.
np
[
req_idx
]
=
sampling_params
.
frequency_penalty
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
if
use_penalty
(
sampling_params
):
self
.
_penalties_reqs
.
append
(
req_idx
)
if
sampling_params
.
seed
is
not
None
:
seed
=
sampling_params
.
seed
else
:
seed
=
np
.
random
.
randint
(
_NP_INT64_MIN
,
_NP_INT64_MAX
)
self
.
seeds
.
np
[
req_idx
]
=
seed
if
sampling_params
.
logprobs
is
not
None
:
num_logprobs
=
sampling_params
.
logprobs
else
:
num_logprobs
=
-
1
self
.
num_logprobs
[
req_idx
]
=
num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs
=
sampling_params
.
prompt_logprobs
is
not
None
self
.
needs_prompt_logprobs
[
req_idx
]
=
needs_prompt_logprobs
...
...
@@ -183,17 +121,6 @@ class RequestState:
self
.
prefill_token_ids
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
# TODO(woosuk): Optimize this.
for
req_idx
in
self
.
_penalties_reqs
:
bincount
(
self
.
prefill_token_ids
.
gpu
[
req_idx
],
int
(
self
.
prefill_len
.
np
[
req_idx
]),
int
(
self
.
prompt_len
[
req_idx
]),
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
self
.
_penalties_reqs
.
clear
()
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
extra_data
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
...
...
@@ -203,53 +130,6 @@ class RequestState:
self
.
index_to_req_id
.
pop
(
req_idx
,
None
)
self
.
free_indices
.
append
(
req_idx
)
def
make_sampling_metadata
(
self
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
)
->
SamplingMetadata
:
temperature
=
self
.
temperature
.
copy_to_uva
()
top_p
=
self
.
top_p
.
np
[
idx_mapping_np
]
no_top_p
=
np
.
all
(
top_p
==
1.0
)
top_p
=
self
.
top_p
.
copy_to_uva
()[
idx_mapping
]
if
not
no_top_p
else
None
top_k
=
self
.
top_k
.
np
[
idx_mapping_np
]
no_top_k
=
np
.
all
(
top_k
==
self
.
vocab_size
)
top_k
=
self
.
top_k
.
copy_to_uva
()[
idx_mapping
]
if
not
no_top_k
else
None
min_p
=
self
.
min_p
.
np
[
idx_mapping_np
]
no_min_p
=
np
.
all
(
min_p
==
0.0
)
min_p
=
self
.
min_p
.
copy_to_uva
()
if
not
no_min_p
else
None
rep_penalty
=
self
.
repetition_penalty
.
copy_to_uva
()
freq_penalty
=
self
.
frequency_penalty
.
copy_to_uva
()
pres_penalty
=
self
.
presence_penalty
.
copy_to_uva
()
seeds
=
self
.
seeds
.
copy_to_uva
()
num_logprobs
=
self
.
num_logprobs
[
idx_mapping_np
]
max_num_logprobs
:
int
|
None
=
int
(
np
.
max
(
num_logprobs
))
if
max_num_logprobs
==
-
1
:
max_num_logprobs
=
None
return
SamplingMetadata
(
idx_mapping
=
idx_mapping
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
min_p
=
min_p
,
repetition_penalty
=
rep_penalty
,
frequency_penalty
=
freq_penalty
,
presence_penalty
=
pres_penalty
,
prompt_bin_mask
=
self
.
prompt_bin_mask
,
output_bin_counts
=
self
.
output_bin_counts
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
)
def
make_lora_inputs
(
self
,
req_ids
:
list
[
str
],
...
...
@@ -272,11 +152,3 @@ class RequestState:
class
ExtraData
:
lora_request
:
LoRARequest
|
None
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
return
(
sampling_params
.
repetition_penalty
!=
1.0
or
sampling_params
.
frequency_penalty
!=
0.0
or
sampling_params
.
presence_penalty
!=
0.0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment