Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
992f021b
Unverified
Commit
992f021b
authored
Oct 11, 2023
by
Lintang Sutawika
Committed by
GitHub
Oct 11, 2023
Browse files
Merge pull request #912 from AndyWolfZwei/andy/big-refactor-auto-batch
[Refactor] Add _batch_scheduler in greedy_until
parents
2afe7770
660dfb71
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
25 deletions
+35
-25
.github/workflows/unit_tests.yml
.github/workflows/unit_tests.yml
+1
-1
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+33
-23
lm_eval/utils.py
lm_eval/utils.py
+1
-1
No files found.
.github/workflows/unit_tests.yml
View file @
992f021b
...
...
@@ -43,7 +43,7 @@ jobs:
# # mypy turned off for now
# - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
Job
2
Job 2
:
testcpu
:
name
:
CPU Tests
runs-on
:
ubuntu-latest
...
...
lm_eval/models/huggingface.py
View file @
992f021b
...
...
@@ -621,6 +621,23 @@ class HFLM(LM):
return
loglikelihoods
def
_batch_scheduler
(
self
,
pos
,
n_reordered_requests
):
sched
=
pos
//
int
(
len
(
n_reordered_requests
)
/
self
.
batch_schedule
)
if
sched
in
self
.
batch_sizes
:
return
self
.
batch_sizes
[
sched
]
if
(
len
(
self
.
batch_sizes
)
>
1
)
and
(
self
.
batch_sizes
[
sched
-
1
]
==
self
.
max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self
.
batch_sizes
[
sched
]
=
self
.
max_batch_size
return
self
.
batch_sizes
[
sched
]
print
(
f
"Passed argument batch_size = auto:
{
self
.
batch_schedule
}
. Detecting largest batch size"
)
self
.
batch_sizes
[
sched
]
=
self
.
_detect_batch_size
(
n_reordered_requests
,
pos
)
print
(
f
"Determined largest batch size:
{
self
.
batch_sizes
[
sched
]
}
"
)
return
self
.
batch_sizes
[
sched
]
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
:
bool
=
False
,
override_bs
=
None
):
...
...
@@ -644,25 +661,6 @@ class HFLM(LM):
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
def
_batch_scheduler
(
pos
):
sched
=
pos
//
int
(
n_reordered_requests
/
self
.
batch_schedule
)
if
sched
in
self
.
batch_sizes
:
return
self
.
batch_sizes
[
sched
]
if
(
len
(
self
.
batch_sizes
)
>
1
)
and
(
self
.
batch_sizes
[
sched
-
1
]
==
self
.
max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self
.
batch_sizes
[
sched
]
=
self
.
max_batch_size
return
self
.
batch_sizes
[
sched
]
print
(
f
"Passed argument batch_size = auto:
{
self
.
batch_schedule
}
. Detecting largest batch size"
)
self
.
batch_sizes
[
sched
]
=
self
.
_detect_batch_size
(
re_ord
.
get_reordered
(),
pos
)
print
(
f
"Determined largest batch size:
{
self
.
batch_sizes
[
sched
]
}
"
)
return
self
.
batch_sizes
[
sched
]
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
))),
n
=
self
.
batch_size
...
...
@@ -670,7 +668,7 @@ class HFLM(LM):
else
override_bs
if
override_bs
is
not
None
else
0
,
fn
=
_batch_scheduler
fn
=
self
.
_batch_scheduler
if
self
.
batch_size
==
"auto"
and
n_reordered_requests
>
0
and
not
override_bs
...
...
@@ -838,12 +836,24 @@ class HFLM(LM):
re_ords
[
key
]
=
utils
.
Reorderer
([
req
.
args
for
req
in
reqs
],
_collate
)
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
print
(
"Passed argument batch_size = auto. Detecting largest batch size"
)
batch_size
=
self
.
_detect_batch_size
()
print
(
f
"Determined Largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
# for each different set of kwargs, we execute all requests, by batch.
for
key
,
re_ord
in
re_ords
.
items
():
for
chunk
in
utils
.
chunks
(
re_ord
.
get_reordered
(),
self
.
batch_size
,
tqdm
(
re_ord
.
get_reordered
(),
disable
=
self
.
rank
!=
0
),
n
=
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
adaptive_batch_size
if
adaptive_batch_size
is
not
None
else
0
,
fn
=
self
.
_batch_scheduler
if
self
.
batch_size
==
"auto"
and
not
adaptive_batch_size
else
None
,
):
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
# we assume all gen kwargs in the batch are the same
...
...
lm_eval/utils.py
View file @
992f021b
...
...
@@ -78,7 +78,7 @@ def chunks(iter, n: int = 0, fn=None):
arr
=
[]
for
i
,
x
in
enumerate
(
iter
):
arr
.
append
(
x
)
if
len
(
arr
)
==
(
fn
(
i
)
if
fn
else
n
):
if
len
(
arr
)
==
(
fn
(
i
,
iter
)
if
fn
else
n
):
yield
arr
arr
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment