Unverified Commit 63b90a51 authored by guillaume-be's avatar guillaume-be Committed by GitHub
Browse files

Optimized bad word ids (#13433)

* Optimized bad word ids generation

* Fixed optimized bad token ids

* Updated style
parent 5c7789d4
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import inspect import inspect
import math import math
from abc import ABC from abc import ABC
from typing import Callable, Iterable, List from typing import Callable, Iterable, List, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -369,46 +369,59 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): ...@@ -369,46 +369,59 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
) )
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
if len(word) == 1:
self.bad_words_id_length_1.append(word[0])
else:
self.bad_words_id_length_greater_than_1.append(word)
self.static_bad_words_mask: Optional[torch.LongTensor] = None
for banned_token_seq in self.bad_words_ids: for banned_token_seq in self.bad_words_id_length_greater_than_1:
assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list" assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list"
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
banned_tokens = self._calc_banned_bad_words_ids(input_ids) if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens) self.static_bad_words_mask = self._calc_static_bad_word_mask(scores)
dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist())
scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens)
return scores return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: def _calc_static_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor:
static_bad_words_mask = torch.zeros(scores.shape[1])
static_bad_words_mask[self.bad_words_id_length_1] = 1
return static_bad_words_mask.unsqueeze(0).to(scores.device).bool()
def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool:
if len(tokens) == 0: if len(tokens) == 0:
# if bad word tokens is just one token always ban it # if bad word tokens is just one token always ban it
return True return True
elif len(tokens) > len(prev_tokens): elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal # if bad word tokens are longer then prev input_ids they can't be equal
return False return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else: else:
return False return prev_tokens[-len(tokens) :] == tokens
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]: def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]:
banned_tokens = [] banned_tokens = []
for prev_input_ids_slice in prev_input_ids: for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = [] banned_tokens_slice = []
for banned_token_seq in self.bad_words_ids: for banned_token_seq in self.bad_words_id_length_greater_than_1:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]):
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1]) banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice) banned_tokens.append(banned_tokens_slice)
return banned_tokens return banned_tokens
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: def _set_scores_to_inf_for_banned_tokens(
self, scores: torch.Tensor, banned_tokens: List[List[int]]
) -> torch.Tensor:
""" """
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
list of list of banned tokens to ban in the format [[batch index, vocabulary position],... list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
...@@ -428,9 +441,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): ...@@ -428,9 +441,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
f"An invalid bad word ID is defined: {token}. This ID is not contained in the" f"An invalid bad word ID is defined: {token}. This ID is not contained in the"
f"vocabulary, and is therefore ignored." f"vocabulary, and is therefore ignored."
) )
if not banned_mask_list: if not banned_mask_list and self.static_bad_words_mask is None:
return scores return scores
else:
if banned_mask_list:
banned_mask = torch.LongTensor(banned_mask_list) banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask)) indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
...@@ -439,8 +454,17 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): ...@@ -439,8 +454,17 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
# [ 1 0 0 ] # [ 1 0 0 ]
banned_mask = ( banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
.to(scores.device)
.to_dense()
.bool()
) )
if self.static_bad_words_mask is not None:
banned_mask = torch.bitwise_or(banned_mask, self.static_bad_words_mask)
else:
banned_mask = self.static_bad_words_mask
scores = scores.masked_fill(banned_mask, -float("inf")) scores = scores.masked_fill(banned_mask, -float("inf"))
return scores return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment