Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ffc3a456
Commit
ffc3a456
authored
Jun 28, 2023
by
haileyschoelkopf
Browse files
push WIP batched multi-kwarg code
parent
83f957bc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
71 deletions
+93
-71
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+93
-71
No files found.
lm_eval/models/huggingface.py
View file @
ffc3a456
...
...
@@ -3,6 +3,7 @@ import transformers
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import
copy
from
collections
import
defaultdict
from
tqdm
import
tqdm
import
torch.nn.functional
as
F
...
...
@@ -520,19 +521,33 @@ class HFLM(LM):
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
res
=
[]
res
=
defaultdict
(
list
)
re_ords
=
{}
def
_collate
(
x
):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks
=
self
.
tok_encode
(
x
[
0
])
return
-
len
(
toks
),
x
[
0
]
re_ord
=
utils
.
Reorderer
([
req
.
args
for
req
in
requests
],
_collate
)
grouper
=
utils
.
Grouper
(
requests
,
lambda
x
:
str
(
x
.
args
[
1
]))
for
key
,
reqs
in
grouper
.
get_grouped
().
items
():
re_ords
[
key
]
=
utils
.
Reorderer
([
req
.
args
for
req
in
reqs
],
_collate
)
pbar
=
tqdm
(
total
=
len
(
requests
))
assert
len
(
requests
)
==
sum
(
[
len
(
list
(
re_ord
.
get_reordered
()))
for
re_ord
in
re_ords
.
values
()]
)
for
key
,
re_ord
in
re_ords
.
items
():
for
chunk
in
utils
.
chunks
(
tqdm
(
#
tqdm(
re_ord
.
get_reordered
(),
disable
=
(
self
.
rank
!=
0
),
),
#
disable=(self.rank != 0),
#
),
self
.
batch_size
,
):
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
...
...
@@ -599,8 +614,15 @@ class HFLM(LM):
if
len
(
term
)
>
0
:
# ignore '' separator, for seq2seq case where
s
=
s
.
split
(
term
)[
0
]
res
.
append
(
s
)
res
[
str
(
gen_kwargs
)].
append
(
s
)
# TODO: move this to res[-1].append(s) to separate per re_ord
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
gen_kwargs
),
s
)
return
re_ord
.
get_original
(
res
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
gen_kwargs
),
s
)
pbar
.
update
(
1
)
res
[
key
]
=
re_ord
.
get_original
(
res
[
key
])
pbar
.
close
()
return
grouper
.
get_original
(
res
)
# return utils.join_iters([re_ord.get_original(rs) for re_ord, rs in zip(re_ords, res.values())])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment