Unverified Commit 43a26958 authored by Bhuvan Agrawal's avatar Bhuvan Agrawal Committed by GitHub
Browse files

feat: add logits processor support for trtllm backend (#2702)


Signed-off-by: default avatarBhuvan Agrawal <11240550+bhuvan002@users.noreply.github.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent 699996e4
......@@ -43,6 +43,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
- [Client](#client)
- [Benchmarking](#benchmarking)
- [Multimodal Support](#multimodal-support)
- [Logits Processing](#logits-processing)
- [Performance Sweep](#performance-sweep)
## Feature Support Matrix
......@@ -242,6 +243,63 @@ To benchmark your deployment with GenAI-Perf, see this utility script, configuri
Dynamo with the TensorRT-LLM backend supports multimodal models, enabling you to process both text and images (or pre-computed embeddings) in a single request. For detailed setup instructions, example requests, and best practices, see the [Multimodal Support Guide](./multimodal_support.md).
## Logits Processing
Logits processors let you modify the next-token logits at every decoding step (e.g., to apply custom constraints or sampling transforms). Dynamo provides a backend-agnostic interface and an adapter for TensorRT-LLM so you can plug in custom processors.
### How it works
- **Interface**: Implement `dynamo.logits_processing.BaseLogitsProcessor` which defines `__call__(input_ids, logits)` and modifies `logits` in-place.
- **TRT-LLM adapter**: Use `dynamo.trtllm.logits_processing.adapter.create_trtllm_adapters(...)` to convert Dynamo processors into TRT-LLM-compatible processors and assign them to `SamplingParams.logits_processor`.
- **Examples**: See example processors in `lib/bindings/python/src/dynamo/logits_processing/examples/` ([temperature](../../../lib/bindings/python/src/dynamo/logits_processing/examples/temperature.py), [hello_world](../../../lib/bindings/python/src/dynamo/logits_processing/examples/hello_world.py)).
### Quick test: HelloWorld processor
You can enable a test-only processor that forces the model to respond with "Hello world!". This is useful to verify the wiring without modifying your model or engine code.
```bash
cd $DYNAMO_HOME/components/backends/trtllm
export DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR=1
./launch/agg.sh
```
Notes:
- When enabled, Dynamo initializes the tokenizer so the HelloWorld processor can map text to token IDs.
- Expected chat response contains "Hello world".
### Bring your own processor
Implement a processor by conforming to `BaseLogitsProcessor` and modify logits in-place. For example, temperature scaling:
```python
from typing import Sequence
import torch
from dynamo.logits_processing import BaseLogitsProcessor
class TemperatureProcessor(BaseLogitsProcessor):
def __init__(self, temperature: float = 1.0):
if temperature <= 0:
raise ValueError("Temperature must be positive")
self.temperature = temperature
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor):
if self.temperature == 1.0:
return
logits.div_(self.temperature)
```
Wire it into TRT-LLM by adapting and attaching to `SamplingParams`:
```python
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
from dynamo.logits_processing.examples import TemperatureProcessor
processors = [TemperatureProcessor(temperature=0.7)]
sampling_params.logits_processor = create_trtllm_adapters(processors)
```
### Current limitations
- Per-request processing only (batch size must be 1); beam width > 1 is not supported.
- Processors must modify logits in-place and not return a new tensor.
- If your processor needs tokenization, ensure the tokenizer is initialized (do not skip tokenizer init).
## Performance Sweep
For detailed instructions on running comprehensive performance sweeps across both aggregated and disaggregated serving configurations, see the [TensorRT-LLM Benchmark Scripts for DeepSeek R1 model](./performance_sweeps/README.md). This guide covers recommended benchmarking setups, usage of provided scripts, and best practices for evaluating system performance.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import List, Optional
import torch
from tensorrt_llm.sampling_params import LogitsProcessor
from dynamo.logits_processing import BaseLogitsProcessor
logger = logging.getLogger(__name__)
class TrtllmDynamoLogitsAdapter(LogitsProcessor):
"""
Adapter that wraps Dynamo BaseLogitsProcessor instances to work with TensorRT-LLM's logits processor interface.
Inherits from tensorrt_llm.LogitsProcessor and implements the required interface:
__call__(self, req_ids: int, logits: torch.Tensor, ids: List[List[int]], stream_ptr, client_id: Optional[int])
This adapter maintains per-request state and converts between the interfaces.
"""
def __init__(self, processor: BaseLogitsProcessor):
super().__init__()
self.processor = processor
def __call__(
self,
req_ids: int,
logits: torch.Tensor,
ids: List[List[int]],
stream_ptr,
client_id: Optional[int] = None,
):
"""
TensorRT-LLM logits processor interface.
Args:
req_ids: Request identifier
logits: Logits tensor for current step
ids: List of token sequences (batch of sequences)
stream_ptr: CUDA stream pointer
client_id: Optional client identifier
Returns:
Modified logits tensor (in-place modification expected by TRT-LLM)
"""
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
try:
with torch.cuda.stream(stream):
if logits.shape[0] != 1:
raise ValueError(
f"This logits adapter only supports per-request logits processing. "
f"Received logits with batch size {logits.shape[0]} expected 1"
)
if logits.shape[1] != 1:
raise ValueError(
"Logits processing with beam width > 1 is not supported"
)
# Call the processor which modifies the logits in-place
self.processor(ids[0], logits[0, 0, :])
except Exception as e:
logger.error(f"Error in logits processor for request {req_ids}: {e}")
# Don't modify logits on error
# TRT-LLM expects void return (in-place modification)
def create_trtllm_adapters(
processors: List[BaseLogitsProcessor],
) -> List[TrtllmDynamoLogitsAdapter]:
"""
Create TensorRT-LLM compatible adapters from Dynamo logits processors.
Args:
processors: List of Dynamo BaseLogitsProcessor instances
Returns:
List of TensorRT-LLM compatible logits processor adapters
"""
adapters = []
for processor in processors:
adapter = TrtllmDynamoLogitsAdapter(processor)
adapters.append(adapter)
return adapters
......@@ -3,6 +3,7 @@
import asyncio
import logging
import os
import signal
import sys
......@@ -225,6 +226,12 @@ async def init(runtime: DistributedRuntime, config: Config):
modelType = ModelType.Backend
multimodal_processor = None
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
# We need to initialize the tokenizer for the test logits processor
# But detokenizing still happens in the rust engine, so we do _not_ want
# to set default_sampling_params.detokenize to True.
engine_args["skip_tokenizer_init"] = False
if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False
modelType = ModelType.Chat
......
......@@ -15,6 +15,7 @@
import copy
import logging
import os
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Optional, Union
......@@ -23,9 +24,11 @@ import torch
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import Publisher
from dynamo.trtllm.utils.disagg_utils import (
......@@ -182,6 +185,12 @@ class HandlerBase:
request_id = request.get("id") or request.get("request_id", "unknown-id")
model_name = request.get("model", "unknown_model")
# Optional test-only logits processing (enable with DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR=1)
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
processors = [HelloWorldLogitsProcessor(self.engine.llm.tokenizer)]
adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters
# NEW: Updated engine call to include multimodal data
async for res in self.engine.llm.generate_async(
inputs=processed_input, # Use the correctly extracted inputs
......
......@@ -8,11 +8,12 @@ This module defines the core BaseLogitsProcessor interface that all
logits processors must implement.
"""
from typing import Protocol, Sequence
from typing import Protocol, Sequence, runtime_checkable
import torch
@runtime_checkable
class BaseLogitsProcessor(Protocol):
"""
Protocol for logits processors in Dynamo.
......@@ -25,7 +26,7 @@ class BaseLogitsProcessor(Protocol):
self,
input_ids: Sequence[int],
logits: torch.Tensor,
) -> torch.Tensor:
) -> None:
"""
Process the logits for the next token prediction.
......@@ -33,7 +34,6 @@ class BaseLogitsProcessor(Protocol):
input_ids: The input token IDs generated so far.
logits: The raw logits for the next token. Shape: (vocab_size,)
Returns:
A tensor with the same shape, dtype, and device as `logits`.
The processor is expected to modify the logits in-place.
"""
...
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .hello_world import HelloWorldLogitsProcessor
from .temperature import TemperatureProcessor
__all__ = ["TemperatureProcessor", "HelloWorldLogitsProcessor"]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Sequence
import torch
from transformers import PreTrainedTokenizerBase
from dynamo.logits_processing import BaseLogitsProcessor
RESPONSE = "Hello world!"
class HelloWorldLogitsProcessor(BaseLogitsProcessor):
"""
Sample Logits Processor that always outputs a hardcoded
response (`RESPONSE`), no matter the input
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer
self.token_ids = tokenizer.encode(RESPONSE, add_special_tokens=False)
self.eos_id = tokenizer.eos_token_id
if self.eos_id is None:
raise ValueError(
"Tokenizer has no eos_token_id; HelloWorldLogitsProcessor requires one."
)
self.state = 0
def __call__(self, input_ids: Sequence[int], scores: torch.Tensor):
mask = torch.full_like(scores, float("-inf"))
if self.state < len(self.token_ids):
token_idx = self.token_ids[self.state]
else:
token_idx = self.eos_id
# Allow only a single token to be output
mask[token_idx] = 0.0
# The `scores` tensor *must* also be modified in-place
scores.add_(mask)
self.state += 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Sequence
import torch
from dynamo.logits_processing import BaseLogitsProcessor
class TemperatureProcessor(BaseLogitsProcessor):
"""
Example logits processor that applies temperature scaling.
This is a simple demonstration of how to implement a logits processor
that can be used with any Dynamo backend.
"""
def __init__(self, temperature: float = 1.0):
"""
Args:
temperature: Scaling factor. Higher values make distribution more uniform,
lower values make it more peaked. Must be positive.
"""
if temperature <= 0:
raise ValueError("Temperature must be positive")
self.temperature = temperature
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor):
"""
Apply temperature scaling to logits.
Args:
input_ids: Token IDs generated so far (unused in this simple example)
logits: Raw logits tensor from model
The processor is expected to modify the logits in-place.
"""
if self.temperature == 1.0:
return
logits.div_(self.temperature)
......@@ -3,6 +3,7 @@
"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""
import os
from dataclasses import dataclass
from typing import Any, Callable, List
......@@ -32,6 +33,11 @@ def create_payload_for_config(config: EngineConfig) -> Payload:
This provides the default implementation for text-only models.
"""
expected_response = (
["Hello world"]
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1"
else ["AI"]
)
return Payload(
payload_chat={
"model": config.model,
......@@ -54,5 +60,5 @@ def create_payload_for_config(config: EngineConfig) -> Payload:
},
repeat_count=3,
expected_log=[],
expected_response=["AI"],
expected_response=expected_response,
)
......@@ -57,6 +57,33 @@ class TRTLLMProcess(EngineProcess):
)
def run_trtllm_test_case(config: TRTLLMConfig, request) -> None:
payload = create_payload_for_config(config)
with TRTLLMProcess(config, request) as server_process:
assert len(config.endpoints) == len(config.response_handlers)
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
# trtllm test configurations
trtllm_configs = {
"aggregated": TRTLLMConfig(
......@@ -137,33 +164,9 @@ def test_deployment(trtllm_config_test, request, runtime_services):
logger.info("Starting test_deployment")
config = trtllm_config_test
payload = create_payload_for_config(config)
logger.info(f"Using model: {config.model}")
logger.info(f"Script: {config.script_name}")
with TRTLLMProcess(config, request) as server_process:
assert len(config.endpoints) == len(config.response_handlers)
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
run_trtllm_test_case(config, request)
@pytest.mark.e2e
......@@ -331,3 +334,34 @@ def test_metrics_labels(request, runtime_services):
except subprocess.TimeoutExpired:
process.kill()
process.wait()
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.trtllm_marker
@pytest.mark.slow
def test_chat_only_aggregated_with_test_logits_processor(
request, runtime_services, monkeypatch
):
"""
Run a single aggregated chat-completions test using Qwen 0.6B with the
test logits processor enabled, and expect "Hello world" in the response.
"""
# Enable HelloWorld logits processor only for this test
monkeypatch.setenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR", "1")
base = trtllm_configs["aggregated"]
config = TRTLLMConfig(
name="aggregated_qwen_chatonly",
directory=base.directory,
script_name=base.script_name, # agg.sh
marks=[], # not used by this direct test
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
model="Qwen/Qwen3-0.6B",
delayed_start=base.delayed_start,
timeout=base.timeout,
)
run_trtllm_test_case(config, request)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment