Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4d49dd03
Commit
4d49dd03
authored
Dec 28, 2023
by
lintangsutawika
Browse files
aggregation to compute_metric
parent
c6a91582
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
16 deletions
+14
-16
lm_eval/evaluator.py
lm_eval/evaluator.py
+14
-16
No files found.
lm_eval/evaluator.py
View file @
4d49dd03
...
@@ -367,7 +367,7 @@ def evaluate(
...
@@ -367,7 +367,7 @@ def evaluate(
# subset instances to only this document id ; sort by idx
# subset instances to only this document id ; sort by idx
requests
=
list
(
filter
(
lambda
x
:
x
.
doc_id
==
doc_id
,
task
.
instances
))
requests
=
list
(
filter
(
lambda
x
:
x
.
doc_id
==
doc_id
,
task
.
instances
))
requests
.
sort
(
key
=
lambda
x
:
x
.
idx
)
requests
.
sort
(
key
=
lambda
x
:
x
.
idx
)
item
s
=
task
.
process_results
(
metric
s
=
task
.
process_results
(
doc
,
[
req
.
filtered_resps
[
key
]
for
req
in
requests
]
doc
,
[
req
.
filtered_resps
[
key
]
for
req
in
requests
]
)
)
if
log_samples
:
if
log_samples
:
...
@@ -380,11 +380,10 @@ def evaluate(
...
@@ -380,11 +380,10 @@ def evaluate(
"resps"
:
[
req
.
resps
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
key
]
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
key
]
for
req
in
requests
],
}
}
example
.
update
(
item
s
)
example
.
update
(
metric
s
)
samples
[
task_name
].
append
(
example
)
samples
[
task_name
].
append
(
example
)
vals
[(
task_name
,
key
)].
append
(
items
)
for
metric
,
value
in
metrics
.
items
():
# for metric, value in results.items():
vals
[(
task_name
,
key
,
metric
)].
append
(
value
)
# vals[(task_name, key, metric)].append(value)
if
lm
.
world_size
>
1
:
if
lm
.
world_size
>
1
:
# if multigpu, then gather data across all ranks
# if multigpu, then gather data across all ranks
...
@@ -397,8 +396,7 @@ def evaluate(
...
@@ -397,8 +396,7 @@ def evaluate(
# then collect metrics across all ranks
# then collect metrics across all ranks
vals_torch
=
collections
.
defaultdict
(
list
)
vals_torch
=
collections
.
defaultdict
(
list
)
# for (task_name, key, metric), items in vals.items():
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
for
(
task_name
,
key
),
items
in
vals
.
items
():
numitem
=
0
numitem
=
0
if
type
(
items
[
0
])
==
tuple
:
if
type
(
items
[
0
])
==
tuple
:
numitem
=
len
(
items
[
0
])
numitem
=
len
(
items
[
0
])
...
@@ -434,8 +432,7 @@ def evaluate(
...
@@ -434,8 +432,7 @@ def evaluate(
gathered_item
=
[
tuple
(
g
)
for
g
in
gathered_item
]
gathered_item
=
[
tuple
(
g
)
for
g
in
gathered_item
]
if
lm
.
rank
==
0
:
if
lm
.
rank
==
0
:
# vals_torch[(task_name, key, metric)] = gathered_item
vals_torch
[(
task_name
,
key
,
metric
)]
=
gathered_item
vals_torch
[(
task_name
,
key
)]
=
gathered_item
vals
=
vals_torch
vals
=
vals_torch
...
@@ -443,25 +440,26 @@ def evaluate(
...
@@ -443,25 +440,26 @@ def evaluate(
### Aggregate results over all datapoints ###
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
# aggregate results ; run bootstrap CIs
# for (task_name, key, metric), items in vals.items():
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
for
(
task_name
,
key
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
#
metric_key = metric + "," + key
metric_key
=
metric
+
","
+
key
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group_name
,
task
=
task
group_name
,
task
=
task
else
:
else
:
group_name
=
None
group_name
=
None
for
metric_key
,
metric_fn
in
task
.
aggregation
().
items
():
results
[
task_name
][
metric_key
]
=
metric_fn
(
*
list
(
zip
(
*
items
)))
metric_fn
=
task
.
compute_metric
()[
metric
]
results
[
task_name
][
"samples"
]
=
len
(
items
)
results
[
task_name
][
metric_key
]
=
metric_fn
(
items
)
results
[
task_name
][
"samples"
]
=
len
(
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
if
bootstrap_iters
>
0
:
if
bootstrap_iters
>
0
:
stderr
=
lm_eval
.
api
.
metrics
.
stderr_for_metric
(
stderr
=
lm_eval
.
api
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
metric
],
# metric=task.aggregation()[metric],
metric
=
task
.
compute_metric
()[
metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
100
)
bootstrap_iters
=
min
(
bootstrap_iters
,
100
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
else
bootstrap_iters
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment