Unverified Commit c71880f8 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Vectorize logprobs computation (#787)

parent bcb6611a
...@@ -77,33 +77,46 @@ class LogitsProcessor(nn.Module): ...@@ -77,33 +77,46 @@ class LogitsProcessor(nn.Module):
@staticmethod @staticmethod
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
output_top_logprobs = [] output_top_logprobs = []
for i in range(all_logprobs.shape[0]): max_k = max(logits_metadata.top_logprobs_nums)
k = logits_metadata.top_logprobs_nums[i] ret = all_logprobs.topk(max_k, dim=1)
t = all_logprobs[i].topk(k) values = ret.values.tolist()
v_cpu = t.values.tolist() indices = ret.indices.tolist()
p_cpu = t.indices.tolist() for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs.append(list(zip(v_cpu, p_cpu))) output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
return None, output_top_logprobs return None, output_top_logprobs
else: else:
# TODO: vectorize the code below
input_top_logprobs, output_top_logprobs = [], [] input_top_logprobs, output_top_logprobs = [], []
pt = 0 pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu): for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0: if extend_seq_len == 0:
input_top_logprobs.append([]) input_top_logprobs.append([])
output_top_logprobs.append([]) output_top_logprobs.append([])
continue continue
k = logits_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
input_top_logprobs.append( input_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] [
list(zip(values[pt + j][:k], indices[pt + j][:k]))
for j in range(extend_seq_len - 1)
]
)
output_top_logprobs.append(
list(
zip(
values[pt + extend_seq_len - 1][:k],
indices[pt + extend_seq_len - 1][:k],
)
)
) )
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_len pt += extend_seq_len
return input_top_logprobs, output_top_logprobs return input_top_logprobs, output_top_logprobs
......
...@@ -6,7 +6,7 @@ import dataclasses ...@@ -6,7 +6,7 @@ import dataclasses
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from typing import Dict, List from typing import Dict, List, Tuple
import numpy as np import numpy as np
import transformers import transformers
...@@ -469,7 +469,9 @@ class TokenizerManager: ...@@ -469,7 +469,9 @@ class TokenizerManager:
) )
return ret return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool): def detokenize_logprob_tokens(
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
):
if not decode_to_text: if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs] return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
...@@ -481,9 +483,13 @@ class TokenizerManager: ...@@ -481,9 +483,13 @@ class TokenizerManager:
] ]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
for i, t in enumerate(top_logprobs): # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
if t: # We should batch all top-k tokens in all positions.
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) for i, token_top_logprobs in enumerate(top_logprobs):
if token_top_logprobs:
top_logprobs[i] = self.detokenize_logprob_tokens(
token_top_logprobs, decode_to_text
)
return top_logprobs return top_logprobs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment