Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90c08369
Unverified
Commit
90c08369
authored
Jan 13, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 13, 2026
Browse files
[Model Runner V2] Refactor Sampler (#32245)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
8ef50d9a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
289 additions
and
269 deletions
+289
-269
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+36
-25
vllm/v1/worker/gpu/sample/metadata.py
vllm/v1/worker/gpu/sample/metadata.py
+0
-79
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+103
-11
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+70
-22
vllm/v1/worker/gpu/sample/states.py
vllm/v1/worker/gpu/sample/states.py
+74
-0
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+6
-4
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+0
-128
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
90c08369
...
@@ -49,7 +49,6 @@ from vllm.v1.worker.gpu.input_batch import (
...
@@ -49,7 +49,6 @@ from vllm.v1.worker.gpu.input_batch import (
)
)
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
...
@@ -139,7 +138,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -139,7 +138,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
self
.
sampler
=
Sampler
(
max_num_reqs
=
self
.
max_num_reqs
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
logprobs_mode
=
self
.
model_config
.
logprobs_mode
,
)
# CUDA graphs.
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
cudagraph_manager
=
CudaGraphManager
(
...
@@ -310,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -310,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
num_reqs
=
hidden_states
.
shape
[
0
]
num_reqs
=
hidden_states
.
shape
[
0
]
sampling_metadata
=
SamplingMetadata
.
make_dummy
(
num_reqs
=
num_reqs
,
device
=
self
.
device
,
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
)
self
.
sampler
(
logits
,
sampling_metadata
)
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self
.
sampler
(
logits
,
idx_mapping
,
idx_mapping_np
,
pos
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
...
@@ -401,9 +407,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -401,9 +407,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert
new_req_data
.
prefill_token_ids
is
not
None
assert
new_req_data
.
prefill_token_ids
is
not
None
assert
new_req_data
.
sampling_params
is
not
None
assert
new_req_data
.
sampling_params
is
not
None
req_id
=
new_req_data
.
req_id
req_id
=
new_req_data
.
req_id
prompt_len
=
len
(
new_req_data
.
prompt_token_ids
)
self
.
req_states
.
add_request
(
self
.
req_states
.
add_request
(
req_id
=
req_id
,
req_id
=
req_id
,
prompt_len
=
len
(
new_req_data
.
prompt_token_ids
)
,
prompt_len
=
prompt_len
,
prefill_token_ids
=
new_req_data
.
prefill_token_ids
,
prefill_token_ids
=
new_req_data
.
prefill_token_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
sampling_params
=
new_req_data
.
sampling_params
,
sampling_params
=
new_req_data
.
sampling_params
,
...
@@ -423,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -423,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
block_tables
.
append_block_ids
(
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
)
)
self
.
sampler
.
add_request
(
req_index
,
prompt_len
,
new_req_data
.
sampling_params
)
# Add new blocks for the existing requests.
# Add new blocks for the existing requests.
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
...
@@ -436,6 +446,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -436,6 +446,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
req_states
.
apply_staged_writes
()
self
.
req_states
.
apply_staged_writes
()
self
.
block_tables
.
apply_staged_writes
()
self
.
block_tables
.
apply_staged_writes
()
self
.
sampler
.
apply_staged_writes
(
self
.
req_states
.
prefill_token_ids
.
gpu
,
self
.
req_states
.
prefill_len
.
np
,
self
.
req_states
.
prompt_len
,
)
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
self
.
mrope_states
.
apply_staged_writes
()
self
.
mrope_states
.
apply_staged_writes
()
...
@@ -612,10 +627,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -612,10 +627,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
grammar_output
:
GrammarOutput
|
None
,
grammar_output
:
GrammarOutput
|
None
,
)
->
tuple
[
SamplerOutput
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
SamplerOutput
,
torch
.
Tensor
,
torch
.
Tensor
]:
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
grammar_output
is
not
None
:
if
grammar_output
is
not
None
:
# Apply grammar bitmask to the logits in-place.
# Apply grammar bitmask to the logits in-place.
...
@@ -627,7 +642,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -627,7 +642,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
# Sample tokens and compute logprobs (if needed).
# Sample tokens and compute logprobs (if needed).
sampler_output
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampler_output
=
self
.
sampler
(
logits
,
input_batch
.
expanded_idx_mapping
,
input_batch
.
idx_mapping_np
,
sample_pos
,
)
if
input_batch
.
num_draft_tokens
==
0
:
if
input_batch
.
num_draft_tokens
==
0
:
# No draft tokens (common case).
# No draft tokens (common case).
...
@@ -766,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -766,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
.
idx_mapping
,
input_batch
.
idx_mapping
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req
_state
s
.
output_bin_counts
,
self
.
sampler
.
penalties
_state
.
output_bin_counts
,
sampled_tokens
,
sampled_tokens
,
num_sampled
,
num_sampled
,
num_rejected
,
num_rejected
,
...
@@ -786,7 +806,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -786,7 +806,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def
propose_draft
(
def
propose_draft
(
self
,
self
,
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
last_hidden_states
:
torch
.
Tensor
,
last_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
...
@@ -801,13 +820,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -801,13 +820,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
]
]
draft_tokens
=
self
.
speculator
.
propose
(
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
input_batch
,
sampling_metadata
,
last_hidden_states
,
last_hidden_states
,
aux_hidden_states
,
aux_hidden_states
,
num_sampled
,
num_sampled
,
num_rejected
,
num_rejected
,
last_sampled_tokens
,
last_sampled_tokens
,
next_prefill_tokens
,
next_prefill_tokens
,
self
.
sampler
.
sampling_states
.
temperature
.
gpu
,
self
.
sampler
.
sampling_states
.
seeds
.
gpu
,
)
)
return
draft_tokens
return
draft_tokens
...
@@ -893,12 +913,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -893,12 +913,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output
,
scheduler_output
,
num_tokens_after_padding
,
num_tokens_after_padding
,
)
)
pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
sampling_metadata
=
self
.
req_states
.
make_sampling_metadata
(
input_batch
.
expanded_idx_mapping
,
input_batch
.
idx_mapping_np
,
pos
)
if
self
.
lora_config
:
if
self
.
lora_config
:
# Activate LoRA adapters.
# Activate LoRA adapters.
lora_inputs
=
self
.
req_states
.
make_lora_inputs
(
lora_inputs
=
self
.
req_states
.
make_lora_inputs
(
...
@@ -917,7 +931,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -917,7 +931,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
prepare_dummy_attn_metadata
(
input_batch
)
self
.
prepare_dummy_attn_metadata
(
input_batch
)
sampling_metadata
=
None
# Run model.
# Run model.
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
...
@@ -946,7 +959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -946,7 +959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions
=
positions
,
positions
=
positions
,
)
)
self
.
execute_model_state
=
hidden_states
,
input_batch
,
sampling_metadata
self
.
execute_model_state
=
hidden_states
,
input_batch
return
None
return
None
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -955,12 +968,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -955,12 +968,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
grammar_output
:
GrammarOutput
|
None
,
grammar_output
:
GrammarOutput
|
None
,
)
->
AsyncOutput
|
ModelRunnerOutput
:
)
->
AsyncOutput
|
ModelRunnerOutput
:
assert
self
.
execute_model_state
is
not
None
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
,
sampling_metadata
=
self
.
execute_model_state
hidden_states
,
input_batch
=
self
.
execute_model_state
self
.
execute_model_state
=
None
# type: ignore
self
.
execute_model_state
=
None
# type: ignore
assert
sampling_metadata
is
not
None
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
hidden_states
,
input_batch
,
sampling_metadata
,
grammar_output
hidden_states
,
input_batch
,
grammar_output
)
)
prompt_logprobs_dict
=
self
.
compute_prompt_logprobs
(
hidden_states
,
input_batch
)
prompt_logprobs_dict
=
self
.
compute_prompt_logprobs
(
hidden_states
,
input_batch
)
...
@@ -992,7 +1004,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -992,7 +1004,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
self
.
do_spec_decode
:
if
self
.
do_spec_decode
:
draft_tokens
=
self
.
propose_draft
(
draft_tokens
=
self
.
propose_draft
(
input_batch
,
input_batch
,
sampling_metadata
,
hidden_states
,
hidden_states
,
None
,
# aux_hidden_states
None
,
# aux_hidden_states
num_sampled
,
num_sampled
,
...
...
vllm/v1/worker/gpu/sample/metadata.py
deleted
100644 → 0
View file @
8ef50d9a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
@
dataclass
class
SamplingMetadata
:
idx_mapping
:
torch
.
Tensor
temperature
:
torch
.
Tensor
top_p
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
min_p
:
torch
.
Tensor
|
None
# For penalties
repetition_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
presence_penalty
:
torch
.
Tensor
prompt_bin_mask
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
seeds
:
torch
.
Tensor
pos
:
torch
.
Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs
:
int
|
None
@
classmethod
def
make_dummy
(
cls
,
num_reqs
:
int
,
device
:
torch
.
device
,
)
->
"SamplingMetadata"
:
assert
num_reqs
>
0
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
temperature
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
temperature
[
0
]
=
0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p
=
None
top_k
=
None
min_p
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
frequency_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
presence_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_mask
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
return
cls
(
idx_mapping
=
idx_mapping
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
min_p
=
min_p
,
repetition_penalty
=
repetition_penalty
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
prompt_bin_mask
=
prompt_bin_mask
,
output_bin_counts
=
output_bin_counts
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
)
vllm/v1/worker/gpu/sample/penalties.py
View file @
90c08369
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
torch
import
torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBackedTensor
class
PenaltiesState
:
def
__init__
(
self
,
max_num_reqs
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
):
self
.
max_num_reqs
=
max_num_reqs
self
.
vocab_size
=
vocab_size
self
.
device
=
device
self
.
repetition_penalty
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
frequency_penalty
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
presence_penalty
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
# Initialize repetition penalty manually because 0 is an invalid value for it.
self
.
repetition_penalty
.
np
.
fill
(
1.0
)
self
.
repetition_penalty
.
copy_to_uva
()
# Statistics for penalties.
self
.
prompt_bin_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self
.
output_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
_penalties_reqs
:
list
[
int
]
=
[]
def
add_request
(
self
,
req_idx
:
int
,
sampling_params
:
SamplingParams
)
->
None
:
self
.
repetition_penalty
.
np
[
req_idx
]
=
sampling_params
.
repetition_penalty
self
.
frequency_penalty
.
np
[
req_idx
]
=
sampling_params
.
frequency_penalty
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
if
use_penalty
(
sampling_params
):
self
.
_penalties_reqs
.
append
(
req_idx
)
def
apply_staged_writes
(
self
,
prefill_token_ids
:
torch
.
Tensor
,
prefill_lens
:
np
.
ndarray
,
prompt_lens
:
np
.
ndarray
,
)
->
None
:
# TODO(woosuk): Optimize this.
for
req_idx
in
self
.
_penalties_reqs
:
bincount
(
prefill_token_ids
[
req_idx
],
int
(
prefill_lens
[
req_idx
]),
int
(
prompt_lens
[
req_idx
]),
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
self
.
_penalties_reqs
.
clear
()
self
.
repetition_penalty
.
copy_to_uva
()
self
.
frequency_penalty
.
copy_to_uva
()
self
.
presence_penalty
.
copy_to_uva
()
def
apply_penalties_and_temperature
(
self
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
)
->
None
:
apply_penalties_and_temperature
(
logits
,
idx_mapping
,
temperature
,
self
.
repetition_penalty
.
gpu
,
self
.
frequency_penalty
.
gpu
,
self
.
presence_penalty
.
gpu
,
self
.
prompt_bin_mask
,
self
.
output_bin_counts
,
)
@
triton
.
jit
@
triton
.
jit
...
@@ -84,7 +162,13 @@ def _penalties_and_temperature_kernel(
...
@@ -84,7 +162,13 @@ def _penalties_and_temperature_kernel(
def
apply_penalties_and_temperature
(
def
apply_penalties_and_temperature
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
idx_mapping
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
repetition_penalty
:
torch
.
Tensor
,
frequency_penalty
:
torch
.
Tensor
,
presence_penalty
:
torch
.
Tensor
,
prompt_bin_mask
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
num_reqs
,
vocab_size
=
logits
.
shape
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
8192
BLOCK_SIZE
=
8192
...
@@ -92,15 +176,15 @@ def apply_penalties_and_temperature(
...
@@ -92,15 +176,15 @@ def apply_penalties_and_temperature(
_penalties_and_temperature_kernel
[(
num_reqs
,
num_blocks
)](
_penalties_and_temperature_kernel
[(
num_reqs
,
num_blocks
)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
sampling_metadata
.
idx_mapping
,
idx_mapping
,
sampling_metadata
.
repetition_penalty
,
repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
frequency_penalty
,
sampling_metadata
.
presence_penalty
,
presence_penalty
,
sampling_metadata
.
temperature
,
temperature
,
sampling_metadata
.
prompt_bin_mask
,
prompt_bin_mask
,
sampling_metadata
.
prompt_bin_mask
.
stride
(
0
),
prompt_bin_mask
.
stride
(
0
),
sampling_metadata
.
output_bin_counts
,
output_bin_counts
,
sampling_metadata
.
output_bin_counts
.
stride
(
0
),
output_bin_counts
.
stride
(
0
),
vocab_size
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
)
...
@@ -153,3 +237,11 @@ def bincount(
...
@@ -153,3 +237,11 @@ def bincount(
output_bin_counts
,
output_bin_counts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
)
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
return
(
sampling_params
.
repetition_penalty
!=
1.0
or
sampling_params
.
frequency_penalty
!=
0.0
or
sampling_params
.
presence_penalty
!=
0.0
)
vllm/v1/worker/gpu/sample/sampler.py
View file @
90c08369
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config.model
import
LogprobsMode
from
vllm.config.model
import
LogprobsMode
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.worker.gpu.metrics.logits
import
get_num_nans
from
vllm.v1.worker.gpu.metrics.logits
import
get_num_nans
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.logit_bias
import
LogitBiasState
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.min_p
import
apply_min_p
from
vllm.v1.worker.gpu.sample.min_p
import
apply_min_p
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.penalties
import
apply_penalties_and_temperature
from
vllm.v1.worker.gpu.sample.penalties
import
PenaltiesState
from
vllm.v1.worker.gpu.sample.states
import
NO_LOGPROBS
,
SamplingStates
class
Sampler
:
class
Sampler
:
def
__init__
(
def
__init__
(
self
,
self
,
max_num_reqs
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
,
):
):
if
logprobs_mode
not
in
[
"processed_logprobs"
,
"raw_logprobs"
]:
if
logprobs_mode
not
in
[
"processed_logprobs"
,
"raw_logprobs"
]:
...
@@ -25,26 +31,54 @@ class Sampler:
...
@@ -25,26 +31,54 @@ class Sampler:
self
.
logprobs_mode
=
logprobs_mode
self
.
logprobs_mode
=
logprobs_mode
self
.
compute_nans
=
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
# False by default.
self
.
compute_nans
=
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
# False by default.
self
.
sampling_states
=
SamplingStates
(
max_num_reqs
,
vocab_size
)
self
.
penalties_state
=
PenaltiesState
(
max_num_reqs
,
vocab_size
,
device
)
self
.
logit_bias_state
=
LogitBiasState
(
max_num_reqs
,
device
)
def
add_request
(
self
,
req_idx
:
int
,
prompt_len
:
int
,
sampling_params
:
SamplingParams
,
)
->
None
:
self
.
sampling_states
.
add_request
(
req_idx
,
sampling_params
)
self
.
penalties_state
.
add_request
(
req_idx
,
sampling_params
)
self
.
logit_bias_state
.
add_request
(
req_idx
,
prompt_len
,
sampling_params
)
def
apply_staged_writes
(
self
,
prefill_token_ids
:
torch
.
Tensor
,
prefill_lens
:
np
.
ndarray
,
prompt_lens
:
np
.
ndarray
,
)
->
None
:
self
.
sampling_states
.
apply_staged_writes
()
self
.
penalties_state
.
apply_staged_writes
(
prefill_token_ids
,
prefill_lens
,
prompt_lens
)
self
.
logit_bias_state
.
apply_staged_writes
()
def
__call__
(
def
__call__
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
# that num_nans is computed before applying penalties and temperature.
num_nans
=
get_num_nans
(
logits
)
if
self
.
compute_nans
else
None
num_nans
=
get_num_nans
(
logits
)
if
self
.
compute_nans
else
None
sampled
,
processed_logits
=
self
.
sample
(
logits
,
sampling_metadata
)
sampled
,
processed_logits
=
self
.
sample
(
if
sampling_metadata
.
max_num_logprobs
is
not
None
:
logits
,
idx_mapping
,
idx_mapping_np
,
pos
)
max_num_logprobs
=
self
.
sampling_states
.
max_num_logprobs
(
idx_mapping_np
)
if
max_num_logprobs
!=
NO_LOGPROBS
:
logits
=
(
logits
=
(
processed_logits
processed_logits
if
self
.
logprobs_mode
==
"processed_logprobs"
if
self
.
logprobs_mode
==
"processed_logprobs"
else
logits
else
logits
)
)
logprobs_tensors
=
compute_topk_logprobs
(
logprobs_tensors
=
compute_topk_logprobs
(
logits
,
max_num_logprobs
,
sampled
)
logits
,
sampling_metadata
.
max_num_logprobs
,
sampled
,
)
else
:
else
:
logprobs_tensors
=
None
logprobs_tensors
=
None
...
@@ -62,27 +96,41 @@ class Sampler:
...
@@ -62,27 +96,41 @@ class Sampler:
def
sample
(
def
sample
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Copy logits to a new FP32 tensor.
# Copy logits to a new FP32 tensor.
logits
=
torch
.
empty_like
(
logits
,
dtype
=
torch
.
float32
).
copy_
(
logits
)
logits
=
torch
.
empty_like
(
logits
,
dtype
=
torch
.
float32
).
copy_
(
logits
)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self
.
logit_bias_state
.
apply_logit_bias
(
logits
,
idx_mapping
,
pos
)
# Apply penalties and temperature in place.
# Apply penalties and temperature in place.
apply_penalties_and_temperature
(
logits
,
sampling_metadata
)
self
.
penalties_state
.
apply_penalties_and_temperature
(
# Apply min_p in place.
logits
,
idx_mapping
,
self
.
sampling_states
.
temperature
.
gpu
if
sampling_metadata
.
min_p
is
not
None
:
apply_min_p
(
logits
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p. This might return a new tensor.
logits
=
apply_top_k_top_p
(
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
)
)
# Apply min_p in place if any request has a non-zero min_p.
do_min_p
=
self
.
sampling_states
.
do_min_p
(
idx_mapping_np
)
if
do_min_p
:
apply_min_p
(
logits
,
idx_mapping
,
self
.
sampling_states
.
min_p
.
gpu
)
# Apply top_k and/or top_p. This might return a new tensor.
do_top_k
=
self
.
sampling_states
.
do_top_k
(
idx_mapping_np
)
top_k
=
self
.
sampling_states
.
top_k
.
gpu
[
idx_mapping
]
if
do_top_k
else
None
do_top_p
=
self
.
sampling_states
.
do_top_p
(
idx_mapping_np
)
top_p
=
self
.
sampling_states
.
top_p
.
gpu
[
idx_mapping
]
if
do_top_p
else
None
if
do_top_k
or
do_top_p
:
logits
=
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
# Sample the next token.
sampled
=
gumbel_sample
(
sampled
=
gumbel_sample
(
logits
,
logits
,
sampling_metadata
.
idx_mapping
,
idx_mapping
,
sampling_
metadata
.
temperature
,
self
.
sampling_
states
.
temperature
.
gpu
,
sampling_
metadata
.
seeds
,
self
.
sampling_
states
.
seeds
.
gpu
,
sampling_metadata
.
pos
,
pos
,
apply_temperature
=
False
,
apply_temperature
=
False
,
)
)
return
sampled
,
logits
return
sampled
,
logits
vllm/v1/worker/gpu/sample/states.py
0 → 100644
View file @
90c08369
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBackedTensor
NO_LOGPROBS
=
-
1
_NP_INT64_MIN
=
np
.
iinfo
(
np
.
int64
).
min
_NP_INT64_MAX
=
np
.
iinfo
(
np
.
int64
).
max
class
SamplingStates
:
def
__init__
(
self
,
max_num_reqs
:
int
,
vocab_size
:
int
):
self
.
max_num_reqs
=
max_num_reqs
self
.
vocab_size
=
vocab_size
self
.
temperature
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_k
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
top_p
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
min_p
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
seeds
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
int64
)
# Initialize top_k and top_p manually because 0 is an invalid value for them.
self
.
top_k
.
np
.
fill
(
self
.
vocab_size
)
self
.
top_k
.
copy_to_uva
()
self
.
top_p
.
np
.
fill
(
1.0
)
self
.
top_p
.
copy_to_uva
()
self
.
num_logprobs
=
np
.
empty
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# -1 means no logprobs are requested.
self
.
num_logprobs
.
fill
(
NO_LOGPROBS
)
def
add_request
(
self
,
req_idx
:
int
,
sampling_params
:
SamplingParams
)
->
None
:
self
.
temperature
.
np
[
req_idx
]
=
sampling_params
.
temperature
self
.
top_p
.
np
[
req_idx
]
=
sampling_params
.
top_p
if
0
<
sampling_params
.
top_k
<
self
.
vocab_size
:
top_k
=
sampling_params
.
top_k
else
:
top_k
=
self
.
vocab_size
self
.
top_k
.
np
[
req_idx
]
=
top_k
self
.
min_p
.
np
[
req_idx
]
=
sampling_params
.
min_p
if
sampling_params
.
seed
is
not
None
:
seed
=
sampling_params
.
seed
else
:
seed
=
np
.
random
.
randint
(
_NP_INT64_MIN
,
_NP_INT64_MAX
)
self
.
seeds
.
np
[
req_idx
]
=
seed
if
sampling_params
.
logprobs
is
not
None
:
num_logprobs
=
sampling_params
.
logprobs
else
:
num_logprobs
=
NO_LOGPROBS
self
.
num_logprobs
[
req_idx
]
=
num_logprobs
def
apply_staged_writes
(
self
)
->
None
:
self
.
temperature
.
copy_to_uva
()
self
.
top_p
.
copy_to_uva
()
self
.
top_k
.
copy_to_uva
()
self
.
min_p
.
copy_to_uva
()
self
.
seeds
.
copy_to_uva
()
def
do_min_p
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
return
np
.
any
(
self
.
min_p
.
np
[
idx_mapping_np
]
!=
0.0
)
def
do_top_k
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
return
np
.
any
(
self
.
top_k
.
np
[
idx_mapping_np
]
!=
self
.
vocab_size
)
def
do_top_p
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
return
np
.
any
(
self
.
top_p
.
np
[
idx_mapping_np
]
!=
1.0
)
def
max_num_logprobs
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
int
:
return
int
(
np
.
max
(
self
.
num_logprobs
[
idx_mapping_np
]))
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
90c08369
...
@@ -17,7 +17,6 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
...
@@ -17,7 +17,6 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.spec_decode.eagle_cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.spec_decode.eagle_cudagraph
import
EagleCudaGraphManager
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -188,7 +187,6 @@ class EagleSpeculator:
...
@@ -188,7 +187,6 @@ class EagleSpeculator:
def
propose
(
def
propose
(
self
,
self
,
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
# [num_tokens, hidden_size]
# [num_tokens, hidden_size]
last_hidden_states
:
torch
.
Tensor
,
last_hidden_states
:
torch
.
Tensor
,
# num_layers x [num_tokens, hidden_size]
# num_layers x [num_tokens, hidden_size]
...
@@ -201,6 +199,10 @@ class EagleSpeculator:
...
@@ -201,6 +199,10 @@ class EagleSpeculator:
last_sampled
:
torch
.
Tensor
,
last_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
next_prefill_tokens
:
torch
.
Tensor
,
# [max_num_reqs]
temperature
:
torch
.
Tensor
,
# [max_num_reqs]
seeds
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
# number of rejected tokens, we maintain the size of eagle's input_ids and
...
@@ -246,8 +248,8 @@ class EagleSpeculator:
...
@@ -246,8 +248,8 @@ class EagleSpeculator:
# affect the output distribution after rejection sampling.
# affect the output distribution after rejection sampling.
idx_mapping
=
self
.
idx_mapping
[:
num_reqs
]
idx_mapping
=
self
.
idx_mapping
[:
num_reqs
]
idx_mapping
.
copy_
(
input_batch
.
idx_mapping
)
idx_mapping
.
copy_
(
input_batch
.
idx_mapping
)
self
.
temperature
.
copy_
(
sampling_metadata
.
temperature
)
self
.
temperature
.
copy_
(
temperature
)
self
.
seeds
.
copy_
(
sampling_metadata
.
seeds
)
self
.
seeds
.
copy_
(
seeds
)
# Gather the values and copy them to the pre-allocated buffers.
# Gather the values and copy them to the pre-allocated buffers.
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
...
...
vllm/v1/worker/gpu/states.py
View file @
90c08369
...
@@ -7,14 +7,9 @@ import torch
...
@@ -7,14 +7,9 @@ import torch
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.penalties
import
bincount
_NP_INT64_MIN
=
np
.
iinfo
(
np
.
int64
).
min
_NP_INT64_MAX
=
np
.
iinfo
(
np
.
int64
).
max
NO_LORA_ID
=
0
NO_LORA_ID
=
0
...
@@ -81,38 +76,8 @@ class RequestState:
...
@@ -81,38 +76,8 @@ class RequestState:
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
# Sampling parameters.
self
.
temperature
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_p
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_k
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
min_p
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
repetition_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
frequency_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
presence_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
seeds
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
)
self
.
num_logprobs
=
np
.
empty
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# -1 means no logprobs are requested.
self
.
num_logprobs
.
fill
(
-
1
)
self
.
needs_prompt_logprobs
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
bool
)
self
.
needs_prompt_logprobs
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
bool
)
# Statistics for penalties.
self
.
prompt_bin_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# TODO(woosuk): This tensor is rarely used but can be extremely large.
# Optimize the memory usage.
self
.
output_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
_penalties_reqs
:
list
[
int
]
=
[]
@
property
@
property
def
num_reqs
(
self
)
->
int
:
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_id_to_index
)
return
len
(
self
.
req_id_to_index
)
...
@@ -147,33 +112,6 @@ class RequestState:
...
@@ -147,33 +112,6 @@ class RequestState:
else
:
else
:
self
.
lora_ids
[
req_idx
]
=
NO_LORA_ID
self
.
lora_ids
[
req_idx
]
=
NO_LORA_ID
self
.
temperature
.
np
[
req_idx
]
=
sampling_params
.
temperature
self
.
top_p
.
np
[
req_idx
]
=
sampling_params
.
top_p
if
0
<
sampling_params
.
top_k
<
self
.
vocab_size
:
top_k
=
sampling_params
.
top_k
else
:
top_k
=
self
.
vocab_size
self
.
top_k
.
np
[
req_idx
]
=
top_k
self
.
min_p
.
np
[
req_idx
]
=
sampling_params
.
min_p
self
.
repetition_penalty
.
np
[
req_idx
]
=
sampling_params
.
repetition_penalty
self
.
frequency_penalty
.
np
[
req_idx
]
=
sampling_params
.
frequency_penalty
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
if
use_penalty
(
sampling_params
):
self
.
_penalties_reqs
.
append
(
req_idx
)
if
sampling_params
.
seed
is
not
None
:
seed
=
sampling_params
.
seed
else
:
seed
=
np
.
random
.
randint
(
_NP_INT64_MIN
,
_NP_INT64_MAX
)
self
.
seeds
.
np
[
req_idx
]
=
seed
if
sampling_params
.
logprobs
is
not
None
:
num_logprobs
=
sampling_params
.
logprobs
else
:
num_logprobs
=
-
1
self
.
num_logprobs
[
req_idx
]
=
num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs
=
sampling_params
.
prompt_logprobs
is
not
None
needs_prompt_logprobs
=
sampling_params
.
prompt_logprobs
is
not
None
self
.
needs_prompt_logprobs
[
req_idx
]
=
needs_prompt_logprobs
self
.
needs_prompt_logprobs
[
req_idx
]
=
needs_prompt_logprobs
...
@@ -183,17 +121,6 @@ class RequestState:
...
@@ -183,17 +121,6 @@ class RequestState:
self
.
prefill_token_ids
.
apply_write
()
self
.
prefill_token_ids
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
# TODO(woosuk): Optimize this.
for
req_idx
in
self
.
_penalties_reqs
:
bincount
(
self
.
prefill_token_ids
.
gpu
[
req_idx
],
int
(
self
.
prefill_len
.
np
[
req_idx
]),
int
(
self
.
prompt_len
[
req_idx
]),
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
self
.
_penalties_reqs
.
clear
()
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
extra_data
.
pop
(
req_id
,
None
)
self
.
extra_data
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
...
@@ -203,53 +130,6 @@ class RequestState:
...
@@ -203,53 +130,6 @@ class RequestState:
self
.
index_to_req_id
.
pop
(
req_idx
,
None
)
self
.
index_to_req_id
.
pop
(
req_idx
,
None
)
self
.
free_indices
.
append
(
req_idx
)
self
.
free_indices
.
append
(
req_idx
)
def
make_sampling_metadata
(
self
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
)
->
SamplingMetadata
:
temperature
=
self
.
temperature
.
copy_to_uva
()
top_p
=
self
.
top_p
.
np
[
idx_mapping_np
]
no_top_p
=
np
.
all
(
top_p
==
1.0
)
top_p
=
self
.
top_p
.
copy_to_uva
()[
idx_mapping
]
if
not
no_top_p
else
None
top_k
=
self
.
top_k
.
np
[
idx_mapping_np
]
no_top_k
=
np
.
all
(
top_k
==
self
.
vocab_size
)
top_k
=
self
.
top_k
.
copy_to_uva
()[
idx_mapping
]
if
not
no_top_k
else
None
min_p
=
self
.
min_p
.
np
[
idx_mapping_np
]
no_min_p
=
np
.
all
(
min_p
==
0.0
)
min_p
=
self
.
min_p
.
copy_to_uva
()
if
not
no_min_p
else
None
rep_penalty
=
self
.
repetition_penalty
.
copy_to_uva
()
freq_penalty
=
self
.
frequency_penalty
.
copy_to_uva
()
pres_penalty
=
self
.
presence_penalty
.
copy_to_uva
()
seeds
=
self
.
seeds
.
copy_to_uva
()
num_logprobs
=
self
.
num_logprobs
[
idx_mapping_np
]
max_num_logprobs
:
int
|
None
=
int
(
np
.
max
(
num_logprobs
))
if
max_num_logprobs
==
-
1
:
max_num_logprobs
=
None
return
SamplingMetadata
(
idx_mapping
=
idx_mapping
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
min_p
=
min_p
,
repetition_penalty
=
rep_penalty
,
frequency_penalty
=
freq_penalty
,
presence_penalty
=
pres_penalty
,
prompt_bin_mask
=
self
.
prompt_bin_mask
,
output_bin_counts
=
self
.
output_bin_counts
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
)
def
make_lora_inputs
(
def
make_lora_inputs
(
self
,
self
,
req_ids
:
list
[
str
],
req_ids
:
list
[
str
],
...
@@ -272,11 +152,3 @@ class RequestState:
...
@@ -272,11 +152,3 @@ class RequestState:
class
ExtraData
:
class
ExtraData
:
lora_request
:
LoRARequest
|
None
lora_request
:
LoRARequest
|
None
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
return
(
sampling_params
.
repetition_penalty
!=
1.0
or
sampling_params
.
frequency_penalty
!=
0.0
or
sampling_params
.
presence_penalty
!=
0.0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment