Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9ad2fc3d
Unverified
Commit
9ad2fc3d
authored
Jul 03, 2023
by
Hailey Schoelkopf
Committed by
GitHub
Jul 03, 2023
Browse files
Account for padding in inplen calculation
parent
d8bf52c6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
4 deletions
+6
-4
lm_eval/base.py
lm_eval/base.py
+6
-4
No files found.
lm_eval/base.py
View file @
9ad2fc3d
...
...
@@ -289,6 +289,7 @@ class BaseLM(LM):
):
inps
=
[]
cont_toks_list
=
[]
inplens
=
[]
padding_length
=
None
...
...
@@ -336,19 +337,20 @@ class BaseLM(LM):
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
cont_toks_list
.
append
(
cont
)
inplens
.
append
(
inplen
)
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length
]
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
cont_toks_list
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
# Slice to original seq length
contlen
=
len
(
cont_toks
)
inplen
=
logits
.
shape
[
0
]
inplen
=
inplen
+
(
logits
.
shape
[
0
]
-
padding_length
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment