Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
3789d340
"src/diffusers/pipelines/allegro/pipeline_allegro.py" did not exist on "b934215d4c376ea2e08e28103443686b95ea772c"
Commit
3789d340
authored
Jun 09, 2023
by
Benjamin Fattori
Committed by
lintangsutawika
Jun 22, 2023
Browse files
batch support for loglikelihood tokens
parent
0a3b8069
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
37 deletions
+16
-37
lm_eval/models/seq2seq.py
lm_eval/models/seq2seq.py
+16
-37
No files found.
lm_eval/models/seq2seq.py
View file @
3789d340
...
...
@@ -127,7 +127,7 @@ class Seq2SeqHFLM(LM):
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
def
_model_call
(
self
,
inps
,
labels
=
None
):
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
...
...
@@ -139,7 +139,7 @@ class Seq2SeqHFLM(LM):
logits returned from the model
"""
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_generate
(
self
,
context
,
max_length
,
stop
):
...
...
@@ -194,10 +194,11 @@ class Seq2SeqHFLM(LM):
):
inps
=
[]
conts
=
[]
encoder_attns
=
[]
cont_toks_list
=
[]
padding
_length_inp
=
None
padding
_length_cont
=
None
max_batch
_length_inp
=
None
max_batch
_length_cont
=
None
for
_
,
context_enc
,
continuation_enc
in
chunk
:
# sanity check
...
...
@@ -217,44 +218,22 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
(
contlen
,)
=
cont
.
shape
padding_length_inp
=
(
padding_length_inp
if
padding_length_inp
is
not
None
else
inplen
)
padding_length_cont
=
(
padding_length_cont
if
padding_length_cont
is
not
None
else
contlen
)
inp
=
torch
.
cat
(
[
inp
,
# [seq]
torch
.
zeros
(
padding_length_inp
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
max_batch_length_inp
=
max
(
max_batch_length_inp
,
inplen
)
if
max_batch_length_inp
is
not
None
else
inplen
max_batch_length_cont
=
max
(
max_batch_length_cont
,
contlen
)
if
max_batch_length_cont
is
not
None
else
contlen
cont
=
torch
.
cat
(
[
cont
,
# [seq]
torch
.
zeros
(
padding_length_cont
-
contlen
,
dtype
=
torch
.
long
).
to
(
cont
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
conts
.
append
(
cont
.
unsqueeze
(
0
))
# [1, padding_length]
inps
.
append
(
inp
)
# [1, inp_len]
conts
.
append
(
cont
)
# [1, cont_len]
encoder_attns
.
append
(
torch
.
ones_like
(
inp
))
cont_toks_list
.
append
(
continuation_enc
)
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length]
batched_conts
=
torch
.
cat
(
conts
,
dim
=
0
)
# [batch, padding_length]
batched_inps
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
inps
)
# [batch, padding_length]
batched_conts
=
utils
.
pad_and_concat
(
max_batch_length_cont
,
conts
)
# [batch, padding_length]
batched_encoder_mask
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
encoder_attns
)
# need to make attention mask here too
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
labels
=
batched_conts
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
attn_mask
=
batched_encoder_mask
,
labels
=
batched_conts
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
cont_toks
in
zip
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment