Unverified Commit b3a0d01e authored by Aviv Keshet's avatar Aviv Keshet Committed by GitHub
Browse files

[Core] add and implement `VLLM_LOGITS_PROCESSOR_THREADS` (#12368)


Signed-off-by: default avatarAviv Keshet <akeshet@scaledcognition.com>
parent 75e94309
...@@ -31,6 +31,7 @@ if TYPE_CHECKING: ...@@ -31,6 +31,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_LEVEL: str = "INFO"
VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_PREFIX: str = ""
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
...@@ -282,6 +283,14 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -282,6 +283,14 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_PREFIX": "VLLM_LOGGING_PREFIX":
lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), lambda: os.getenv("VLLM_LOGGING_PREFIX", ""),
# if set, vllm will call logits processors in a thread pool with this many
# threads. This is useful when using custom logits processors that either
# (a) launch additional CUDA kernels or (b) do significant CPU-bound work
# while not holding the python GIL, or both.
"VLLM_LOGITS_PROCESSOR_THREADS":
lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0"))
if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None,
# Trace function calls # Trace function calls
# If set to 1, vllm will trace function calls # If set to 1, vllm will trace function calls
# Useful for debugging # Useful for debugging
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A layer that compute logits from hidden_stats.""" """A layer that compute logits from hidden_stats."""
import inspect import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Optional from typing import Optional
import torch import torch
...@@ -15,6 +16,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -15,6 +16,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform from vllm.platforms import current_platform
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
_logits_processor_threadpool = ThreadPoolExecutor(
envs.VLLM_LOGITS_PROCESSOR_THREADS)
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata. """Process logits and apply logits processors from sampling metadata.
...@@ -135,6 +141,7 @@ def _apply_logits_processors( ...@@ -135,6 +141,7 @@ def _apply_logits_processors(
) -> torch.Tensor: ) -> torch.Tensor:
found_logits_processors = False found_logits_processors = False
logits_processed = 0 logits_processed = 0
logits_row_ids_and_logits_row_futures = []
for seq_group in sampling_metadata.seq_groups: for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
...@@ -148,22 +155,39 @@ def _apply_logits_processors( ...@@ -148,22 +155,39 @@ def _apply_logits_processors(
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
for logits_processor in logits_processors: if _logits_processor_threadpool is not None:
parameters = inspect.signature(logits_processor).parameters logits_row_ids_and_logits_row_futures.append(
if len(parameters) == 3: (logits_row_idx,
logits_row = logits_processor(prompt_tokens_ids, _logits_processor_threadpool.submit(
past_tokens_ids, _apply_logits_processors_single_seq, logits_row,
logits_row) logits_processors, past_tokens_ids,
prompt_tokens_ids)))
else: else:
logits_row = logits_processor(past_tokens_ids, logits[logits_row_idx] = \
logits_row) _apply_logits_processors_single_seq(
logits_row, logits_processors, past_tokens_ids,
logits[logits_row_idx] = logits_row prompt_tokens_ids)
logits_processed += len(seq_group.sample_indices) + len( logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices) seq_group.prompt_logprob_indices)
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
logits[logits_row_idx] = future.result()
if found_logits_processors: if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly # verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0] assert logits_processed == logits.shape[0]
return logits return logits
def _apply_logits_processors_single_seq(logits_row, logits_processors,
past_tokens_ids,
prompt_tokens_ids) -> torch.Tensor:
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids, logits_row)
return logits_row
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment