Unverified Commit 40478291 authored by guillaume-be's avatar guillaume-be Committed by GitHub
Browse files

[Performance improvement] "Bad tokens ids" optimization (#6064)

* Optimized banned token masking

* Avoid duplicate EOS masking if in bad_words_id

* Updated mask generation to handle empty banned token list

* Addition of unit tests for the updated bad_words_ids masking

* Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test

* Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test (timeout does not work on Windows)

* Moving Marian import to the test context to allow TF only environments to run

* Moving imports to torch_available test

* Updated operations device and test

* Updated operations device and test

* Added docstring and comment for in-place scores modification

* Moving test to own test_generation_utils, use of lighter models for testing

* removed unneded imports in test_modeling_common

* revert formatting change for ModelTesterMixin

* Updated caching, simplified eos token id test, removed unnecessary @require_torch

* formatting compliance
parent 87e124c2
import random
import unittest
import timeout_decorator
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from transformers import (
MarianConfig,
MarianMTModel,
)
@require_torch
class GenerationUtilsTest(unittest.TestCase):
@cached_property
def config(self):
config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de")
return config
@cached_property
def model(self):
return MarianMTModel(self.config)
def test_postprocess_next_token_scores(self):
config = self.config
model = self.model
# Initialize an input id tensor with batch size 8 and sequence length 12
input_ids = torch.arange(0, 96, 1).view((8, 12))
eos = config.eos_token_id
bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []]
masked_scores = [
[(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)],
[(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)],
[(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos)],
[],
]
for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases):
# Initialize a scores tensor with batch size 8 and vocabulary size 300
scores = torch.rand((8, 300))
output = model.postprocess_next_token_scores(
scores,
input_ids,
0,
bad_words_ids,
13,
15,
config.max_length,
config.eos_token_id,
config.repetition_penalty,
32,
5,
)
for masked_score in masked_scores[test_case_index]:
self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf"))
@timeout_decorator.timeout(10)
def test_postprocess_next_token_scores_large_bad_words_list(self):
config = self.config
model = self.model
# Initialize an input id tensor with batch size 8 and sequence length 12
input_ids = torch.arange(0, 96, 1).view((8, 12))
bad_words_ids = []
for _ in range(100):
length_bad_word = random.randint(1, 4)
bad_words_ids.append(random.sample(range(1, 300), length_bad_word))
scores = torch.rand((8, 300))
_ = model.postprocess_next_token_scores(
scores,
input_ids,
0,
bad_words_ids,
13,
15,
config.max_length,
config.eos_token_id,
config.repetition_penalty,
32,
5,
)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -89,11 +89,12 @@ class GenerationMixin: ...@@ -89,11 +89,12 @@ class GenerationMixin:
scores[i, banned_tokens] = -float("inf") scores[i, banned_tokens] = -float("inf")
if bad_words_ids is not None: if bad_words_ids is not None:
# Exclude EOS token (already processed)
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
# calculate a list of banned tokens according to bad words # calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
# Modify the scores in place by setting the banned tokens logits to `-inf`
for i, banned_tokens in enumerate(banned_tokens): set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
scores[i, banned_tokens] = -float("inf")
return scores return scores
...@@ -893,7 +894,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter ...@@ -893,7 +894,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
bad_words_ids bad_words_ids
) )
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False: if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
# if tokens do not match continue # if tokens do not match continue
continue continue
...@@ -904,6 +905,30 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter ...@@ -904,6 +905,30 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
return banned_tokens return banned_tokens
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
if not banned_mask_list:
return
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
scores.masked_fill_(banned_mask, -float("inf"))
def top_k_top_p_filtering( def top_k_top_p_filtering(
logits: Tensor, logits: Tensor,
top_k: int = 0, top_k: int = 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment