Unverified Commit 3a243095 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Optimize `_get_ranks` in Sampler (#3623)

parent 64172a97
...@@ -506,22 +506,23 @@ def _sample( ...@@ -506,22 +506,23 @@ def _sample(
# sampling_tensors) # sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor: def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
""" """
This function calculates the ranks of the chosen tokens in a logprob tensor. This function calculates the ranks of the chosen tokens in a logprob tensor.
Args: Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M) x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim. where N is the no. of tokens and M is the vocab dim.
indices (List[int]): List of chosen token indices. indices (torch.Tensor): List of chosen token indices.
Returns: Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor. of the chosen token in the input logprob tensor.
""" """
vals = x[range(len(x)), indices] vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
return (x > vals[:, None]).long().sum(1) + 1 indices]
return (x > vals[:, None]).long().sum(1).add_(1)
def _get_logprobs( def _get_logprobs(
...@@ -561,12 +562,21 @@ def _get_logprobs( ...@@ -561,12 +562,21 @@ def _get_logprobs(
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0) assert sample_idx == logprobs.size(0)
batched_logprobs_query_seq_indices_gpu = torch.tensor(
batched_logprobs_query_seq_indices, device=logprobs.device)
batched_logprobs_query_token_indices_gpu = torch.tensor(
batched_logprobs_query_token_indices, device=logprobs.device)
# Batched query for logprobs of selected token # Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[ batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices, batched_logprobs_query_seq_indices_gpu,
batched_logprobs_query_token_indices batched_logprobs_query_token_indices_gpu
]] ]]
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices_gpu],
batched_logprobs_query_token_indices_gpu)
# Batched query for logprobs of topk tokens # Batched query for logprobs of topk tokens
if largest_num_logprobs > 0: if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs, top_logprobs, top_token_ids = torch.topk(logprobs,
...@@ -578,10 +588,7 @@ def _get_logprobs( ...@@ -578,10 +588,7 @@ def _get_logprobs(
top_logprobs, top_token_ids = None, None top_logprobs, top_token_ids = None, None
batched_logprobs_query_result = batched_logprobs_query_result.cpu() batched_logprobs_query_result = batched_logprobs_query_result.cpu()
batched_ranks_query_result = batched_ranks_query_result.cpu()
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices],
batched_logprobs_query_token_indices)
# Gather results # Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment