Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6afc0ffa
Unverified
Commit
6afc0ffa
authored
Nov 29, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 29, 2025
Browse files
[Model Runner V2] Add sample/ directory and reorganize files (#29719)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
39e63dec
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
587 additions
and
237 deletions
+587
-237
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+11
-4
vllm/v1/worker/gpu/sample/__init__.py
vllm/v1/worker/gpu/sample/__init__.py
+0
-0
vllm/v1/worker/gpu/sample/gumbel.py
vllm/v1/worker/gpu/sample/gumbel.py
+100
-0
vllm/v1/worker/gpu/sample/logprob.py
vllm/v1/worker/gpu/sample/logprob.py
+167
-0
vllm/v1/worker/gpu/sample/metadata.py
vllm/v1/worker/gpu/sample/metadata.py
+179
-0
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+47
-1
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+79
-0
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+2
-2
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+2
-230
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
6afc0ffa
...
...
@@ -47,13 +47,18 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
)
from
vllm.v1.worker.gpu.sampler
import
Sampler
,
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
(
SamplingMetadata
,
expand_sampling_metadata
,
)
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
(
get_num_rejected
,
rejection_sample
,
)
from
vllm.v1.worker.gpu.states
import
RequestState
,
SamplingMetadata
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
...
...
@@ -890,8 +895,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
.
idx_mapping
,
input_batch
.
idx_mapping_np
,
pos
)
if
input_batch
.
num_draft_tokens
>
0
:
sampling_metadata
=
self
.
req_states
.
expand_sampling_metadata
(
sampling_metadata
,
input_batch
.
cu_num_logits
sampling_metadata
=
expand_sampling_metadata
(
sampling_metadata
,
input_batch
.
cu_num_logits
,
max_expand_len
=
self
.
num_speculative_steps
+
1
,
)
if
self
.
lora_config
:
...
...
vllm/v1/worker/gpu/sample/__init__.py
0 → 100644
View file @
6afc0ffa
vllm/v1/worker/gpu/sample/gumbel.py
0 → 100644
View file @
6afc0ffa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
_gumbel_sample_kernel
(
local_argmax_ptr
,
local_argmax_stride
,
local_max_ptr
,
local_max_stride
,
logits_ptr
,
logits_stride
,
seeds_ptr
,
pos_ptr
,
temp_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
logits
=
logits
.
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_idx
).
to
(
tl
.
float32
)
if
temp
!=
0.0
:
# Calculate the seed for gumbel noise.
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
pos
=
tl
.
load
(
pos_ptr
+
req_idx
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
# Generate gumbel noise.
r
=
tl
.
rand
(
gumbel_seed
,
block
).
to
(
tl
.
float64
)
gumbel_noise
=
-
tl
.
log
(
-
tl
.
log
(
r
+
1e-20
)
+
1e-20
)
gumbel_noise
=
gumbel_noise
.
to
(
tl
.
float32
)
# Apply temperature.
if
APPLY_TEMPERATURE
:
# NOTE(woosuk): Use div_rn to match the behavior of torch.
logits
=
tl
.
div_rn
(
logits
,
temp
)
# Apply gumbel noise.
logits
=
tl
.
where
(
mask
,
logits
+
gumbel_noise
,
float
(
"-inf"
))
idx
=
tl
.
argmax
(
logits
,
axis
=
0
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
value
=
tl
.
max
(
logits
,
axis
=
0
)
tl
.
store
(
local_argmax_ptr
+
req_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_max_ptr
+
req_idx
*
local_max_stride
+
block_idx
,
value
)
def
gumbel_sample
(
logits
:
torch
.
Tensor
,
# [num_reqs, vocab_size]
temperature
:
torch
.
Tensor
,
# [num_reqs]
seed
:
torch
.
Tensor
,
# [num_reqs]
pos
:
torch
.
Tensor
,
# [num_reqs]
apply_temperature
:
bool
,
)
->
torch
.
Tensor
:
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
local_argmax
=
torch
.
empty
(
num_reqs
,
num_blocks
,
dtype
=
torch
.
int64
,
device
=
logits
.
device
,
)
local_max
=
torch
.
empty
(
num_reqs
,
num_blocks
,
dtype
=
torch
.
float32
,
device
=
logits
.
device
,
)
_gumbel_sample_kernel
[(
num_reqs
,
num_blocks
)](
local_argmax
,
local_argmax
.
stride
(
0
),
local_max
,
local_max
.
stride
(
0
),
logits
,
logits
.
stride
(
0
),
seed
,
pos
,
temperature
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
APPLY_TEMPERATURE
=
apply_temperature
,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx
=
local_max
.
argmax
(
dim
=-
1
,
keepdim
=
True
)
sampled
=
local_argmax
.
gather
(
dim
=-
1
,
index
=
max_block_idx
).
view
(
-
1
)
return
sampled
vllm/v1/worker/gpu/sample
r
.py
→
vllm/v1/worker/gpu/sample
/logprob
.py
View file @
6afc0ffa
...
...
@@ -4,174 +4,8 @@ from collections.abc import Callable
import
torch
from
vllm.config.model
import
LogprobsMode
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.worker.gpu.penalties
import
apply_penalties
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
class
Sampler
:
def
__init__
(
self
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
,
):
if
logprobs_mode
not
in
[
"processed_logprobs"
,
"raw_logprobs"
]:
raise
NotImplementedError
(
f
"Unsupported logprobs_mode:
{
logprobs_mode
}
"
)
self
.
logprobs_mode
=
logprobs_mode
def
__call__
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
if
sampling_metadata
.
max_num_logprobs
is
not
None
:
if
self
.
logprobs_mode
==
"processed_logprobs"
:
sampled
,
logits
=
self
.
sample
(
logits
,
sampling_metadata
,
return_logits
=
True
)
else
:
assert
self
.
logprobs_mode
==
"raw_logprobs"
sampled
,
_
=
self
.
sample
(
logits
,
sampling_metadata
,
return_logits
=
False
)
logprobs_tensors
=
compute_topk_logprobs
(
logits
,
sampling_metadata
.
max_num_logprobs
,
sampled
,
)
else
:
sampled
,
_
=
self
.
sample
(
logits
,
sampling_metadata
,
return_logits
=
False
)
logprobs_tensors
=
None
# These are GPU tensors.
sampler_output
=
SamplerOutput
(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids
=
sampled
.
view
(
-
1
,
1
),
logprobs_tensors
=
logprobs_tensors
,
)
return
sampler_output
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
return_logits
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
is_greedy
=
sampling_metadata
.
temperature
==
0
temp
=
torch
.
where
(
is_greedy
,
1.0
,
sampling_metadata
.
temperature
)
logits
=
logits
/
temp
.
view
(
-
1
,
1
)
logits
=
apply_top_k_top_p
(
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
)
# Apply penalties in place.
apply_penalties
(
logits
,
sampling_metadata
)
sampled
=
gumbel_sample
(
logits
,
sampling_metadata
.
temperature
,
sampling_metadata
.
seeds
,
sampling_metadata
.
pos
,
apply_temperature
=
False
,
)
return
sampled
,
logits
if
return_logits
else
None
@
triton
.
jit
def
_gumbel_sample_kernel
(
local_argmax_ptr
,
local_argmax_stride
,
local_max_ptr
,
local_max_stride
,
logits_ptr
,
logits_stride
,
seeds_ptr
,
pos_ptr
,
temp_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
logits
=
logits
.
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_idx
).
to
(
tl
.
float32
)
if
temp
!=
0.0
:
# Calculate the seed for gumbel noise.
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
pos
=
tl
.
load
(
pos_ptr
+
req_idx
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
# Generate gumbel noise.
r
=
tl
.
rand
(
gumbel_seed
,
block
).
to
(
tl
.
float64
)
gumbel_noise
=
-
tl
.
log
(
-
tl
.
log
(
r
+
1e-20
)
+
1e-20
)
gumbel_noise
=
gumbel_noise
.
to
(
tl
.
float32
)
# Apply temperature.
if
APPLY_TEMPERATURE
:
# NOTE(woosuk): Use div_rn to match the behavior of torch.
logits
=
tl
.
div_rn
(
logits
,
temp
)
# Apply gumbel noise.
logits
=
tl
.
where
(
mask
,
logits
+
gumbel_noise
,
float
(
"-inf"
))
idx
=
tl
.
argmax
(
logits
,
axis
=
0
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
value
=
tl
.
max
(
logits
,
axis
=
0
)
tl
.
store
(
local_argmax_ptr
+
req_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_max_ptr
+
req_idx
*
local_max_stride
+
block_idx
,
value
)
def
gumbel_sample
(
logits
:
torch
.
Tensor
,
# [num_reqs, vocab_size]
temperature
:
torch
.
Tensor
,
# [num_reqs]
seed
:
torch
.
Tensor
,
# [num_reqs]
pos
:
torch
.
Tensor
,
# [num_reqs]
apply_temperature
:
bool
,
)
->
torch
.
Tensor
:
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
local_argmax
=
torch
.
empty
(
num_reqs
,
num_blocks
,
dtype
=
torch
.
int64
,
device
=
logits
.
device
,
)
local_max
=
torch
.
empty
(
num_reqs
,
num_blocks
,
dtype
=
torch
.
float32
,
device
=
logits
.
device
,
)
_gumbel_sample_kernel
[(
num_reqs
,
num_blocks
)](
local_argmax
,
local_argmax
.
stride
(
0
),
local_max
,
local_max
.
stride
(
0
),
logits
,
logits
.
stride
(
0
),
seed
,
pos
,
temperature
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
APPLY_TEMPERATURE
=
apply_temperature
,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx
=
local_max
.
argmax
(
dim
=-
1
,
keepdim
=
True
)
sampled
=
local_argmax
.
gather
(
dim
=-
1
,
index
=
max_block_idx
).
view
(
-
1
)
return
sampled
from
vllm.v1.outputs
import
LogprobsTensors
@
triton
.
jit
...
...
vllm/v1/worker/gpu/sample/metadata.py
0 → 100644
View file @
6afc0ffa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
from
vllm.triton_utils
import
tl
,
triton
@
dataclass
class
SamplingMetadata
:
temperature
:
torch
.
Tensor
top_p
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
repetition_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
presence_penalty
:
torch
.
Tensor
seeds
:
torch
.
Tensor
pos
:
torch
.
Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs
:
int
|
None
# For penalties
idx_mapping
:
torch
.
Tensor
prompt_bin_counts
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
@
classmethod
def
make_dummy
(
cls
,
num_reqs
:
int
,
device
:
torch
.
device
,
)
->
"SamplingMetadata"
:
assert
num_reqs
>
0
temperature
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
temperature
[
0
]
=
0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p
=
None
top_k
=
None
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
frequency_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
presence_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
return
cls
(
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
repetition_penalty
=
repetition_penalty
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_counts
=
prompt_bin_counts
,
output_bin_counts
=
output_bin_counts
,
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@
triton
.
jit
def
_expand_sampling_metadata_kernel
(
temp_ptr
,
expanded_temp_ptr
,
top_p_ptr
,
expanded_top_p_ptr
,
top_k_ptr
,
expanded_top_k_ptr
,
rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
freq_penalty_ptr
,
expanded_freq_penalty_ptr
,
pres_penalty_ptr
,
expanded_pres_penalty_ptr
,
seeds_ptr
,
expanded_seeds_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
num_tokens
temp
=
tl
.
load
(
temp_ptr
+
req_idx
)
tl
.
store
(
expanded_temp_ptr
+
start_idx
+
block
,
temp
,
mask
=
mask
)
if
top_p_ptr
is
not
None
:
top_p
=
tl
.
load
(
top_p_ptr
+
req_idx
)
tl
.
store
(
expanded_top_p_ptr
+
start_idx
+
block
,
top_p
,
mask
=
mask
)
if
top_k_ptr
is
not
None
:
top_k
=
tl
.
load
(
top_k_ptr
+
req_idx
)
tl
.
store
(
expanded_top_k_ptr
+
start_idx
+
block
,
top_k
,
mask
=
mask
)
rep_penalty
=
tl
.
load
(
rep_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_rep_penalty_ptr
+
start_idx
+
block
,
rep_penalty
,
mask
=
mask
)
freq_penalty
=
tl
.
load
(
freq_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_freq_penalty_ptr
+
start_idx
+
block
,
freq_penalty
,
mask
=
mask
)
pres_penalty
=
tl
.
load
(
pres_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_pres_penalty_ptr
+
start_idx
+
block
,
pres_penalty
,
mask
=
mask
)
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
tl
.
store
(
expanded_seeds_ptr
+
start_idx
+
block
,
seed
,
mask
=
mask
)
def
expand_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
,
cu_num_logits
:
torch
.
Tensor
,
max_expand_len
:
int
,
)
->
SamplingMetadata
:
total_num_logits
=
sampling_metadata
.
pos
.
shape
[
0
]
create_empty
=
lambda
x
:
x
.
new_empty
(
total_num_logits
)
if
x
is
not
None
else
None
expanded_temp
=
create_empty
(
sampling_metadata
.
temperature
)
expanded_top_p
=
create_empty
(
sampling_metadata
.
top_p
)
expanded_top_k
=
create_empty
(
sampling_metadata
.
top_k
)
expanded_repetition_penalty
=
create_empty
(
sampling_metadata
.
repetition_penalty
)
expanded_frequency_penalty
=
create_empty
(
sampling_metadata
.
frequency_penalty
)
expanded_presence_penalty
=
create_empty
(
sampling_metadata
.
presence_penalty
)
expanded_seeds
=
create_empty
(
sampling_metadata
.
seeds
)
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
_expand_sampling_metadata_kernel
[(
num_reqs
,)](
sampling_metadata
.
temperature
,
expanded_temp
,
sampling_metadata
.
top_p
,
expanded_top_p
,
sampling_metadata
.
top_k
,
expanded_top_k
,
sampling_metadata
.
repetition_penalty
,
expanded_repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
expanded_frequency_penalty
,
sampling_metadata
.
presence_penalty
,
expanded_presence_penalty
,
sampling_metadata
.
seeds
,
expanded_seeds
,
cu_num_logits
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
max_expand_len
),
)
return
SamplingMetadata
(
temperature
=
expanded_temp
,
top_p
=
expanded_top_p
,
top_k
=
expanded_top_k
,
seeds
=
expanded_seeds
,
repetition_penalty
=
expanded_repetition_penalty
,
frequency_penalty
=
expanded_frequency_penalty
,
presence_penalty
=
expanded_presence_penalty
,
pos
=
sampling_metadata
.
pos
,
max_num_logprobs
=
sampling_metadata
.
max_num_logprobs
,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping
=
sampling_metadata
.
idx_mapping
,
prompt_bin_counts
=
sampling_metadata
.
prompt_bin_counts
,
output_bin_counts
=
sampling_metadata
.
output_bin_counts
,
)
vllm/v1/worker/gpu/penalties.py
→
vllm/v1/worker/gpu/
sample/
penalties.py
View file @
6afc0ffa
...
...
@@ -3,7 +3,7 @@
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.s
tates
import
SamplingMetadata
from
vllm.v1.worker.gpu.s
ample.metadata
import
SamplingMetadata
@
triton
.
jit
...
...
@@ -83,3 +83,49 @@ def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
@
triton
.
jit
(
do_not_specialize
=
[
"prefill_len"
,
"prompt_len"
])
def
_bincount_kernel
(
prefill_token_ids_ptr
,
prefill_len
,
prompt_len
,
prompt_bin_counts_ptr
,
output_bin_counts_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
block_idx
=
tl
.
program_id
(
0
)
if
block_idx
*
BLOCK_SIZE
>=
prefill_len
:
return
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
if
block_idx
*
BLOCK_SIZE
<
prompt_len
:
mask
=
block
<
prompt_len
prefill_tokens
=
tl
.
load
(
prefill_token_ids_ptr
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
prompt_bin_counts_ptr
+
prefill_tokens
,
1
,
mask
=
mask
)
if
(
block_idx
+
1
)
*
BLOCK_SIZE
>=
prompt_len
:
mask
=
block
<
prefill_len
mask
&=
block
>=
prompt_len
prefill_tokens
=
tl
.
load
(
prefill_token_ids_ptr
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
output_bin_counts_ptr
+
prefill_tokens
,
1
,
mask
=
mask
)
def
bincount
(
prefill_token_ids
:
torch
.
Tensor
,
prefill_len
:
int
,
prompt_len
:
int
,
prompt_bin_counts
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
)
->
None
:
prompt_bin_counts
.
zero_
()
output_bin_counts
.
zero_
()
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
prefill_len
,
BLOCK_SIZE
)
_bincount_kernel
[(
num_blocks
,)](
prefill_token_ids
,
prefill_len
,
prompt_len
,
prompt_bin_counts
,
output_bin_counts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
vllm/v1/worker/gpu/sample/sampler.py
0 → 100644
View file @
6afc0ffa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.config.model
import
LogprobsMode
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.penalties
import
apply_penalties
class
Sampler
:
def
__init__
(
self
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
,
):
if
logprobs_mode
not
in
[
"processed_logprobs"
,
"raw_logprobs"
]:
raise
NotImplementedError
(
f
"Unsupported logprobs_mode:
{
logprobs_mode
}
"
)
self
.
logprobs_mode
=
logprobs_mode
def
__call__
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
if
sampling_metadata
.
max_num_logprobs
is
not
None
:
if
self
.
logprobs_mode
==
"processed_logprobs"
:
sampled
,
logits
=
self
.
sample
(
logits
,
sampling_metadata
,
return_logits
=
True
)
else
:
assert
self
.
logprobs_mode
==
"raw_logprobs"
sampled
,
_
=
self
.
sample
(
logits
,
sampling_metadata
,
return_logits
=
False
)
logprobs_tensors
=
compute_topk_logprobs
(
logits
,
sampling_metadata
.
max_num_logprobs
,
sampled
,
)
else
:
sampled
,
_
=
self
.
sample
(
logits
,
sampling_metadata
,
return_logits
=
False
)
logprobs_tensors
=
None
# These are GPU tensors.
sampler_output
=
SamplerOutput
(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids
=
sampled
.
view
(
-
1
,
1
),
logprobs_tensors
=
logprobs_tensors
,
)
return
sampler_output
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
return_logits
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
is_greedy
=
sampling_metadata
.
temperature
==
0
temp
=
torch
.
where
(
is_greedy
,
1.0
,
sampling_metadata
.
temperature
)
logits
=
logits
/
temp
.
view
(
-
1
,
1
)
logits
=
apply_top_k_top_p
(
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
)
# Apply penalties in place.
apply_penalties
(
logits
,
sampling_metadata
)
sampled
=
gumbel_sample
(
logits
,
sampling_metadata
.
temperature
,
sampling_metadata
.
seeds
,
sampling_metadata
.
pos
,
apply_temperature
=
False
,
)
return
sampled
,
logits
if
return_logits
else
None
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
6afc0ffa
...
...
@@ -18,9 +18,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.sampler
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.spec_decode.eagle_cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/worker/gpu/states.py
View file @
6afc0ffa
...
...
@@ -7,86 +7,18 @@ import torch
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.penalties
import
bincount
_NP_INT64_MIN
=
np
.
iinfo
(
np
.
int64
).
min
_NP_INT64_MAX
=
np
.
iinfo
(
np
.
int64
).
max
NO_LORA_ID
=
0
@
dataclass
class
SamplingMetadata
:
temperature
:
torch
.
Tensor
top_p
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
repetition_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
presence_penalty
:
torch
.
Tensor
seeds
:
torch
.
Tensor
pos
:
torch
.
Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs
:
int
|
None
# For penalties
idx_mapping
:
torch
.
Tensor
prompt_bin_counts
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
@
classmethod
def
make_dummy
(
cls
,
num_reqs
:
int
,
device
:
torch
.
device
,
)
->
"SamplingMetadata"
:
assert
num_reqs
>
0
temperature
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
temperature
[
0
]
=
0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p
=
None
top_k
=
None
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
frequency_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
presence_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
return
cls
(
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
repetition_penalty
=
repetition_penalty
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_counts
=
prompt_bin_counts
,
output_bin_counts
=
output_bin_counts
,
)
class
RequestState
:
def
__init__
(
self
,
...
...
@@ -311,17 +243,6 @@ class RequestState:
output_bin_counts
=
self
.
output_bin_counts
,
)
def
expand_sampling_metadata
(
self
,
sampling_metadata
:
SamplingMetadata
,
cu_num_logits
:
torch
.
Tensor
,
)
->
SamplingMetadata
:
# For draft tokens, we need to expand the sampling param tensors as
# each request samples multiple tokens in each step.
return
expand_sampling_metadata
(
sampling_metadata
,
cu_num_logits
,
self
.
num_speculative_steps
)
def
make_lora_inputs
(
self
,
req_ids
:
list
[
str
],
...
...
@@ -376,158 +297,9 @@ class UvaBuffer:
self
.
gpu
=
get_cuda_view_from_cpu_tensor
(
self
.
cpu
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@
triton
.
jit
def
_expand_sampling_metadata_kernel
(
temp_ptr
,
expanded_temp_ptr
,
top_p_ptr
,
expanded_top_p_ptr
,
top_k_ptr
,
expanded_top_k_ptr
,
rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
freq_penalty_ptr
,
expanded_freq_penalty_ptr
,
pres_penalty_ptr
,
expanded_pres_penalty_ptr
,
seeds_ptr
,
expanded_seeds_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
num_tokens
temp
=
tl
.
load
(
temp_ptr
+
req_idx
)
tl
.
store
(
expanded_temp_ptr
+
start_idx
+
block
,
temp
,
mask
=
mask
)
if
top_p_ptr
is
not
None
:
top_p
=
tl
.
load
(
top_p_ptr
+
req_idx
)
tl
.
store
(
expanded_top_p_ptr
+
start_idx
+
block
,
top_p
,
mask
=
mask
)
if
top_k_ptr
is
not
None
:
top_k
=
tl
.
load
(
top_k_ptr
+
req_idx
)
tl
.
store
(
expanded_top_k_ptr
+
start_idx
+
block
,
top_k
,
mask
=
mask
)
rep_penalty
=
tl
.
load
(
rep_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_rep_penalty_ptr
+
start_idx
+
block
,
rep_penalty
,
mask
=
mask
)
freq_penalty
=
tl
.
load
(
freq_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_freq_penalty_ptr
+
start_idx
+
block
,
freq_penalty
,
mask
=
mask
)
pres_penalty
=
tl
.
load
(
pres_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_pres_penalty_ptr
+
start_idx
+
block
,
pres_penalty
,
mask
=
mask
)
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
tl
.
store
(
expanded_seeds_ptr
+
start_idx
+
block
,
seed
,
mask
=
mask
)
def
expand_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
,
cu_num_logits
:
torch
.
Tensor
,
num_speculative_steps
:
int
,
)
->
SamplingMetadata
:
total_num_logits
=
sampling_metadata
.
pos
.
shape
[
0
]
create_empty
=
lambda
x
:
x
.
new_empty
(
total_num_logits
)
if
x
is
not
None
else
None
expanded_temp
=
create_empty
(
sampling_metadata
.
temperature
)
expanded_top_p
=
create_empty
(
sampling_metadata
.
top_p
)
expanded_top_k
=
create_empty
(
sampling_metadata
.
top_k
)
expanded_repetition_penalty
=
create_empty
(
sampling_metadata
.
repetition_penalty
)
expanded_frequency_penalty
=
create_empty
(
sampling_metadata
.
frequency_penalty
)
expanded_presence_penalty
=
create_empty
(
sampling_metadata
.
presence_penalty
)
expanded_seeds
=
create_empty
(
sampling_metadata
.
seeds
)
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
_expand_sampling_metadata_kernel
[(
num_reqs
,)](
sampling_metadata
.
temperature
,
expanded_temp
,
sampling_metadata
.
top_p
,
expanded_top_p
,
sampling_metadata
.
top_k
,
expanded_top_k
,
sampling_metadata
.
repetition_penalty
,
expanded_repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
expanded_frequency_penalty
,
sampling_metadata
.
presence_penalty
,
expanded_presence_penalty
,
sampling_metadata
.
seeds
,
expanded_seeds
,
cu_num_logits
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
num_speculative_steps
+
1
),
)
return
SamplingMetadata
(
temperature
=
expanded_temp
,
top_p
=
expanded_top_p
,
top_k
=
expanded_top_k
,
seeds
=
expanded_seeds
,
repetition_penalty
=
expanded_repetition_penalty
,
frequency_penalty
=
expanded_frequency_penalty
,
presence_penalty
=
expanded_presence_penalty
,
pos
=
sampling_metadata
.
pos
,
max_num_logprobs
=
sampling_metadata
.
max_num_logprobs
,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping
=
sampling_metadata
.
idx_mapping
,
prompt_bin_counts
=
sampling_metadata
.
prompt_bin_counts
,
output_bin_counts
=
sampling_metadata
.
output_bin_counts
,
)
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
return
(
sampling_params
.
repetition_penalty
!=
1.0
or
sampling_params
.
frequency_penalty
!=
0.0
or
sampling_params
.
presence_penalty
!=
0.0
)
@
triton
.
jit
(
do_not_specialize
=
[
"prefill_len"
,
"prompt_len"
])
def
_bincount_kernel
(
prefill_token_ids_ptr
,
prefill_len
,
prompt_len
,
prompt_bin_counts_ptr
,
output_bin_counts_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
block_idx
=
tl
.
program_id
(
0
)
if
block_idx
*
BLOCK_SIZE
>=
prefill_len
:
return
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
if
block_idx
*
BLOCK_SIZE
<
prompt_len
:
mask
=
block
<
prompt_len
prefill_tokens
=
tl
.
load
(
prefill_token_ids_ptr
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
prompt_bin_counts_ptr
+
prefill_tokens
,
1
,
mask
=
mask
)
if
(
block_idx
+
1
)
*
BLOCK_SIZE
>=
prompt_len
:
mask
=
block
<
prefill_len
mask
&=
block
>=
prompt_len
prefill_tokens
=
tl
.
load
(
prefill_token_ids_ptr
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
output_bin_counts_ptr
+
prefill_tokens
,
1
,
mask
=
mask
)
def
bincount
(
prefill_token_ids
:
torch
.
Tensor
,
prefill_len
:
int
,
prompt_len
:
int
,
prompt_bin_counts
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
)
->
None
:
prompt_bin_counts
.
zero_
()
output_bin_counts
.
zero_
()
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
prefill_len
,
BLOCK_SIZE
)
_bincount_kernel
[(
num_blocks
,)](
prefill_token_ids
,
prefill_len
,
prompt_len
,
prompt_bin_counts
,
output_bin_counts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment