Unverified Commit ff739414 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Fix bug in multi-token Stop Sequences (#1268)

* fix incorrect lookback protections

* bump generate_until task versions
parent 818c056b
...@@ -16,4 +16,4 @@ metric_list: ...@@ -16,4 +16,4 @@ metric_list:
aggregation: !function metrics.agg_bleu aggregation: !function metrics.agg_bleu
higher_is_better: true higher_is_better: true
metadata: metadata:
version: 0.0 version: 1.0
...@@ -636,6 +636,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): ...@@ -636,6 +636,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
self.done_tracker = [False] * batch_size self.done_tracker = [False] * batch_size
self.sequence = sequence self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
# print(sequence, self.sequence_ids)
# we look back for 2 more tokens than it takes to encode our stop sequence # we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our # and we don't want to mistakenly not stop a generation because our
...@@ -643,16 +644,18 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): ...@@ -643,16 +644,18 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model, # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
# Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
self.sequence_id_len = len(self.sequence_ids) + 2 self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool: def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][ lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
:, -self.sequence_id_len :
] lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker): for i, done in enumerate(self.done_tracker):
if not done: if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment