Unverified Commit b92a805e authored by Bhuvan Agrawal's avatar Bhuvan Agrawal Committed by GitHub
Browse files

feat: add BaseLogitsProcessor core interface (#2613)


Signed-off-by: default avatarBhuvan Agrawal <11240550+bhuvan002@users.noreply.github.com>
parent 0a71aea6
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Dynamo Logits Processing - Backend-agnostic logits processors.
This module provides the BaseLogitsProcessor protocol that can be used
across different backend adapters (TRT-LLM, vLLM, SGLang).
"""
from .base import BaseLogitsProcessor
__all__ = ["BaseLogitsProcessor"]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Base logits processor protocol for Dynamo.
This module defines the core BaseLogitsProcessor interface that all
logits processors must implement.
"""
from typing import Protocol, Sequence
import torch
class BaseLogitsProcessor(Protocol):
"""
Protocol for logits processors in Dynamo.
All logits processors must implement this interface to be compatible
with backend adapters (TRT-LLM, vLLM, SGLang).
"""
def __call__(
self,
input_ids: Sequence[int],
logits: torch.Tensor,
) -> torch.Tensor:
"""
Process the logits for the next token prediction.
Args:
input_ids: The input token IDs generated so far.
logits: The raw logits for the next token. Shape: (vocab_size,)
Returns:
A tensor with the same shape, dtype, and device as `logits`.
"""
...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment