Unverified Commit 8d0a01a5 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

[v1][sampler] Inplace logprobs comparison to get the token rank (#21283)


Signed-off-by: default avatarLu Fang <lufang@fb.com>
parent 0ec82edd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Some utilities for logprobs, including logits."""
import torch
@torch.compile(dynamic=True)
def batched_count_greater_than(x: torch.Tensor,
values: torch.Tensor) -> torch.Tensor:
"""
Counts elements in each row of x that are greater than the corresponding
value in values. Use torch.compile to generate an optimized kernel for
this function. otherwise, it will create additional copies of the input
tensors and cause memory issues.
Args:
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
Returns:
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
"""
return (x >= values).sum(-1)
......@@ -9,6 +9,7 @@ from vllm.utils import is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
......@@ -174,7 +175,7 @@ class Sampler(nn.Module):
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment