"vscode:/vscode.git/clone" did not exist on "27a8c9e4f189462c1cc4206317a27032887c8513"
Unverified Commit 63b90a51 authored by guillaume-be's avatar guillaume-be Committed by GitHub
Browse files

Optimized bad word ids (#13433)

* Optimized bad word ids generation

* Fixed optimized bad token ids

* Updated style
parent 5c7789d4
......@@ -16,7 +16,7 @@
import inspect
import math
from abc import ABC
from typing import Callable, Iterable, List
from typing import Callable, Iterable, List, Optional
import numpy as np
import torch
......@@ -369,46 +369,59 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
if len(word) == 1:
self.bad_words_id_length_1.append(word[0])
else:
self.bad_words_id_length_greater_than_1.append(word)
self.static_bad_words_mask: Optional[torch.LongTensor] = None
for banned_token_seq in self.bad_words_ids:
for banned_token_seq in self.bad_words_id_length_greater_than_1:
assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list"
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
banned_tokens = self._calc_banned_bad_words_ids(input_ids)
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
self.static_bad_words_mask = self._calc_static_bad_word_mask(scores)
dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist())
scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
def _calc_static_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor:
static_bad_words_mask = torch.zeros(scores.shape[1])
static_bad_words_mask[self.bad_words_id_length_1] = 1
return static_bad_words_mask.unsqueeze(0).to(scores.device).bool()
def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
return prev_tokens[-len(tokens) :] == tokens
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]:
banned_tokens = []
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_ids:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
for banned_token_seq in self.bad_words_id_length_greater_than_1:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]):
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
def _set_scores_to_inf_for_banned_tokens(
self, scores: torch.Tensor, banned_tokens: List[List[int]]
) -> torch.Tensor:
"""
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
......@@ -428,21 +441,32 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
f"An invalid bad word ID is defined: {token}. This ID is not contained in the"
f"vocabulary, and is therefore ignored."
)
if not banned_mask_list:
if not banned_mask_list and self.static_bad_words_mask is None:
return scores
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
else:
if banned_mask_list:
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
.to(scores.device)
.to_dense()
.bool()
)
if self.static_bad_words_mask is not None:
banned_mask = torch.bitwise_or(banned_mask, self.static_bad_words_mask)
else:
banned_mask = self.static_bad_words_mask
banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
)
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment