Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
96ea7ddc
"example/vscode:/vscode.git/clone" did not exist on "457c024d608c9b855775cc014a630c7e0d30710c"
Commit
96ea7ddc
authored
Apr 26, 2022
by
Tian Yun
Browse files
Added stoppping criteria for generation
parent
c27e29e1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
2 deletions
+66
-2
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+33
-2
tests/test_gpt2.py
tests/test_gpt2.py
+33
-0
No files found.
lm_eval/models/gpt2.py
View file @
96ea7ddc
...
@@ -116,10 +116,41 @@ class HFLM(BaseLM):
...
@@ -116,10 +116,41 @@ class HFLM(BaseLM):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
gpt2
(
inps
)[
0
][:,
:,
:
50257
]
return
self
.
gpt2
(
inps
)[
0
][:,
:,
:
50257
]
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
_get_stopping_criteria
(
self
,
stopping_criteria_ids
):
class
MultitokenEOSCriteria
(
transformers
.
StoppingCriteria
):
def
__init__
(
self
,
eos_seq_id
:
torch
.
LongTensor
,
tokenizer
):
self
.
eos_seq
=
tokenizer
.
decode
(
eos_seq_id
)
self
.
eos_seq_id
=
eos_seq_id
self
.
eos_seq_len
=
len
(
eos_seq_id
)
+
1
self
.
tokenizer
=
tokenizer
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
last_token_id
=
input_ids
[
0
,
-
self
.
eos_seq_len
:]
last_tokens
=
self
.
tokenizer
.
decode
(
last_token_id
)
is_stopped
=
self
.
eos_seq
in
last_tokens
return
is_stopped
class
EOSCriteria
(
transformers
.
StoppingCriteria
):
def
__init__
(
self
,
eos_token_id
:
torch
.
LongTensor
):
self
.
eos_token_id
=
eos_token_id
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
return
input_ids
[
0
,
-
1
]
==
self
.
eos_token_id
return
transformers
.
StoppingCriteriaList
([
MultitokenEOSCriteria
(
stopping_criteria_ids
,
self
.
tokenizer
),
EOSCriteria
(
stopping_criteria_ids
)
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
gpt2
.
generate
(
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
)
# for backwards compatibility
# for backwards compatibility
...
...
tests/test_gpt2.py
0 → 100644
View file @
96ea7ddc
import
random
import
lm_eval.models
as
models
import
pytest
import
torch
from
transformers
import
StoppingCriteria
@
pytest
.
mark
.
parametrize
(
"eos_token,test_input,expected"
,
[
(
"not"
,
"i like"
,
"i like to say that I'm not"
),
(
"say that"
,
"i like"
,
"i like to say that"
),
(
"great"
,
"big science is"
,
"big science is a great"
),
(
"<|endoftext|>"
,
"big science has"
,
"big science has been done in the past, but it's not the same as the science of the"
)
]
)
def
test_stopping_criteria
(
eos_token
,
test_input
,
expected
):
random
.
seed
(
42
)
torch
.
random
.
manual_seed
(
42
)
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
gpt2
=
models
.
get_model
(
"gpt2"
)(
device
=
device
)
context
=
torch
.
tensor
([
gpt2
.
tokenizer
.
encode
(
test_input
)])
stopping_criteria_ids
=
gpt2
.
tokenizer
.
encode
(
eos_token
)
generations
=
gpt2
.
_model_generate
(
context
,
max_length
=
20
,
stopping_criteria_ids
=
stopping_criteria_ids
)
generations
=
gpt2
.
tokenizer
.
decode
(
generations
[
0
])
assert
generations
==
expected
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment