"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cf416764f4d08bd24dc625a706e0ad7540ffd2c0"
Unverified Commit 89136ff7 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: sequence bias can handle same terminations (#24822)

parent 37d8611a
...@@ -624,9 +624,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): ...@@ -624,9 +624,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# is infered in the first usage, which inhibits initializing here) # is infered in the first usage, which inhibits initializing here)
self.sequences_length_greater_than_1 = []
self.length_1_bias = None self.length_1_bias = None
self.length_greather_than_1_bias = None
self.prepared_bias_variables = False self.prepared_bias_variables = False
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
...@@ -642,11 +640,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): ...@@ -642,11 +640,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
bias += self.length_1_bias bias += self.length_1_bias
# 4 - include the bias from length > 1, after determining which biased sequences may be completed. # 4 - include the bias from length > 1, after determining which biased sequences may be completed.
# `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding for sequence_ids, sequence_bias in self.sequence_bias.items():
# bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence if len(sequence_ids) == 1: # the sequence is of length 1, already applied
# may become complete this iteration. continue
matching_mask = torch.zeros_like(scores, dtype=torch.bool)
for sequence_ids in self.sequences_length_greater_than_1:
if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore
continue continue
prefix_length = len(sequence_ids) - 1 prefix_length = len(sequence_ids) - 1
...@@ -655,12 +651,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): ...@@ -655,12 +651,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
input_ids[:, -prefix_length:], input_ids[:, -prefix_length:],
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device), torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
).prod(dim=1) ).prod(dim=1)
matching_mask[:, last_token] |= matching_rows.bool() bias[:, last_token] += torch.where(
bias += torch.where( matching_rows.bool(), sequence_bias, torch.tensor(0.0, device=input_ids.device)
matching_mask, )
self.length_greather_than_1_bias,
torch.tensor(0.0, device=self.length_greather_than_1_bias.device),
)
# 5 - apply the bias to the scores # 5 - apply the bias to the scores
scores = scores + bias scores = scores + bias
...@@ -668,12 +661,10 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): ...@@ -668,12 +661,10 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
def _prepare_bias_variables(self, scores: torch.FloatTensor): def _prepare_bias_variables(self, scores: torch.FloatTensor):
vocabulary_size = scores.shape[-1] vocabulary_size = scores.shape[-1]
sequence_bias = self.sequence_bias
tokens_with_bias = []
# Check biased tokens out of bounds # Check biased tokens out of bounds
invalid_biases = [] invalid_biases = []
for sequence_ids in sequence_bias: for sequence_ids in self.sequence_bias:
for token_id in sequence_ids: for token_id in sequence_ids:
if token_id >= vocabulary_size: if token_id >= vocabulary_size:
invalid_biases.append(token_id) invalid_biases.append(token_id)
...@@ -686,20 +677,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): ...@@ -686,20 +677,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
# with simpler logic. # with simpler logic.
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device) self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
self.length_greather_than_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device) for sequence_ids, bias in self.sequence_bias.items():
for sequence_ids, bias in sequence_bias.items():
if len(sequence_ids) == 1: if len(sequence_ids) == 1:
self.length_1_bias[sequence_ids[-1]] = bias self.length_1_bias[sequence_ids[-1]] = bias
else:
self.sequences_length_greater_than_1.append(sequence_ids)
if self.length_greather_than_1_bias[sequence_ids[-1]] != 0.0:
raise ValueError(
"Setting a bias on sequences that share a common token termination is not yet supported. "
"Please open an issue if you see this error message (after checking that it doesn't already "
"exist)."
)
self.length_greather_than_1_bias[sequence_ids[-1]] = bias
tokens_with_bias.append(sequence_ids[-1])
self.prepared_bias_variables = True self.prepared_bias_variables = True
......
...@@ -520,6 +520,9 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -520,6 +520,9 @@ class LogitsProcessorTest(unittest.TestCase):
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long) input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
positive_bias = {(1,): 100.0, (4,): 100.0} positive_bias = {(1,): 100.0, (4,): 100.0}
negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0} negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0}
# biases the same termination twice, to ensure we can handle overlapping terminations (it won't have an effect
# on the test cases, though)
negative_bias.update({(1, 3, 1, 3, 1, 3): -100.0})
sequence_bias = {**positive_bias, **negative_bias} sequence_bias = {**positive_bias, **negative_bias}
# scores = 0 to facilitate checks # scores = 0 to facilitate checks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment