Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e377c47f
Commit
e377c47f
authored
Jul 02, 2024
by
Nathan Habib
Browse files
linting
parent
84f59a7f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
24 deletions
+37
-24
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+27
-21
lm_eval/models/utils.py
lm_eval/models/utils.py
+10
-3
No files found.
lm_eval/models/huggingface.py
View file @
e377c47f
...
...
@@ -13,7 +13,6 @@ from accelerate import (
InitProcessGroupKwargs
,
find_executable_batch_size
,
)
from
accelerate.utils
import
get_max_memory
from
huggingface_hub
import
HfApi
from
packaging
import
version
from
peft
import
PeftModel
...
...
@@ -680,17 +679,25 @@ class HFLM(TemplateLM):
return
None
def
_detect_batch_size
(
self
,
requests
=
None
,
pos
:
int
=
0
)
->
int
:
if
len
(
requests
[
0
])
==
3
:
# logprob evals
if
len
(
requests
[
0
])
==
3
:
# logprob evals
_
,
context_enc
,
continuation_enc
=
requests
[
pos
]
max_length
=
len
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
]
)
max_context_enc
=
len
(
context_enc
[
-
(
self
.
max_length
+
1
)
:])
max_cont_enc
=
len
(
continuation_enc
[
-
(
self
.
max_length
+
1
)
:])
security_margin_factor
=
4
# batch sizes for log prob evals sometimes generate OOMs
elif
len
(
requests
[
0
])
==
2
:
# generative evals
security_margin_factor
=
(
4
# batch sizes for log prob evals sometimes generate OOMs
)
elif
len
(
requests
[
0
])
==
2
:
# generative evals
# using rolling window with maximum context
longest_context
=
max
([
len
(
self
.
tok_encode
(
request
[
0
]))
+
request
[
1
].
get
(
"max_gen_toks"
,
self
.
max_length
)
for
request
in
requests
[
pos
:]])
longest_context
=
max
(
[
len
(
self
.
tok_encode
(
request
[
0
]))
+
request
[
1
].
get
(
"max_gen_toks"
,
self
.
max_length
)
for
request
in
requests
[
pos
:]
]
)
if
longest_context
>
self
.
max_length
:
eval_logger
.
warning
(
f
"Longest context length of
{
longest_context
}
exceeds max_length of
{
self
.
max_length
}
. Truncating to max_length."
...
...
@@ -701,7 +708,6 @@ class HFLM(TemplateLM):
max_cont_enc
=
max_length
security_margin_factor
=
4
# if OOM, then halves batch_size and tries again
@
find_executable_batch_size
(
starting_batch_size
=
self
.
max_batch_size
)
def
forward_batch
(
batch_size
):
...
...
@@ -711,7 +717,9 @@ class HFLM(TemplateLM):
batched_conts
=
torch
.
ones
(
(
batch_size
+
security_margin
,
length
),
device
=
self
.
device
).
long
()
test_batch
=
torch
.
ones
((
batch_size
+
security_margin
,
length
),
device
=
self
.
device
).
long
()
test_batch
=
torch
.
ones
(
(
batch_size
+
security_margin
,
length
),
device
=
self
.
device
).
long
()
call_kwargs
=
{
"attn_mask"
:
test_batch
,
"labels"
:
batched_conts
,
...
...
@@ -722,7 +730,7 @@ class HFLM(TemplateLM):
(
batch_size
+
security_margin
,
max_length
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
*
security_margin_factor
):
for
_
in
range
(
5
*
security_margin_factor
):
logits
=
self
.
_model_call
(
inps
=
test_batch
,
**
call_kwargs
).
float
()
scores
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
# noqa: F841
...
...
@@ -1122,7 +1130,9 @@ class HFLM(TemplateLM):
}
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
,
dtype
=
torch
.
float16
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
,
dtype
=
torch
.
float16
,
)
# [batch, padding_length (inp or cont), vocab]
for
(
request_str
,
ctx_tokens
,
_
),
logits
,
inplen
,
cont_toks
in
zip
(
...
...
@@ -1200,16 +1210,8 @@ class HFLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
desc
=
"Running generate_until requests"
,
)
batch_size
=
(
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
0
)
batch_fn
=
(
self
.
_batch_scheduler
if
self
.
batch_size
==
"auto"
else
None
)
batch_size
=
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
0
batch_fn
=
self
.
_batch_scheduler
if
self
.
batch_size
==
"auto"
else
None
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
...
...
@@ -1221,7 +1223,9 @@ class HFLM(TemplateLM):
group_by
=
"gen_kwargs"
,
group_fn
=
lambda
x
:
x
[
1
],
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
,
reset_batch_fn
=
self
.
_reset_batch_scheduler
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
,
reset_batch_fn
=
self
.
_reset_batch_scheduler
)
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
# we assume all gen kwargs in the batch are the same
...
...
@@ -1252,7 +1256,9 @@ class HFLM(TemplateLM):
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
if
max_gen_toks
>
self
.
max_length
:
# some model have low max length limit
if
(
max_gen_toks
>
self
.
max_length
):
# some model have low max length limit
max_gen_toks
=
self
.
max_gen_toks
else
:
max_gen_toks
=
self
.
max_gen_toks
...
...
lm_eval/models/utils.py
View file @
e377c47f
...
...
@@ -389,7 +389,12 @@ class Collator:
self
.
_arr_with_indices
,
fn
=
self
.
_group_fn
,
group_by
=
"contexts"
)
def
get_batched
(
self
,
n
:
int
=
1
,
batch_fn
:
Optional
[
Callable
]
=
None
,
reset_batch_fn
:
Optional
[
Callable
]
=
None
)
->
Iterator
:
def
get_batched
(
self
,
n
:
int
=
1
,
batch_fn
:
Optional
[
Callable
]
=
None
,
reset_batch_fn
:
Optional
[
Callable
]
=
None
,
)
->
Iterator
:
"""
Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
...
...
@@ -402,7 +407,7 @@ class Collator:
- n (int): The size of each batch. Defaults to 1.
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None.
- reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of
- reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of
the batch_fn, if present, when we change group in generative mode.
Returns:
...
...
@@ -414,7 +419,9 @@ class Collator:
"""
if
self
.
_group_by
==
"gen_kwargs"
:
for
key
,
values
in
self
.
_arr_with_indices
.
items
():
# type: ignore
if
reset_batch_fn
is
not
None
:
# with each group change, we must recompute the batch size, so we restart the scheduler
if
(
reset_batch_fn
is
not
None
):
# with each group change, we must recompute the batch size, so we restart the scheduler
reset_batch_fn
()
values
=
self
.
_reorder
(
values
)
batch
=
self
.
get_chunks
(
values
,
n
=
n
,
fn
=
batch_fn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment