Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
53754d41
Unverified
Commit
53754d41
authored
Jul 17, 2023
by
Lintang Sutawika
Committed by
GitHub
Jul 17, 2023
Browse files
Merge pull request #679 from EleutherAI/fix-padding-ranks
[Refactor] Fix padding ranks
parents
e5161a6d
abccc756
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
lm_eval/evaluator.py
lm_eval/evaluator.py
+7
-4
No files found.
lm_eval/evaluator.py
View file @
53754d41
...
@@ -191,6 +191,8 @@ def evaluate(
...
@@ -191,6 +191,8 @@ def evaluate(
samples
=
collections
.
defaultdict
(
list
)
samples
=
collections
.
defaultdict
(
list
)
requests
=
collections
.
defaultdict
(
list
)
requests
=
collections
.
defaultdict
(
list
)
padding_requests
=
collections
.
defaultdict
(
int
)
# get lists of each type of request
# get lists of each type of request
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
versions
[
task_name
]
=
task
.
VERSION
versions
[
task_name
]
=
task
.
VERSION
...
@@ -239,6 +241,7 @@ def evaluate(
...
@@ -239,6 +241,7 @@ def evaluate(
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad
=
max
(
gathered_item
)
-
gathered_item
[
lm
.
rank
]
numpad
=
max
(
gathered_item
)
-
gathered_item
[
lm
.
rank
]
padding_requests
[
task
.
OUTPUT_TYPE
]
+=
numpad
### Run LM on inputs, get all outputs ###
### Run LM on inputs, get all outputs ###
# execute each type of request
# execute each type of request
...
@@ -249,8 +252,8 @@ def evaluate(
...
@@ -249,8 +252,8 @@ def evaluate(
for
req
in
reqs
:
for
req
in
reqs
:
cloned_reqs
.
extend
([
req
]
*
req
.
repeats
)
cloned_reqs
.
extend
([
req
]
*
req
.
repeats
)
if
(
lm
.
world_size
>
1
)
and
(
num
pad
>
0
):
if
(
lm
.
world_size
>
1
)
and
(
pad
ding_requests
[
reqtype
]
>
0
):
for
_
in
range
(
num
pad
):
for
_
in
range
(
pad
ding_requests
[
reqtype
]
):
cloned_reqs
.
extend
([
req
]
*
req
.
repeats
)
cloned_reqs
.
extend
([
req
]
*
req
.
repeats
)
# run requests through model
# run requests through model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment