Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
eb7b9095
Commit
eb7b9095
authored
Jun 09, 2023
by
Benjamin Fattori
Browse files
batch support for loglikelihood tokens
parent
226063ce
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
38 deletions
+40
-38
lm_eval/models/seq2seq.py
lm_eval/models/seq2seq.py
+16
-37
lm_eval/utils.py
lm_eval/utils.py
+24
-1
No files found.
lm_eval/models/seq2seq.py
View file @
eb7b9095
...
@@ -127,7 +127,7 @@ class Seq2SeqHFLM(LM):
...
@@ -127,7 +127,7 @@ class Seq2SeqHFLM(LM):
def
tok_decode
(
self
,
tokens
):
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
def
_model_call
(
self
,
inps
,
labels
=
None
):
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
"""
inps: a torch tensor of shape [batch, sequence_ctx]
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
the size of sequence may vary from call to call
...
@@ -139,7 +139,7 @@ class Seq2SeqHFLM(LM):
...
@@ -139,7 +139,7 @@ class Seq2SeqHFLM(LM):
logits returned from the model
logits returned from the model
"""
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_generate
(
self
,
context
,
max_length
,
stop
):
def
_model_generate
(
self
,
context
,
max_length
,
stop
):
...
@@ -194,10 +194,11 @@ class Seq2SeqHFLM(LM):
...
@@ -194,10 +194,11 @@ class Seq2SeqHFLM(LM):
):
):
inps
=
[]
inps
=
[]
conts
=
[]
conts
=
[]
encoder_attns
=
[]
cont_toks_list
=
[]
cont_toks_list
=
[]
padding
_length_inp
=
None
max_batch
_length_inp
=
None
padding
_length_cont
=
None
max_batch
_length_cont
=
None
for
_
,
context_enc
,
continuation_enc
in
chunk
:
for
_
,
context_enc
,
continuation_enc
in
chunk
:
# sanity check
# sanity check
...
@@ -217,44 +218,22 @@ class Seq2SeqHFLM(LM):
...
@@ -217,44 +218,22 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
(
contlen
,)
=
cont
.
shape
(
contlen
,)
=
cont
.
shape
padding_length_inp
=
(
max_batch_length_inp
=
max
(
max_batch_length_inp
,
inplen
)
if
max_batch_length_inp
is
not
None
else
inplen
padding_length_inp
if
padding_length_inp
is
not
None
else
inplen
max_batch_length_cont
=
max
(
max_batch_length_cont
,
contlen
)
if
max_batch_length_cont
is
not
None
else
contlen
)
padding_length_cont
=
(
padding_length_cont
if
padding_length_cont
is
not
None
else
contlen
)
inp
=
torch
.
cat
(
[
inp
,
# [seq]
torch
.
zeros
(
padding_length_inp
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
cont
=
torch
.
cat
(
inps
.
append
(
inp
)
# [1, inp_len]
[
conts
.
append
(
cont
)
# [1, cont_len]
cont
,
# [seq]
encoder_attns
.
append
(
torch
.
ones_like
(
inp
))
torch
.
zeros
(
padding_length_cont
-
contlen
,
dtype
=
torch
.
long
).
to
(
cont
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
conts
.
append
(
cont
.
unsqueeze
(
0
))
# [1, padding_length]
cont_toks_list
.
append
(
continuation_enc
)
cont_toks_list
.
append
(
continuation_enc
)
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length]
batched_inps
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
inps
)
# [batch, padding_length]
batched_conts
=
torch
.
cat
(
conts
,
dim
=
0
)
# [batch, padding_length]
batched_conts
=
utils
.
pad_and_concat
(
max_batch_length_cont
,
conts
)
# [batch, padding_length]
batched_encoder_mask
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
encoder_attns
)
# need to make attention mask here too
multi_logits
=
F
.
log_softmax
(
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
labels
=
batched_conts
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
attn_mask
=
batched_encoder_mask
,
labels
=
batched_conts
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
cont_toks
in
zip
(
for
(
cache_key
,
_
,
_
),
logits
,
cont_toks
in
zip
(
...
...
lm_eval/utils.py
View file @
eb7b9095
...
@@ -14,7 +14,7 @@ from typing import List
...
@@ -14,7 +14,7 @@ from typing import List
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
itertools
import
islice
from
itertools
import
islice
import
torch
class
ExitCodeError
(
Exception
):
class
ExitCodeError
(
Exception
):
pass
pass
...
@@ -327,3 +327,26 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
...
@@ -327,3 +327,26 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
among ranks in multigpu setting or only pulling a sample of documents
among ranks in multigpu setting or only pulling a sample of documents
"""
"""
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
]):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
for
i
,
tensor
in
enumerate
(
tensors
):
tensor_len
=
tensor
.
shape
[
0
]
if
tensor_len
<
max_length
:
tensors
[
i
]
=
torch
.
cat
(
[
tensor
,
# [seq]
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
).
to
(
tensor
.
device
),
# [padding_length - seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
else
:
tensors
[
i
]
=
tensor
.
unsqueeze
(
0
)
return
torch
.
cat
(
tensors
,
dim
=
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment