Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6c86eb47
Commit
6c86eb47
authored
Apr 28, 2022
by
Tian Yun
Browse files
Fixed generation trunction for GPT-2 and T5
parent
fce17ee1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
13 deletions
+17
-13
lm_eval/base.py
lm_eval/base.py
+1
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+12
-8
lm_eval/models/t5.py
lm_eval/models/t5.py
+4
-4
No files found.
lm_eval/base.py
View file @
6c86eb47
...
...
@@ -384,7 +384,7 @@ class BaseLM(LM):
torch
.
tensor
(
primary_until
),
)
s
=
self
.
tok_decode
(
cont
[
0
]
.
tolist
()
[
context_enc
.
shape
[
1
]
:]
)
s
=
self
.
tok_decode
(
cont
.
tolist
())
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
...
...
lm_eval/models/gpt2.py
View file @
6c86eb47
...
...
@@ -12,6 +12,7 @@ class HFLM(BaseLM):
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
,
parallelize
=
False
):
super
().
__init__
()
...
...
@@ -32,7 +33,7 @@ class HFLM(BaseLM):
self
.
gpt2
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
pretrained
,
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
),
)
.
to
(
self
.
device
)
)
self
.
gpt2
.
eval
()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
...
...
@@ -68,9 +69,11 @@ class HFLM(BaseLM):
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
if
parallelize
:
self
.
gpt2
.
parallelize
()
self
.
_device
=
torch
.
device
(
'cuda:0'
)
else
:
self
.
gpt2
.
to
(
self
.
_device
)
@
property
def
eot_token
(
self
):
...
...
@@ -147,15 +150,16 @@ class HFLM(BaseLM):
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
gpt2
.
generate
(
#
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
generations
=
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
#
stopping_criteria=stopping_criteria,
do_sample
=
False
,
)
# Remove the context from the generations
return
generations
[
0
,
context
.
shape
[
1
]
:]
# for backwards compatibility
GPT2LM
=
HFLM
lm_eval/models/t5.py
View file @
6c86eb47
...
...
@@ -62,7 +62,7 @@ class T5LM(BaseLM):
@
property
def
max_gen_toks
(
self
):
return
self
.
tokenizer
.
model_max_length
return
256
@
property
def
batch_size
(
self
):
...
...
@@ -187,10 +187,10 @@ class T5LM(BaseLM):
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
#
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return
self
.
t5
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
#
stopping_criteria=stopping_criteria,
do_sample
=
False
,
)
)
[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment