Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f7873a49
"README_origin.md" did not exist on "c238f1cde6d983963f5c2eee572e0cb852f81a44"
Commit
f7873a49
authored
Nov 21, 2023
by
haileyschoelkopf
Browse files
update multi-token stopsequence handling
parent
afda6551
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
5 deletions
+9
-5
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+1
-3
lm_eval/utils.py
lm_eval/utils.py
+8
-2
No files found.
lm_eval/models/huggingface.py
View file @
f7873a49
...
...
@@ -889,8 +889,6 @@ class HFLM(LM):
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
# first stop sequence is used to halt generation upon encountering
primary_until
=
[
until
[
0
]]
# set the max length in tokens of inputs ("context_enc")
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
...
...
@@ -916,7 +914,7 @@ class HFLM(LM):
cont
=
self
.
_model_generate
(
context
=
context_enc
,
attention_mask
=
attn_masks
,
stop
=
primary_
until
,
stop
=
until
,
**
kwargs
,
)
...
...
lm_eval/utils.py
View file @
f7873a49
...
...
@@ -579,7 +579,14 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
self
.
done_tracker
=
[
False
]
*
batch_size
self
.
sequence
=
sequence
self
.
sequence_ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
)
self
.
sequence_id_len
=
len
(
self
.
sequence_ids
)
# we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
self
.
sequence_id_len
=
len
(
self
.
sequence_ids
)
+
2
self
.
tokenizer
=
tokenizer
def
__call__
(
self
,
input_ids
,
scores
,
**
kwargs
)
->
bool
:
...
...
@@ -589,7 +596,6 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
]
lookback_tokens_batch
=
self
.
tokenizer
.
batch_decode
(
lookback_ids_batch
)
for
i
,
done
in
enumerate
(
self
.
done_tracker
):
if
not
done
:
self
.
done_tracker
[
i
]
=
self
.
sequence
in
lookback_tokens_batch
[
i
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment