Unverified Commit 43a26958 authored by Bhuvan Agrawal's avatar Bhuvan Agrawal Committed by GitHub
Browse files

feat: add logits processor support for trtllm backend (#2702)


Signed-off-by: default avatarBhuvan Agrawal <11240550+bhuvan002@users.noreply.github.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent 699996e4
...@@ -43,6 +43,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1)) ...@@ -43,6 +43,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
- [Client](#client) - [Client](#client)
- [Benchmarking](#benchmarking) - [Benchmarking](#benchmarking)
- [Multimodal Support](#multimodal-support) - [Multimodal Support](#multimodal-support)
- [Logits Processing](#logits-processing)
- [Performance Sweep](#performance-sweep) - [Performance Sweep](#performance-sweep)
## Feature Support Matrix ## Feature Support Matrix
...@@ -242,6 +243,63 @@ To benchmark your deployment with GenAI-Perf, see this utility script, configuri ...@@ -242,6 +243,63 @@ To benchmark your deployment with GenAI-Perf, see this utility script, configuri
Dynamo with the TensorRT-LLM backend supports multimodal models, enabling you to process both text and images (or pre-computed embeddings) in a single request. For detailed setup instructions, example requests, and best practices, see the [Multimodal Support Guide](./multimodal_support.md). Dynamo with the TensorRT-LLM backend supports multimodal models, enabling you to process both text and images (or pre-computed embeddings) in a single request. For detailed setup instructions, example requests, and best practices, see the [Multimodal Support Guide](./multimodal_support.md).
## Logits Processing
Logits processors let you modify the next-token logits at every decoding step (e.g., to apply custom constraints or sampling transforms). Dynamo provides a backend-agnostic interface and an adapter for TensorRT-LLM so you can plug in custom processors.
### How it works
- **Interface**: Implement `dynamo.logits_processing.BaseLogitsProcessor` which defines `__call__(input_ids, logits)` and modifies `logits` in-place.
- **TRT-LLM adapter**: Use `dynamo.trtllm.logits_processing.adapter.create_trtllm_adapters(...)` to convert Dynamo processors into TRT-LLM-compatible processors and assign them to `SamplingParams.logits_processor`.
- **Examples**: See example processors in `lib/bindings/python/src/dynamo/logits_processing/examples/` ([temperature](../../../lib/bindings/python/src/dynamo/logits_processing/examples/temperature.py), [hello_world](../../../lib/bindings/python/src/dynamo/logits_processing/examples/hello_world.py)).
### Quick test: HelloWorld processor
You can enable a test-only processor that forces the model to respond with "Hello world!". This is useful to verify the wiring without modifying your model or engine code.
```bash
cd $DYNAMO_HOME/components/backends/trtllm
export DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR=1
./launch/agg.sh
```
Notes:
- When enabled, Dynamo initializes the tokenizer so the HelloWorld processor can map text to token IDs.
- Expected chat response contains "Hello world".
### Bring your own processor
Implement a processor by conforming to `BaseLogitsProcessor` and modify logits in-place. For example, temperature scaling:
```python
from typing import Sequence
import torch
from dynamo.logits_processing import BaseLogitsProcessor
class TemperatureProcessor(BaseLogitsProcessor):
def __init__(self, temperature: float = 1.0):
if temperature <= 0:
raise ValueError("Temperature must be positive")
self.temperature = temperature
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor):
if self.temperature == 1.0:
return
logits.div_(self.temperature)
```
Wire it into TRT-LLM by adapting and attaching to `SamplingParams`:
```python
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
from dynamo.logits_processing.examples import TemperatureProcessor
processors = [TemperatureProcessor(temperature=0.7)]
sampling_params.logits_processor = create_trtllm_adapters(processors)
```
### Current limitations
- Per-request processing only (batch size must be 1); beam width > 1 is not supported.
- Processors must modify logits in-place and not return a new tensor.
- If your processor needs tokenization, ensure the tokenizer is initialized (do not skip tokenizer init).
## Performance Sweep ## Performance Sweep
For detailed instructions on running comprehensive performance sweeps across both aggregated and disaggregated serving configurations, see the [TensorRT-LLM Benchmark Scripts for DeepSeek R1 model](./performance_sweeps/README.md). This guide covers recommended benchmarking setups, usage of provided scripts, and best practices for evaluating system performance. For detailed instructions on running comprehensive performance sweeps across both aggregated and disaggregated serving configurations, see the [TensorRT-LLM Benchmark Scripts for DeepSeek R1 model](./performance_sweeps/README.md). This guide covers recommended benchmarking setups, usage of provided scripts, and best practices for evaluating system performance.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import List, Optional
import torch
from tensorrt_llm.sampling_params import LogitsProcessor
from dynamo.logits_processing import BaseLogitsProcessor
logger = logging.getLogger(__name__)
class TrtllmDynamoLogitsAdapter(LogitsProcessor):
"""
Adapter that wraps Dynamo BaseLogitsProcessor instances to work with TensorRT-LLM's logits processor interface.
Inherits from tensorrt_llm.LogitsProcessor and implements the required interface:
__call__(self, req_ids: int, logits: torch.Tensor, ids: List[List[int]], stream_ptr, client_id: Optional[int])
This adapter maintains per-request state and converts between the interfaces.
"""
def __init__(self, processor: BaseLogitsProcessor):
super().__init__()
self.processor = processor
def __call__(
self,
req_ids: int,
logits: torch.Tensor,
ids: List[List[int]],
stream_ptr,
client_id: Optional[int] = None,
):
"""
TensorRT-LLM logits processor interface.
Args:
req_ids: Request identifier
logits: Logits tensor for current step
ids: List of token sequences (batch of sequences)
stream_ptr: CUDA stream pointer
client_id: Optional client identifier
Returns:
Modified logits tensor (in-place modification expected by TRT-LLM)
"""
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
try:
with torch.cuda.stream(stream):
if logits.shape[0] != 1:
raise ValueError(
f"This logits adapter only supports per-request logits processing. "
f"Received logits with batch size {logits.shape[0]} expected 1"
)
if logits.shape[1] != 1:
raise ValueError(
"Logits processing with beam width > 1 is not supported"
)
# Call the processor which modifies the logits in-place
self.processor(ids[0], logits[0, 0, :])
except Exception as e:
logger.error(f"Error in logits processor for request {req_ids}: {e}")
# Don't modify logits on error
# TRT-LLM expects void return (in-place modification)
def create_trtllm_adapters(
processors: List[BaseLogitsProcessor],
) -> List[TrtllmDynamoLogitsAdapter]:
"""
Create TensorRT-LLM compatible adapters from Dynamo logits processors.
Args:
processors: List of Dynamo BaseLogitsProcessor instances
Returns:
List of TensorRT-LLM compatible logits processor adapters
"""
adapters = []
for processor in processors:
adapter = TrtllmDynamoLogitsAdapter(processor)
adapters.append(adapter)
return adapters
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import logging import logging
import os
import signal import signal
import sys import sys
...@@ -225,6 +226,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -225,6 +226,12 @@ async def init(runtime: DistributedRuntime, config: Config):
modelType = ModelType.Backend modelType = ModelType.Backend
multimodal_processor = None multimodal_processor = None
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
# We need to initialize the tokenizer for the test logits processor
# But detokenizing still happens in the rust engine, so we do _not_ want
# to set default_sampling_params.detokenize to True.
engine_args["skip_tokenizer_init"] = False
if modality == "multimodal": if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False engine_args["skip_tokenizer_init"] = False
modelType = ModelType.Chat modelType = ModelType.Chat
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
import logging import logging
import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Optional, Union
...@@ -23,9 +24,11 @@ import torch ...@@ -23,9 +24,11 @@ import torch
from tensorrt_llm import SamplingParams from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector from dynamo.nixl_connect import Connector
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import Publisher from dynamo.trtllm.publisher import Publisher
from dynamo.trtllm.utils.disagg_utils import ( from dynamo.trtllm.utils.disagg_utils import (
...@@ -182,6 +185,12 @@ class HandlerBase: ...@@ -182,6 +185,12 @@ class HandlerBase:
request_id = request.get("id") or request.get("request_id", "unknown-id") request_id = request.get("id") or request.get("request_id", "unknown-id")
model_name = request.get("model", "unknown_model") model_name = request.get("model", "unknown_model")
# Optional test-only logits processing (enable with DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR=1)
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
processors = [HelloWorldLogitsProcessor(self.engine.llm.tokenizer)]
adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters
# NEW: Updated engine call to include multimodal data # NEW: Updated engine call to include multimodal data
async for res in self.engine.llm.generate_async( async for res in self.engine.llm.generate_async(
inputs=processed_input, # Use the correctly extracted inputs inputs=processed_input, # Use the correctly extracted inputs
......
...@@ -8,11 +8,12 @@ This module defines the core BaseLogitsProcessor interface that all ...@@ -8,11 +8,12 @@ This module defines the core BaseLogitsProcessor interface that all
logits processors must implement. logits processors must implement.
""" """
from typing import Protocol, Sequence from typing import Protocol, Sequence, runtime_checkable
import torch import torch
@runtime_checkable
class BaseLogitsProcessor(Protocol): class BaseLogitsProcessor(Protocol):
""" """
Protocol for logits processors in Dynamo. Protocol for logits processors in Dynamo.
...@@ -25,7 +26,7 @@ class BaseLogitsProcessor(Protocol): ...@@ -25,7 +26,7 @@ class BaseLogitsProcessor(Protocol):
self, self,
input_ids: Sequence[int], input_ids: Sequence[int],
logits: torch.Tensor, logits: torch.Tensor,
) -> torch.Tensor: ) -> None:
""" """
Process the logits for the next token prediction. Process the logits for the next token prediction.
...@@ -33,7 +34,6 @@ class BaseLogitsProcessor(Protocol): ...@@ -33,7 +34,6 @@ class BaseLogitsProcessor(Protocol):
input_ids: The input token IDs generated so far. input_ids: The input token IDs generated so far.
logits: The raw logits for the next token. Shape: (vocab_size,) logits: The raw logits for the next token. Shape: (vocab_size,)
Returns: The processor is expected to modify the logits in-place.
A tensor with the same shape, dtype, and device as `logits`.
""" """
... ...
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .hello_world import HelloWorldLogitsProcessor
from .temperature import TemperatureProcessor
__all__ = ["TemperatureProcessor", "HelloWorldLogitsProcessor"]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Sequence
import torch
from transformers import PreTrainedTokenizerBase
from dynamo.logits_processing import BaseLogitsProcessor
RESPONSE = "Hello world!"
class HelloWorldLogitsProcessor(BaseLogitsProcessor):
"""
Sample Logits Processor that always outputs a hardcoded
response (`RESPONSE`), no matter the input
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer
self.token_ids = tokenizer.encode(RESPONSE, add_special_tokens=False)
self.eos_id = tokenizer.eos_token_id
if self.eos_id is None:
raise ValueError(
"Tokenizer has no eos_token_id; HelloWorldLogitsProcessor requires one."
)
self.state = 0
def __call__(self, input_ids: Sequence[int], scores: torch.Tensor):
mask = torch.full_like(scores, float("-inf"))
if self.state < len(self.token_ids):
token_idx = self.token_ids[self.state]
else:
token_idx = self.eos_id
# Allow only a single token to be output
mask[token_idx] = 0.0
# The `scores` tensor *must* also be modified in-place
scores.add_(mask)
self.state += 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Sequence
import torch
from dynamo.logits_processing import BaseLogitsProcessor
class TemperatureProcessor(BaseLogitsProcessor):
"""
Example logits processor that applies temperature scaling.
This is a simple demonstration of how to implement a logits processor
that can be used with any Dynamo backend.
"""
def __init__(self, temperature: float = 1.0):
"""
Args:
temperature: Scaling factor. Higher values make distribution more uniform,
lower values make it more peaked. Must be positive.
"""
if temperature <= 0:
raise ValueError("Temperature must be positive")
self.temperature = temperature
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor):
"""
Apply temperature scaling to logits.
Args:
input_ids: Token IDs generated so far (unused in this simple example)
logits: Raw logits tensor from model
The processor is expected to modify the logits in-place.
"""
if self.temperature == 1.0:
return
logits.div_(self.temperature)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)""" """Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, List from typing import Any, Callable, List
...@@ -32,6 +33,11 @@ def create_payload_for_config(config: EngineConfig) -> Payload: ...@@ -32,6 +33,11 @@ def create_payload_for_config(config: EngineConfig) -> Payload:
This provides the default implementation for text-only models. This provides the default implementation for text-only models.
""" """
expected_response = (
["Hello world"]
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1"
else ["AI"]
)
return Payload( return Payload(
payload_chat={ payload_chat={
"model": config.model, "model": config.model,
...@@ -54,5 +60,5 @@ def create_payload_for_config(config: EngineConfig) -> Payload: ...@@ -54,5 +60,5 @@ def create_payload_for_config(config: EngineConfig) -> Payload:
}, },
repeat_count=3, repeat_count=3,
expected_log=[], expected_log=[],
expected_response=["AI"], expected_response=expected_response,
) )
...@@ -57,6 +57,33 @@ class TRTLLMProcess(EngineProcess): ...@@ -57,6 +57,33 @@ class TRTLLMProcess(EngineProcess):
) )
def run_trtllm_test_case(config: TRTLLMConfig, request) -> None:
payload = create_payload_for_config(config)
with TRTLLMProcess(config, request) as server_process:
assert len(config.endpoints) == len(config.response_handlers)
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
# trtllm test configurations # trtllm test configurations
trtllm_configs = { trtllm_configs = {
"aggregated": TRTLLMConfig( "aggregated": TRTLLMConfig(
...@@ -137,33 +164,9 @@ def test_deployment(trtllm_config_test, request, runtime_services): ...@@ -137,33 +164,9 @@ def test_deployment(trtllm_config_test, request, runtime_services):
logger.info("Starting test_deployment") logger.info("Starting test_deployment")
config = trtllm_config_test config = trtllm_config_test
payload = create_payload_for_config(config)
logger.info(f"Using model: {config.model}") logger.info(f"Using model: {config.model}")
logger.info(f"Script: {config.script_name}") logger.info(f"Script: {config.script_name}")
run_trtllm_test_case(config, request)
with TRTLLMProcess(config, request) as server_process:
assert len(config.endpoints) == len(config.response_handlers)
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)
for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
@pytest.mark.e2e @pytest.mark.e2e
...@@ -331,3 +334,34 @@ def test_metrics_labels(request, runtime_services): ...@@ -331,3 +334,34 @@ def test_metrics_labels(request, runtime_services):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
process.kill() process.kill()
process.wait() process.wait()
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.trtllm_marker
@pytest.mark.slow
def test_chat_only_aggregated_with_test_logits_processor(
request, runtime_services, monkeypatch
):
"""
Run a single aggregated chat-completions test using Qwen 0.6B with the
test logits processor enabled, and expect "Hello world" in the response.
"""
# Enable HelloWorld logits processor only for this test
monkeypatch.setenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR", "1")
base = trtllm_configs["aggregated"]
config = TRTLLMConfig(
name="aggregated_qwen_chatonly",
directory=base.directory,
script_name=base.script_name, # agg.sh
marks=[], # not used by this direct test
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
model="Qwen/Qwen3-0.6B",
delayed_start=base.delayed_start,
timeout=base.timeout,
)
run_trtllm_test_case(config, request)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment