Unverified Commit 50107e10 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #1009 from EleutherAI/fix-stopseqs

[Refactor] Improve Handling of Stop-Sequences for HF Batched Generation
parents a1403c8f f7873a49
......@@ -889,8 +889,6 @@ class HFLM(LM):
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
primary_until = [until[0]]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......@@ -916,7 +914,7 @@ class HFLM(LM):
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
stop=primary_until,
stop=until,
**kwargs,
)
......
......@@ -579,7 +579,14 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
# we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
......@@ -589,7 +596,6 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment