Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b250b001
Commit
b250b001
authored
Jun 28, 2023
by
haileyschoelkopf
Browse files
clean up batched code and add comments
parent
ffc3a456
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
16 deletions
+20
-16
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+20
-16
No files found.
lm_eval/models/huggingface.py
View file @
b250b001
...
...
@@ -534,26 +534,27 @@ class HFLM(LM):
toks
=
self
.
tok_encode
(
x
[
0
])
return
-
len
(
toks
),
x
[
0
]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper
=
utils
.
Grouper
(
requests
,
lambda
x
:
str
(
x
.
args
[
1
]))
for
key
,
reqs
in
grouper
.
get_grouped
().
items
():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords
[
key
]
=
utils
.
Reorderer
([
req
.
args
for
req
in
reqs
],
_collate
)
pbar
=
tqdm
(
total
=
len
(
requests
))
assert
len
(
requests
)
==
sum
(
[
len
(
list
(
re_ord
.
get_reordered
()))
for
re_ord
in
re_ords
.
values
()]
)
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
# for each different set of kwargs, we execute all requests, by batch.
for
key
,
re_ord
in
re_ords
.
items
():
for
chunk
in
utils
.
chunks
(
# tqdm(
re_ord
.
get_reordered
(),
# disable=(self.rank != 0),
# ),
self
.
batch_size
,
):
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
gen_kwargs
=
all_gen_kwargs
[
0
]
# TODO: handle case where not all gen kwargs are same
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
...
...
@@ -586,13 +587,14 @@ class HFLM(LM):
# max len for inputs = encoder's whole max_length
max_ctx_len
=
self
.
max_length
# encode, pad, and truncate contexts
# encode, pad, and truncate contexts
for this batch
context_enc
,
attn_masks
=
self
.
tok_batch_encode
(
contexts
,
left_truncate_len
=
max_ctx_len
)
context_enc
=
context_enc
.
to
(
self
.
device
)
attn_masks
=
attn_masks
.
to
(
self
.
device
)
# perform batched generation
cont
=
self
.
_model_generate
(
context
=
context_enc
,
attention_mask
=
attn_masks
,
...
...
@@ -611,18 +613,20 @@ class HFLM(LM):
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for
term
in
until
:
if
len
(
term
)
>
0
:
# ignore '' separator, for seq2seq case where
if
len
(
term
)
>
0
:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s
=
s
.
split
(
term
)[
0
]
res
[
str
(
gen_kwargs
)].
append
(
s
)
# TODO: move this to res[-1].append(s) to separate per re_ord
res
[
key
].
append
(
s
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
gen_kwargs
),
s
)
pbar
.
update
(
1
)
# reorder this group of results back to original unsorted form
res
[
key
]
=
re_ord
.
get_original
(
res
[
key
])
pbar
.
close
()
return
grouper
.
get_original
(
res
)
# return utils.join_iters([re_ord.get_original(rs) for re_ord, rs in zip(re_ords, res.values())])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment