Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
62572f05
"vscode:/vscode.git/clone" did not exist on "e91c1182d6d964bc2a3f8e4f5ef03d075618aceb"
Commit
62572f05
authored
May 07, 2024
by
lintangsutawika
Browse files
adjust group scoring with using ConfigurableGroup
parent
fe2a1472
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
134 additions
and
106 deletions
+134
-106
lm_eval/evaluator.py
lm_eval/evaluator.py
+134
-106
No files found.
lm_eval/evaluator.py
View file @
62572f05
...
...
@@ -22,7 +22,7 @@ from lm_eval.evaluator_utils import (
run_task_tests
,
)
from
lm_eval.logging_utils
import
add_env_info
,
get_git_commit_hash
from
lm_eval.tasks
import
TaskManager
,
get_task_dict
from
lm_eval.tasks
import
ConfigurableGroup
,
ConfigurableTask
,
TaskManager
,
get_task_dict
from
lm_eval.utils
import
eval_logger
,
positional_deprecated
,
simple_parse_args_string
...
...
@@ -204,17 +204,24 @@ def simple_evaluate(
+
".db"
,
)
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
if
task_manager
is
None
:
task_manager
=
TaskManager
(
verbosity
)
task_dict
=
get_task_dict
(
tasks
,
task_manager
)
for
task_name
in
task_dict
.
keys
():
task_obj
=
task_dict
[
task_name
]
if
isinstance
(
task_obj
,
tuple
):
_
,
task_obj
=
task_obj
if
isinstance
(
task_obj
,
lm_eval
.
api
.
task
.
ConfigurableTask
)
is
False
:
continue
def
_adjust_config
(
task_dict
):
adjusted_task_dict
=
{}
for
task_name
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
task_obj
,
dict
):
adjusted_task_dict
=
{
**
adjusted_task_dict
,
**
{
task_name
:
_adjust_config
(
task_obj
)}
}
else
:
if
task_obj
.
get_config
(
"output_type"
)
==
"generate_until"
:
if
gen_kwargs
is
not
None
:
task_obj
.
set_config
(
...
...
@@ -246,9 +253,11 @@ def simple_evaluate(
if
(
default_num_fewshot
:
=
task_obj
.
get_config
(
"num_fewshot"
))
is
None
:
task_obj
.
set_config
(
key
=
"num_fewshot"
,
value
=
0
)
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
adjusted_task_dict
[
task_name
]
=
task_obj
return
adjusted_task_dict
task_dict
=
_adjust_config
(
task_dict
)
results
=
evaluate
(
lm
=
lm
,
task_dict
=
task_dict
,
...
...
@@ -330,7 +339,10 @@ def evaluate(
padding_requests
=
defaultdict
(
int
)
# get lists of group hierarchy and each type of request
task_hierarchy
,
eval_tasks
=
get_task_list
(
task_dict
)
eval_tasks
=
get_task_list
(
task_dict
)
# print("task_hierarchy")
# print(task_hierarchy)
# import sys; sys.exit()
if
not
log_samples
:
if
not
all
(
"bypass"
not
in
getattr
(
task_output
.
task
,
"_metric_fn_list"
,
{}).
keys
()
...
...
@@ -484,24 +496,36 @@ def evaluate(
### Calculate group metrics ###
if
bool
(
results
):
show_group_table
=
False
for
group
,
group_info
in
reversed
(
task_hierarchy
.
items
()):
task_list
=
group_info
[
"tasks"
]
if
len
(
task_list
)
==
0
:
# task_hierarchy entries are either
# `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`.
# we only want to operate on groups here.
continue
def
process_group
(
results
,
task_dict
,
task_root
=
None
,
task_hierarchy
=
None
,
show_group_table
=
False
):
group_config
=
lm_eval
.
api
.
task
.
GroupConfig
(
**
group_info
[
"config"
]
if
"config"
in
group_info
else
{}
)
if
task_root
is
None
:
task_root
=
{}
if
task_hierarchy
is
None
:
task_hierarchy
=
{}
for
group_or_task
,
group_or_task_info
in
task_dict
.
items
():
if
isinstance
(
group_or_task_info
,
ConfigurableTask
):
if
task_root
:
task_hierarchy
.
setdefault
(
task_root
,
[]).
append
(
group_or_task
)
else
:
results
,
_task_hierarchy
,
show_group_table
=
process_group
(
results
,
group_or_task_info
,
group_or_task
,
task_hierarchy
,
show_group_table
)
if
task_root
:
task_hierarchy
.
setdefault
(
task_root
,
[]).
extend
(
task_hierarchy
.
get
(
group_or_task
,
[]))
if
isinstance
(
group_or_task
,
ConfigurableGroup
):
group_config
=
group_or_task
.
config
group
=
group_or_task
.
group
show_group_table
=
show_group_table
|
group_config
[
"aggregate_metric"
]
if
group_config
[
"aggregate_metric"
]
is
False
:
results
[
group
][
" "
]
=
" "
continue
elif
isinstance
(
group_or_task
,
str
):
results
[
group_or_task
][
" "
]
=
" "
continue
task_list
=
_task_hierarchy
[
group_or_task
]
metric_list
=
list
(
{
key
...
...
@@ -550,6 +574,9 @@ def evaluate(
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results
[
group
][
"samples"
]
=
sum
(
sizes
)
return
results
,
task_hierarchy
,
show_group_table
results
,
task_hierarchy
,
show_group_table
=
process_group
(
results
,
task_dict
)
results_agg
=
defaultdict
(
dict
)
groups_agg
=
defaultdict
(
dict
)
...
...
@@ -575,6 +602,7 @@ def evaluate(
task_list
[
0
]
]
# TODO: validate this
import
sys
;
sys
.
exit
()
results_dict
=
{
"results"
:
dict
(
results_agg
.
items
()),
**
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment