Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e377c47f
Commit
e377c47f
authored
Jul 02, 2024
by
Nathan Habib
Browse files
linting
parent
84f59a7f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
24 deletions
+37
-24
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+27
-21
lm_eval/models/utils.py
lm_eval/models/utils.py
+10
-3
No files found.
lm_eval/models/huggingface.py
View file @
e377c47f
...
@@ -13,7 +13,6 @@ from accelerate import (
...
@@ -13,7 +13,6 @@ from accelerate import (
InitProcessGroupKwargs
,
InitProcessGroupKwargs
,
find_executable_batch_size
,
find_executable_batch_size
,
)
)
from
accelerate.utils
import
get_max_memory
from
huggingface_hub
import
HfApi
from
huggingface_hub
import
HfApi
from
packaging
import
version
from
packaging
import
version
from
peft
import
PeftModel
from
peft
import
PeftModel
...
@@ -680,17 +679,25 @@ class HFLM(TemplateLM):
...
@@ -680,17 +679,25 @@ class HFLM(TemplateLM):
return
None
return
None
def
_detect_batch_size
(
self
,
requests
=
None
,
pos
:
int
=
0
)
->
int
:
def
_detect_batch_size
(
self
,
requests
=
None
,
pos
:
int
=
0
)
->
int
:
if
len
(
requests
[
0
])
==
3
:
# logprob evals
if
len
(
requests
[
0
])
==
3
:
# logprob evals
_
,
context_enc
,
continuation_enc
=
requests
[
pos
]
_
,
context_enc
,
continuation_enc
=
requests
[
pos
]
max_length
=
len
(
max_length
=
len
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
]
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
]
)
)
max_context_enc
=
len
(
context_enc
[
-
(
self
.
max_length
+
1
)
:])
max_context_enc
=
len
(
context_enc
[
-
(
self
.
max_length
+
1
)
:])
max_cont_enc
=
len
(
continuation_enc
[
-
(
self
.
max_length
+
1
)
:])
max_cont_enc
=
len
(
continuation_enc
[
-
(
self
.
max_length
+
1
)
:])
security_margin_factor
=
4
# batch sizes for log prob evals sometimes generate OOMs
security_margin_factor
=
(
elif
len
(
requests
[
0
])
==
2
:
# generative evals
4
# batch sizes for log prob evals sometimes generate OOMs
)
elif
len
(
requests
[
0
])
==
2
:
# generative evals
# using rolling window with maximum context
# using rolling window with maximum context
longest_context
=
max
([
len
(
self
.
tok_encode
(
request
[
0
]))
+
request
[
1
].
get
(
"max_gen_toks"
,
self
.
max_length
)
for
request
in
requests
[
pos
:]])
longest_context
=
max
(
[
len
(
self
.
tok_encode
(
request
[
0
]))
+
request
[
1
].
get
(
"max_gen_toks"
,
self
.
max_length
)
for
request
in
requests
[
pos
:]
]
)
if
longest_context
>
self
.
max_length
:
if
longest_context
>
self
.
max_length
:
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"Longest context length of
{
longest_context
}
exceeds max_length of
{
self
.
max_length
}
. Truncating to max_length."
f
"Longest context length of
{
longest_context
}
exceeds max_length of
{
self
.
max_length
}
. Truncating to max_length."
...
@@ -701,7 +708,6 @@ class HFLM(TemplateLM):
...
@@ -701,7 +708,6 @@ class HFLM(TemplateLM):
max_cont_enc
=
max_length
max_cont_enc
=
max_length
security_margin_factor
=
4
security_margin_factor
=
4
# if OOM, then halves batch_size and tries again
# if OOM, then halves batch_size and tries again
@
find_executable_batch_size
(
starting_batch_size
=
self
.
max_batch_size
)
@
find_executable_batch_size
(
starting_batch_size
=
self
.
max_batch_size
)
def
forward_batch
(
batch_size
):
def
forward_batch
(
batch_size
):
...
@@ -711,7 +717,9 @@ class HFLM(TemplateLM):
...
@@ -711,7 +717,9 @@ class HFLM(TemplateLM):
batched_conts
=
torch
.
ones
(
batched_conts
=
torch
.
ones
(
(
batch_size
+
security_margin
,
length
),
device
=
self
.
device
(
batch_size
+
security_margin
,
length
),
device
=
self
.
device
).
long
()
).
long
()
test_batch
=
torch
.
ones
((
batch_size
+
security_margin
,
length
),
device
=
self
.
device
).
long
()
test_batch
=
torch
.
ones
(
(
batch_size
+
security_margin
,
length
),
device
=
self
.
device
).
long
()
call_kwargs
=
{
call_kwargs
=
{
"attn_mask"
:
test_batch
,
"attn_mask"
:
test_batch
,
"labels"
:
batched_conts
,
"labels"
:
batched_conts
,
...
@@ -722,7 +730,7 @@ class HFLM(TemplateLM):
...
@@ -722,7 +730,7 @@ class HFLM(TemplateLM):
(
batch_size
+
security_margin
,
max_length
),
device
=
self
.
device
(
batch_size
+
security_margin
,
max_length
),
device
=
self
.
device
).
long
()
).
long
()
for
_
in
range
(
5
*
security_margin_factor
):
for
_
in
range
(
5
*
security_margin_factor
):
logits
=
self
.
_model_call
(
inps
=
test_batch
,
**
call_kwargs
).
float
()
logits
=
self
.
_model_call
(
inps
=
test_batch
,
**
call_kwargs
).
float
()
scores
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
# noqa: F841
scores
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
# noqa: F841
...
@@ -1122,7 +1130,9 @@ class HFLM(TemplateLM):
...
@@ -1122,7 +1130,9 @@ class HFLM(TemplateLM):
}
}
multi_logits
=
F
.
log_softmax
(
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
,
dtype
=
torch
.
float16
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
,
dtype
=
torch
.
float16
,
)
# [batch, padding_length (inp or cont), vocab]
)
# [batch, padding_length (inp or cont), vocab]
for
(
request_str
,
ctx_tokens
,
_
),
logits
,
inplen
,
cont_toks
in
zip
(
for
(
request_str
,
ctx_tokens
,
_
),
logits
,
inplen
,
cont_toks
in
zip
(
...
@@ -1200,16 +1210,8 @@ class HFLM(TemplateLM):
...
@@ -1200,16 +1210,8 @@ class HFLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
desc
=
"Running generate_until requests"
,
desc
=
"Running generate_until requests"
,
)
)
batch_size
=
(
batch_size
=
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
0
self
.
batch_size
batch_fn
=
self
.
_batch_scheduler
if
self
.
batch_size
==
"auto"
else
None
if
self
.
batch_size
!=
"auto"
else
0
)
batch_fn
=
(
self
.
_batch_scheduler
if
self
.
batch_size
==
"auto"
else
None
)
# we group requests by their generation_kwargs,
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
...
@@ -1221,7 +1223,9 @@ class HFLM(TemplateLM):
...
@@ -1221,7 +1223,9 @@ class HFLM(TemplateLM):
group_by
=
"gen_kwargs"
,
group_by
=
"gen_kwargs"
,
group_fn
=
lambda
x
:
x
[
1
],
group_fn
=
lambda
x
:
x
[
1
],
)
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
,
reset_batch_fn
=
self
.
_reset_batch_scheduler
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
,
reset_batch_fn
=
self
.
_reset_batch_scheduler
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
# we assume all gen kwargs in the batch are the same
# we assume all gen kwargs in the batch are the same
...
@@ -1252,7 +1256,9 @@ class HFLM(TemplateLM):
...
@@ -1252,7 +1256,9 @@ class HFLM(TemplateLM):
if
"max_gen_toks"
in
kwargs
.
keys
():
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
if
max_gen_toks
>
self
.
max_length
:
# some model have low max length limit
if
(
max_gen_toks
>
self
.
max_length
):
# some model have low max length limit
max_gen_toks
=
self
.
max_gen_toks
max_gen_toks
=
self
.
max_gen_toks
else
:
else
:
max_gen_toks
=
self
.
max_gen_toks
max_gen_toks
=
self
.
max_gen_toks
...
...
lm_eval/models/utils.py
View file @
e377c47f
...
@@ -389,7 +389,12 @@ class Collator:
...
@@ -389,7 +389,12 @@ class Collator:
self
.
_arr_with_indices
,
fn
=
self
.
_group_fn
,
group_by
=
"contexts"
self
.
_arr_with_indices
,
fn
=
self
.
_group_fn
,
group_by
=
"contexts"
)
)
def
get_batched
(
self
,
n
:
int
=
1
,
batch_fn
:
Optional
[
Callable
]
=
None
,
reset_batch_fn
:
Optional
[
Callable
]
=
None
)
->
Iterator
:
def
get_batched
(
self
,
n
:
int
=
1
,
batch_fn
:
Optional
[
Callable
]
=
None
,
reset_batch_fn
:
Optional
[
Callable
]
=
None
,
)
->
Iterator
:
"""
"""
Generates and yields batches from the reordered array. The method of grouping and batching
Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
depends on the parameter `group_by`.
...
@@ -402,7 +407,7 @@ class Collator:
...
@@ -402,7 +407,7 @@ class Collator:
- n (int): The size of each batch. Defaults to 1.
- n (int): The size of each batch. Defaults to 1.
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None.
each batch. Optional, defaults to None.
- reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of
- reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of
the batch_fn, if present, when we change group in generative mode.
the batch_fn, if present, when we change group in generative mode.
Returns:
Returns:
...
@@ -414,7 +419,9 @@ class Collator:
...
@@ -414,7 +419,9 @@ class Collator:
"""
"""
if
self
.
_group_by
==
"gen_kwargs"
:
if
self
.
_group_by
==
"gen_kwargs"
:
for
key
,
values
in
self
.
_arr_with_indices
.
items
():
# type: ignore
for
key
,
values
in
self
.
_arr_with_indices
.
items
():
# type: ignore
if
reset_batch_fn
is
not
None
:
# with each group change, we must recompute the batch size, so we restart the scheduler
if
(
reset_batch_fn
is
not
None
):
# with each group change, we must recompute the batch size, so we restart the scheduler
reset_batch_fn
()
reset_batch_fn
()
values
=
self
.
_reorder
(
values
)
values
=
self
.
_reorder
(
values
)
batch
=
self
.
get_chunks
(
values
,
n
=
n
,
fn
=
batch_fn
)
batch
=
self
.
get_chunks
(
values
,
n
=
n
,
fn
=
batch_fn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment