Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b3a0d01e
Unverified
Commit
b3a0d01e
authored
Feb 04, 2025
by
Aviv Keshet
Committed by
GitHub
Feb 04, 2025
Browse files
[Core] add and implement `VLLM_LOGITS_PROCESSOR_THREADS` (#12368)
Signed-off-by:
Aviv Keshet
<
akeshet@scaledcognition.com
>
parent
75e94309
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
11 deletions
+44
-11
vllm/envs.py
vllm/envs.py
+9
-0
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+35
-11
No files found.
vllm/envs.py
View file @
b3a0d01e
...
...
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_LEVEL
:
str
=
"INFO"
VLLM_LOGGING_PREFIX
:
str
=
""
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_LOGITS_PROCESSOR_THREADS
:
Optional
[
int
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
Optional
[
bool
]
=
None
...
...
@@ -282,6 +283,14 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_PREFIX"
:
lambda
:
os
.
getenv
(
"VLLM_LOGGING_PREFIX"
,
""
),
# if set, vllm will call logits processors in a thread pool with this many
# threads. This is useful when using custom logits processors that either
# (a) launch additional CUDA kernels or (b) do significant CPU-bound work
# while not holding the python GIL, or both.
"VLLM_LOGITS_PROCESSOR_THREADS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_LOGITS_PROCESSOR_THREADS"
,
"0"
))
if
"VLLM_LOGITS_PROCESSOR_THREADS"
in
os
.
environ
else
None
,
# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
...
...
vllm/model_executor/layers/logits_processor.py
View file @
b3a0d01e
# SPDX-License-Identifier: Apache-2.0
"""A layer that compute logits from hidden_stats."""
import
inspect
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Optional
import
torch
...
...
@@ -15,6 +16,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
_logits_processor_threadpool
:
Optional
[
ThreadPoolExecutor
]
=
None
if
envs
.
VLLM_LOGITS_PROCESSOR_THREADS
is
not
None
:
_logits_processor_threadpool
=
ThreadPoolExecutor
(
envs
.
VLLM_LOGITS_PROCESSOR_THREADS
)
class
LogitsProcessor
(
nn
.
Module
):
"""Process logits and apply logits processors from sampling metadata.
...
...
@@ -135,6 +141,7 @@ def _apply_logits_processors(
)
->
torch
.
Tensor
:
found_logits_processors
=
False
logits_processed
=
0
logits_row_ids_and_logits_row_futures
=
[]
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
...
...
@@ -148,22 +155,39 @@ def _apply_logits_processors(
past_tokens_ids
=
seq_group
.
seq_data
[
seq_id
].
output_token_ids
prompt_tokens_ids
=
seq_group
.
seq_data
[
seq_id
].
prompt_token_ids
for
logits_processor
in
logits_processors
:
parameters
=
inspect
.
signature
(
logits_processor
).
parameters
if
len
(
parameters
)
==
3
:
logits_row
=
logits_processor
(
prompt_tokens_ids
,
past_tokens_ids
,
logits_row
)
else
:
logits_row
=
logits_processor
(
past_tokens_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
if
_logits_processor_threadpool
is
not
None
:
logits_row_ids_and_logits_row_futures
.
append
(
(
logits_row_idx
,
_logits_processor_threadpool
.
submit
(
_apply_logits_processors_single_seq
,
logits_row
,
logits_processors
,
past_tokens_ids
,
prompt_tokens_ids
)))
else
:
logits
[
logits_row_idx
]
=
\
_apply_logits_processors_single_seq
(
logits_row
,
logits_processors
,
past_tokens_ids
,
prompt_tokens_ids
)
logits_processed
+=
len
(
seq_group
.
sample_indices
)
+
len
(
seq_group
.
prompt_logprob_indices
)
for
logits_row_idx
,
future
in
logits_row_ids_and_logits_row_futures
:
logits
[
logits_row_idx
]
=
future
.
result
()
if
found_logits_processors
:
# verifies that no rows in logits were missed unexpectedly
assert
logits_processed
==
logits
.
shape
[
0
]
return
logits
def
_apply_logits_processors_single_seq
(
logits_row
,
logits_processors
,
past_tokens_ids
,
prompt_tokens_ids
)
->
torch
.
Tensor
:
for
logits_processor
in
logits_processors
:
parameters
=
inspect
.
signature
(
logits_processor
).
parameters
if
len
(
parameters
)
==
3
:
logits_row
=
logits_processor
(
prompt_tokens_ids
,
past_tokens_ids
,
logits_row
)
else
:
logits_row
=
logits_processor
(
past_tokens_ids
,
logits_row
)
return
logits_row
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment