Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ce8285d
Unverified
Commit
3ce8285d
authored
Aug 27, 2025
by
Nick Hill
Committed by
GitHub
Aug 27, 2025
Browse files
[LogitsProcs] Deduplicate built-in LP implementation logic (#23362)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
83f555f6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
143 deletions
+95
-143
examples/offline_inference/logits_processor.py
examples/offline_inference/logits_processor.py
+11
-27
tests/v1/logits_processors/utils.py
tests/v1/logits_processors/utils.py
+9
-28
vllm/v1/sample/logits_processor/builtin.py
vllm/v1/sample/logits_processor/builtin.py
+62
-86
vllm/v1/sample/logits_processor/interface.py
vllm/v1/sample/logits_processor/interface.py
+13
-2
No files found.
examples/offline_inference/logits_processor.py
View file @
3ce8285d
...
@@ -42,8 +42,8 @@ from vllm.config import VllmConfig
...
@@ -42,8 +42,8 @@ from vllm.config import VllmConfig
from
vllm.v1.sample.logits_processor
import
(
from
vllm.v1.sample.logits_processor
import
(
BatchUpdate
,
BatchUpdate
,
LogitsProcessor
,
LogitsProcessor
,
MoveDirectionality
,
)
)
from
vllm.v1.sample.logits_processor.builtin
import
process_dict_updates
# Hypothetical custom logits processor
# Hypothetical custom logits processor
...
@@ -53,38 +53,22 @@ class DummyLogitsProcessor(LogitsProcessor):
...
@@ -53,38 +53,22 @@ class DummyLogitsProcessor(LogitsProcessor):
def
__init__
(
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
is_pin_memory
:
bool
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
is_pin_memory
:
bool
):
):
self
.
req_info
:
dict
[
int
,
SamplingParams
]
=
{}
self
.
req_info
:
dict
[
int
,
int
]
=
{}
def
is_argmax_invariant
(
self
)
->
bool
:
def
is_argmax_invariant
(
self
)
->
bool
:
"""Never impacts greedy sampling"""
"""Never impacts greedy sampling"""
return
False
return
False
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
if
not
batch_update
:
process_dict_updates
(
return
self
.
req_info
,
batch_update
,
# Process added requests.
# This function returns the LP's per-request state based on the
for
index
,
params
,
_
,
_
in
batch_update
.
added
:
# request details, or None if this LP does not apply to the
assert
params
is
not
None
# request.
if
params
.
extra_args
and
(
lambda
params
,
_
,
__
:
params
.
extra_args
target_token
:
=
params
.
extra_args
.
get
(
"target_token"
)
and
(
params
.
extra_args
.
get
(
"target_token"
)),
):
)
self
.
req_info
[
index
]
=
target_token
if
self
.
req_info
:
# Process removed requests.
for
index
in
batch_update
.
removed
:
self
.
req_info
.
pop
(
index
,
None
)
# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for
adx
,
bdx
,
direct
in
batch_update
.
moved
:
a_val
=
self
.
req_info
.
pop
(
adx
,
None
)
b_val
=
self
.
req_info
.
pop
(
bdx
,
None
)
if
a_val
is
not
None
:
self
.
req_info
[
bdx
]
=
a_val
if
direct
==
MoveDirectionality
.
SWAP
and
b_val
is
not
None
:
self
.
req_info
[
adx
]
=
b_val
def
apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
not
self
.
req_info
:
if
not
self
.
req_info
:
...
...
tests/v1/logits_processors/utils.py
View file @
3ce8285d
...
@@ -8,10 +8,9 @@ from typing import Optional
...
@@ -8,10 +8,9 @@ from typing import Optional
import
torch
import
torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.sample.logits_processor
import
(
LOGITSPROCS_GROUP
,
BatchUpdate
,
from
vllm.v1.sample.logits_processor
import
(
LOGITSPROCS_GROUP
,
BatchUpdate
,
LogitsProcessor
,
LogitsProcessor
)
MoveDirectionality
)
from
vllm.v1.sample.logits_processor.builtin
import
process_dict_updates
MODEL_NAME
=
"facebook/opt-125m"
MODEL_NAME
=
"facebook/opt-125m"
POOLING_MODEL_NAME
=
"BAAI/bge-base-en-v1.5"
POOLING_MODEL_NAME
=
"BAAI/bge-base-en-v1.5"
...
@@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor):
...
@@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
device
:
torch
.
device
,
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
device
:
torch
.
device
,
is_pin_memory
:
bool
):
is_pin_memory
:
bool
):
self
.
req_info
:
dict
[
int
,
SamplingParams
]
=
{}
self
.
req_info
:
dict
[
int
,
int
]
=
{}
def
is_argmax_invariant
(
self
)
->
bool
:
def
is_argmax_invariant
(
self
)
->
bool
:
"""Never impacts greedy sampling"""
"""Never impacts greedy sampling"""
return
False
return
False
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
if
not
batch_update
:
process_dict_updates
(
return
self
.
req_info
,
batch_update
,
# Process added requests.
lambda
params
,
_
,
__
:
params
.
extra_args
and
for
index
,
params
,
_
,
_
in
batch_update
.
added
:
(
params
.
extra_args
.
get
(
"target_token"
)),
assert
params
is
not
None
)
if
params
.
extra_args
and
(
target_token
:
=
params
.
extra_args
.
get
(
"target_token"
)):
self
.
req_info
[
index
]
=
target_token
if
self
.
req_info
:
# Process removed requests.
for
index
in
batch_update
.
removed
:
self
.
req_info
.
pop
(
index
,
None
)
# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for
adx
,
bdx
,
direct
in
batch_update
.
moved
:
a_val
=
self
.
req_info
.
pop
(
adx
,
None
)
b_val
=
self
.
req_info
.
pop
(
bdx
,
None
)
if
a_val
is
not
None
:
self
.
req_info
[
bdx
]
=
a_val
if
direct
==
MoveDirectionality
.
SWAP
and
b_val
is
not
None
:
self
.
req_info
[
adx
]
=
b_val
def
apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
not
self
.
req_info
:
if
not
self
.
req_info
:
...
...
vllm/v1/sample/logits_processor/builtin.py
View file @
3ce8285d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
TypeVar
import
torch
import
torch
from
vllm
import
SamplingParams
from
vllm.v1.sample.logits_processor.interface
import
(
BatchUpdate
,
from
vllm.v1.sample.logits_processor.interface
import
(
BatchUpdate
,
LogitsProcessor
,
LogitsProcessor
,
MoveDirectionality
)
MoveDirectionality
)
...
@@ -12,6 +13,8 @@ from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
...
@@ -12,6 +13,8 @@ from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
T
=
TypeVar
(
"T"
)
class
MinPLogitsProcessor
(
LogitsProcessor
):
class
MinPLogitsProcessor
(
LogitsProcessor
):
...
@@ -130,49 +133,15 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
...
@@ -130,49 +133,15 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
return
False
return
False
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
if
not
batch_update
:
needs_update
=
process_dict_updates
(
return
self
.
biases
,
batch_update
,
lambda
params
,
_
,
__
:
params
.
logit_bias
or
None
)
needs_update
:
bool
=
False
# Process added requests.
for
index
,
params
,
_
,
_
in
batch_update
.
added
:
if
lb
:
=
params
.
logit_bias
:
self
.
biases
[
index
]
=
lb
needs_update
=
True
else
:
# Drop biases metadata at batch index
if
self
.
biases
.
pop
(
index
,
None
)
is
not
None
:
# If a new request replaces an old request which
# specified biases, we should update processor tensors
needs_update
=
True
if
self
.
biases
:
# Process removed requests.
for
index
in
batch_update
.
removed
:
if
self
.
biases
.
pop
(
index
,
None
):
needs_update
=
True
# Process moved requests, unidirectional (a->b) and swap (a<->b)
for
a_index
,
b_index
,
direct
in
batch_update
.
moved
:
if
direct
==
MoveDirectionality
.
UNIDIRECTIONAL
:
if
(
a_entry
:
=
self
.
biases
.
pop
(
a_index
,
None
))
is
None
:
if
self
.
biases
.
pop
(
b_index
,
None
)
is
not
None
:
needs_update
=
True
else
:
self
.
biases
[
b_index
]
=
a_entry
needs_update
=
True
else
:
a_entry
=
self
.
biases
.
pop
(
a_index
,
None
)
if
(
b_entry
:
=
self
.
biases
.
pop
(
b_index
,
None
))
is
not
None
:
self
.
biases
[
a_index
]
=
b_entry
needs_update
=
True
if
a_entry
is
not
None
:
self
.
biases
[
b_index
]
=
a_entry
needs_update
=
True
# Update tensors if needed.
# Update tensors if needed.
if
needs_update
:
if
needs_update
:
reqs
,
tok_ids
,
biases
=
[],
[],
[]
reqs
:
list
[
int
]
=
[]
tok_ids
:
list
[
int
]
=
[]
biases
:
list
[
float
]
=
[]
for
req
,
lb
in
self
.
biases
.
items
():
for
req
,
lb
in
self
.
biases
.
items
():
reqs
.
extend
([
req
]
*
len
(
lb
))
reqs
.
extend
([
req
]
*
len
(
lb
))
tok_ids
.
extend
(
lb
.
keys
())
tok_ids
.
extend
(
lb
.
keys
())
...
@@ -216,52 +185,18 @@ class MinTokensLogitsProcessor(LogitsProcessor):
...
@@ -216,52 +185,18 @@ class MinTokensLogitsProcessor(LogitsProcessor):
of the argmax operation in greedy sampling."""
of the argmax operation in greedy sampling."""
return
False
return
False
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
@
staticmethod
needs_update
=
False
def
add_request
(
params
:
SamplingParams
,
_
:
list
[
int
],
output_tok_ids
:
list
[
int
]
if
batch_update
:
)
->
Optional
[
tuple
[
int
,
Sequence
[
int
],
set
[
int
]]]:
# Process added requests.
min_tokens
=
params
.
min_tokens
for
index
,
params
,
_
,
output_tok_ids
in
batch_update
.
added
:
if
not
min_tokens
or
len
(
output_tok_ids
)
>=
min_tokens
:
if
((
min_tokens
:
=
params
.
min_tokens
)
return
None
and
len
(
output_tok_ids
)
<
min_tokens
):
return
min_tokens
,
output_tok_ids
,
params
.
all_stop_token_ids
# Replace request metadata at batch index
self
.
min_toks
[
index
]
=
(
min_tokens
,
output_tok_ids
,
params
.
all_stop_token_ids
)
needs_update
=
True
else
:
# Drop min_toks metadata at batch index
if
self
.
min_toks
.
pop
(
index
,
None
)
is
not
None
:
# If a new request replaces an old request which
# specified min_toks, we should update processor tensors
needs_update
=
True
if
self
.
min_toks
:
# Process removed requests.
for
index
in
batch_update
.
removed
:
if
self
.
min_toks
.
pop
(
index
,
None
):
needs_update
=
True
# Process moved requests, unidirectional (a->b) and
# swapped (a<->b)
for
a_index
,
b_index
,
direct
in
batch_update
.
moved
:
if
direct
==
MoveDirectionality
.
UNIDIRECTIONAL
:
if
(
a_entry
:
=
self
.
min_toks
.
pop
(
a_index
,
None
))
is
None
:
if
self
.
min_toks
.
pop
(
b_index
,
None
)
is
not
None
:
needs_update
=
True
else
:
self
.
min_toks
[
b_index
]
=
a_entry
needs_update
=
True
else
:
a_entry
=
self
.
min_toks
.
pop
(
a_index
,
None
)
if
(
b_entry
:
=
self
.
min_toks
.
pop
(
b_index
,
None
))
is
not
None
:
self
.
min_toks
[
a_index
]
=
b_entry
needs_update
=
True
if
a_entry
is
not
None
:
self
.
min_toks
[
b_index
]
=
a_entry
needs_update
=
True
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
needs_update
=
process_dict_updates
(
self
.
min_toks
,
batch_update
,
self
.
add_request
)
if
self
.
min_toks
:
if
self
.
min_toks
:
# Check for any requests that have attained their min tokens.
# Check for any requests that have attained their min tokens.
to_remove
=
tuple
(
index
for
index
,
(
min_toks
,
out_tok_ids
,
to_remove
=
tuple
(
index
for
index
,
(
min_toks
,
out_tok_ids
,
...
@@ -295,3 +230,44 @@ class MinTokensLogitsProcessor(LogitsProcessor):
...
@@ -295,3 +230,44 @@ class MinTokensLogitsProcessor(LogitsProcessor):
# Inhibit EOS token for requests which have not reached min length
# Inhibit EOS token for requests which have not reached min length
logits
[
self
.
logits_slice
]
=
-
float
(
"inf"
)
logits
[
self
.
logits_slice
]
=
-
float
(
"inf"
)
return
logits
return
logits
def
process_dict_updates
(
req_entries
:
dict
[
int
,
T
],
batch_update
:
Optional
[
BatchUpdate
],
new_state
:
Callable
[[
SamplingParams
,
list
[
int
],
list
[
int
]],
Optional
[
T
]]
)
->
bool
:
"""Utility function to update dict state for sparse LogitsProcessors."""
if
not
batch_update
:
# Nothing to do.
return
False
updated
=
False
for
index
,
params
,
prompt_tok_ids
,
output_tok_ids
in
batch_update
.
added
:
if
(
state
:
=
new_state
(
params
,
prompt_tok_ids
,
output_tok_ids
))
is
not
None
:
req_entries
[
index
]
=
state
updated
=
True
elif
req_entries
.
pop
(
index
,
None
)
is
not
None
:
updated
=
True
if
req_entries
:
# Process removed requests.
for
index
in
batch_update
.
removed
:
if
req_entries
.
pop
(
index
,
None
):
updated
=
True
# Process moved requests, unidirectional (a->b) and
# swapped (a<->b)
for
a_index
,
b_index
,
direct
in
batch_update
.
moved
:
a_entry
=
req_entries
.
pop
(
a_index
,
None
)
b_entry
=
req_entries
.
pop
(
b_index
,
None
)
if
a_entry
is
not
None
:
req_entries
[
b_index
]
=
a_entry
updated
=
True
if
b_entry
is
not
None
:
updated
=
True
if
direct
==
MoveDirectionality
.
SWAP
:
req_entries
[
a_index
]
=
b_entry
return
updated
vllm/v1/sample/logits_processor/interface.py
View file @
3ce8285d
...
@@ -44,10 +44,16 @@ class BatchUpdate:
...
@@ -44,10 +44,16 @@ class BatchUpdate:
# Key assumption: the `output_tok_ids` list (which is an element of each
# Key assumption: the `output_tok_ids` list (which is an element of each
# tuple in `added`) is a reference to the request's running output tokens
# tuple in `added`) is a reference to the request's running output tokens
# list; via this reference, the logits processors always see the latest
# list; via this reference, the logits processors always see the latest
# list of generated output tokens
# list of generated output tokens.
#
# NOTE:
# * Added or moved requests may replace existing requests with the same
# index.
# * Operations should be processed in the following order:
# - removed, added, moved
removed
:
Sequence
[
RemovedRequest
]
removed
:
Sequence
[
RemovedRequest
]
moved
:
Sequence
[
MovedRequest
]
added
:
Sequence
[
AddedRequest
]
added
:
Sequence
[
AddedRequest
]
moved
:
Sequence
[
MovedRequest
]
class
LogitsProcessor
(
ABC
):
class
LogitsProcessor
(
ABC
):
...
@@ -59,6 +65,11 @@ class LogitsProcessor(ABC):
...
@@ -59,6 +65,11 @@ class LogitsProcessor(ABC):
@
abstractmethod
@
abstractmethod
def
apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Apply LogitsProcessor to batch logits tensor.
The updated tensor must be returned but may be
modified in-place.
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment