Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
621214e1
Commit
621214e1
authored
Jun 22, 2023
by
haileyschoelkopf
Committed by
lintangsutawika
Jun 22, 2023
Browse files
fix issues with encoder_attns, test lambada
parent
1c409035
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
1 deletion
+6
-1
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+6
-1
No files found.
lm_eval/models/huggingface.py
View file @
621214e1
...
@@ -226,7 +226,8 @@ class HFLM(LM):
...
@@ -226,7 +226,8 @@ class HFLM(LM):
logits returned from the model's decoder
logits returned from the model's decoder
"""
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
attn_mask
or
labels
:
if
attn_mask
is
not
None
or
labels
is
not
None
:
assert
attn_mask
is
not
None
and
labels
is
not
None
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
return
self
.
model
(
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
...
@@ -394,6 +395,10 @@ class HFLM(LM):
...
@@ -394,6 +395,10 @@ class HFLM(LM):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
(
inplen
,)
=
inp
.
shape
(
inplen
,)
=
inp
.
shape
# build encoder attn masks
encoder_attns
.
append
(
torch
.
ones_like
(
inp
))
cont
=
torch
.
tensor
(
cont
=
torch
.
tensor
(
(
continuation_enc
)[
-
self
.
max_length
:],
(
continuation_enc
)[
-
self
.
max_length
:],
# TODO: left-shift these?
# TODO: left-shift these?
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment