"README_origin.md" did not exist on "c238f1cde6d983963f5c2eee572e0cb852f81a44"
Commit f7873a49 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

update multi-token stopsequence handling

parent afda6551
......@@ -889,8 +889,6 @@ class HFLM(LM):
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
primary_until = [until[0]]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......@@ -916,7 +914,7 @@ class HFLM(LM):
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
stop=primary_until,
stop=until,
**kwargs,
)
......
......@@ -579,7 +579,14 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
# we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
......@@ -589,7 +596,6 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment