Unverified Commit 211b54ac authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Rebased #617 (#868)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation

).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------
Co-authored-by: default avatarVincent Brouwers <vincent.brouwers@ing.com>
parent 4486f78c
from text_generation_server.utils.tokens import batch_top_tokens
import torch
from dataclasses import dataclass
......@@ -11,6 +12,7 @@ from text_generation_server.models.types import (
Batch,
Generation,
PrefillTokens,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
......@@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
......@@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = []
prefix_offsets = []
read_offsets = []
......@@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
......@@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0)
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
......@@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
......@@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_input_length = 0
max_decoder_input_length = 0
......@@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
......@@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
......@@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
......@@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
......@@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = None
decoder_attention_mask = None
encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = []
# Used for slicing correctly inside the tensors
......@@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
......@@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, :
......@@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
......@@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
batch.past_key_values,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Finished requests
generations: List[Generation] = []
stopped = True
......@@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
......@@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
next_token_chooser,
stopping_criteria,
all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
......@@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
......@@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)
......
from functools import total_ordering
import torch
from abc import ABC, abstractmethod
......@@ -71,6 +72,25 @@ class PrefillTokens:
return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass
class Generation:
request_id: int
......@@ -80,6 +100,8 @@ class Generation:
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
......@@ -94,4 +116,5 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
)
import re
import torch
from transformers import (
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from typing import Callable, List, Optional, Tuple
import torch
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import (
static_warper,
HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper,
static_warper,
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser:
......@@ -229,11 +225,10 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores)
next_ids = self.choice(scores)
next_logprobs = torch.gather(
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
).view(-1)
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs
return next_ids, next_logprobs, logprobs
def filter(self, indices):
if self.watermark_processor is not None:
......@@ -339,3 +334,50 @@ class HeterogeneousSampling:
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self
def batch_top_tokens(
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[
idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment