Unverified Commit 8267f991 authored by Yu Guo's avatar Yu Guo Committed by GitHub
Browse files

improve logits bias (#19041)

parent 7353492a
......@@ -5,6 +5,7 @@
import torch
import torch.nn as nn
from vllm.utils import async_tensor_h2d, is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
......@@ -20,6 +21,7 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.pin_memory = is_pin_memory_available()
def forward(
self,
......@@ -232,6 +234,10 @@ class Sampler(nn.Module):
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
rows: list[int] = []
cols: list[int] = []
vals: list[float] = []
# Get vocabulary size from logits
vocab_size = logits.shape[-1]
......@@ -244,7 +250,16 @@ class Sampler(nn.Module):
f"token_id {token_id} in logit_bias contains "
f"out-of-vocab token id. Vocabulary size: "
f"{vocab_size}")
logits[i, token_id] += bias
rows.append(i)
cols.append(token_id)
vals.append(bias)
if rows:
indices = async_tensor_h2d([rows, cols], torch.int64,
logits.device, self.pin_memory)
values = async_tensor_h2d(vals, torch.float, logits.device,
self.pin_memory)
logits.index_put_(tuple(indices), values=values, accumulate=True)
return logits
def apply_allowed_token_ids(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment