Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ffc3a456
Commit
ffc3a456
authored
Jun 28, 2023
by
haileyschoelkopf
Browse files
push WIP batched multi-kwarg code
parent
83f957bc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
71 deletions
+93
-71
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+93
-71
No files found.
lm_eval/models/huggingface.py
View file @
ffc3a456
...
...
@@ -3,6 +3,7 @@ import transformers
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import
copy
from
collections
import
defaultdict
from
tqdm
import
tqdm
import
torch.nn.functional
as
F
...
...
@@ -520,87 +521,108 @@ class HFLM(LM):
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
res
=
[]
res
=
defaultdict
(
list
)
re_ords
=
{}
def
_collate
(
x
):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks
=
self
.
tok_encode
(
x
[
0
])
return
-
len
(
toks
),
x
[
0
]
re_ord
=
utils
.
Reorderer
([
req
.
args
for
req
in
requests
],
_collate
)
grouper
=
utils
.
Grouper
(
requests
,
lambda
x
:
str
(
x
.
args
[
1
]))
for
key
,
reqs
in
grouper
.
get_grouped
().
items
():
re_ords
[
key
]
=
utils
.
Reorderer
([
req
.
args
for
req
in
reqs
],
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
pbar
=
tqdm
(
total
=
len
(
requests
))
assert
len
(
requests
)
==
sum
(
[
len
(
list
(
re_ord
.
get_reordered
()))
for
re_ord
in
re_ords
.
values
()]
)
for
key
,
re_ord
in
re_ords
.
items
():
for
chunk
in
utils
.
chunks
(
# tqdm(
re_ord
.
get_reordered
(),
disable
=
(
self
.
rank
!=
0
),
),
self
.
batch_size
,
):
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
gen_kwargs
=
all_gen_kwargs
[
0
]
# TODO: handle case where not all gen kwargs are same
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
kwargs
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `generation_kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
raise
ValueError
(
f
"Expected `generation_kwargs` to be of type `dict` but got
{
kwargs
}
"
)
if
not
until
:
until
=
[
self
.
tok_decode
(
self
.
eot_token_id
)]
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
# first stop sequence is used to halt generation upon encountering
(
primary_until
)
=
until
[
0
]
# set the max length in tokens of inputs ("context_enc")
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len
=
self
.
max_length
-
max_gen_toks
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
# max len for inputs = encoder's whole max_length
max_ctx_len
=
self
.
max_length
# encode, pad, and truncate contexts
context_enc
,
attn_masks
=
self
.
tok_batch_encode
(
contexts
,
left_truncate_len
=
max_ctx_len
)
context_enc
=
context_enc
.
to
(
self
.
device
)
attn_masks
=
attn_masks
.
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context
=
context_enc
,
attention_mask
=
attn_masks
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
**
kwargs
,
)
cont_toks_list
=
cont
.
tolist
()
for
cont_toks
,
context
in
zip
(
cont_toks_list
,
contexts
):
# discard context + left-padding toks if using causal decoder-only LM
# disable=(self.rank != 0),
# ),
self
.
batch_size
,
):
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
gen_kwargs
=
all_gen_kwargs
[
0
]
# TODO: handle case where not all gen kwargs are same
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
kwargs
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `generation_kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
raise
ValueError
(
f
"Expected `generation_kwargs` to be of type `dict` but got
{
kwargs
}
"
)
if
not
until
:
until
=
[
self
.
tok_decode
(
self
.
eot_token_id
)]
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
# first stop sequence is used to halt generation upon encountering
(
primary_until
)
=
until
[
0
]
# set the max length in tokens of inputs ("context_enc")
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
cont_toks
=
cont_toks
[
context_enc
.
shape
[
1
]
:]
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len
=
self
.
max_length
-
max_gen_toks
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
# max len for inputs = encoder's whole max_length
max_ctx_len
=
self
.
max_length
s
=
self
.
tok_decode
(
cont_toks
)
# encode, pad, and truncate contexts
context_enc
,
attn_masks
=
self
.
tok_batch_encode
(
contexts
,
left_truncate_len
=
max_ctx_len
)
context_enc
=
context_enc
.
to
(
self
.
device
)
attn_masks
=
attn_masks
.
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context
=
context_enc
,
attention_mask
=
attn_masks
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
**
kwargs
,
)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for
term
in
until
:
if
len
(
term
)
>
0
:
# ignore '' separator, for seq2seq case where
s
=
s
.
split
(
term
)[
0
]
cont_toks_list
=
cont
.
tolist
()
for
cont_toks
,
context
in
zip
(
cont_toks_list
,
contexts
):
# discard context + left-padding toks if using causal decoder-only LM
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
cont_toks
=
cont_toks
[
context_enc
.
shape
[
1
]
:]
res
.
append
(
s
)
s
=
self
.
tok_decode
(
cont_tok
s
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
gen_kwargs
),
s
)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for
term
in
until
:
if
len
(
term
)
>
0
:
# ignore '' separator, for seq2seq case where
s
=
s
.
split
(
term
)[
0
]
return
re_ord
.
get_original
(
res
)
res
[
str
(
gen_kwargs
)].
append
(
s
)
# TODO: move this to res[-1].append(s) to separate per re_ord
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
gen_kwargs
),
s
)
pbar
.
update
(
1
)
res
[
key
]
=
re_ord
.
get_original
(
res
[
key
])
pbar
.
close
()
return
grouper
.
get_original
(
res
)
# return utils.join_iters([re_ord.get_original(rs) for re_ord, rs in zip(re_ords, res.values())])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment