Unverified Commit 211b54ac authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Rebased #617 (#868)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation

).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------
Co-authored-by: default avatarVincent Brouwers <vincent.brouwers@ing.com>
parent 4486f78c
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
...@@ -11,6 +12,7 @@ from text_generation_server.models.types import ( ...@@ -11,6 +12,7 @@ from text_generation_server.models.types import (
Batch, Batch,
Generation, Generation,
PrefillTokens, PrefillTokens,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
...@@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
...@@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = [] decoder_input_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
...@@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
...@@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch): ...@@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0) prefix_offsets.append(0)
read_offsets.append(1) read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
...@@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
...@@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_input_length = 0 max_input_length = 0
max_decoder_input_length = 0 max_decoder_input_length = 0
...@@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
...@@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = ( max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length) len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens + remaining_decode_tokens
...@@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
self.read_offsets = read_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset self.padding_right_offset = padding_right_offset
...@@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_tokens = 0 max_tokens = 0
# Batch tensors # Batch tensors
...@@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = None decoder_input_ids = None
decoder_attention_mask = None decoder_attention_mask = None
encoder_last_hidden_state = None encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = [] past_key_values = []
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
...@@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0: if i == 0:
requests_idx_mapping = batch.requests_idx_mapping requests_idx_mapping = batch.requests_idx_mapping
...@@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch): ...@@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
), ),
) )
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices # Copy to correct indices
encoder_last_hidden_state[ encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, : start_index:end_index, -batch.max_input_length :, :
...@@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
...@@ -613,6 +639,12 @@ class Seq2SeqLM(Model): ...@@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
batch.past_key_values, batch.past_key_values,
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
...@@ -628,6 +660,9 @@ class Seq2SeqLM(Model): ...@@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_decoder_input_ids, batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
...@@ -641,6 +676,9 @@ class Seq2SeqLM(Model): ...@@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_decoder_input_ids, all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
...@@ -698,6 +736,24 @@ class Seq2SeqLM(Model): ...@@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
...@@ -706,6 +762,7 @@ class Seq2SeqLM(Model): ...@@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)
......
from functools import total_ordering
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -71,6 +72,25 @@ class PrefillTokens: ...@@ -71,6 +72,25 @@ class PrefillTokens:
return len(self.token_ids) return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass @dataclass
class Generation: class Generation:
request_id: int request_id: int
...@@ -80,6 +100,8 @@ class Generation: ...@@ -80,6 +100,8 @@ class Generation:
token_text: str token_text: str
token_is_special: bool token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
...@@ -94,4 +116,5 @@ class Generation: ...@@ -94,4 +116,5 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
) )
import re import re
import torch from typing import Callable, List, Optional, Tuple
from transformers import (
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
static_warper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper, HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
...@@ -229,11 +225,10 @@ class HeterogeneousNextTokenChooser: ...@@ -229,11 +225,10 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores) scores = warper(input_ids, scores)
next_ids = self.choice(scores) next_ids = self.choice(scores)
next_logprobs = torch.gather( logprobs = torch.log_softmax(scores, -1)
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
).view(-1)
return next_ids, next_logprobs return next_ids, next_logprobs, logprobs
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
...@@ -339,3 +334,50 @@ class HeterogeneousSampling: ...@@ -339,3 +334,50 @@ class HeterogeneousSampling:
self.greedy_indices = new_greedy_indices self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping self.sampling_mapping = new_sampling_mapping
return self return self
def batch_top_tokens(
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[
idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment