Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
336cc455
Commit
336cc455
authored
Nov 29, 2023
by
baberabb
Browse files
Merge remote-tracking branch 'origin/big-refactor' into big-refactor-mps
parents
7bb147b5
42f486ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
11 deletions
+4
-11
lm_eval/api/task.py
lm_eval/api/task.py
+3
-3
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-8
No files found.
lm_eval/api/task.py
View file @
336cc455
...
...
@@ -81,7 +81,7 @@ class TaskConfig(dict):
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_config
:
dict
=
None
# runtime configuration options
num_fewshot
:
int
=
-
1
num_fewshot
:
int
=
None
# scoring options
metric_list
:
list
=
None
output_type
:
str
=
"generate_until"
...
...
@@ -361,7 +361,7 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx
=
self
.
fewshot_context
(
doc
,
self
.
config
.
num_fewshot
,
0
if
self
.
config
.
num_fewshot
is
None
else
self
.
config
.
num_fewshot
,
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...
...
@@ -777,7 +777,7 @@ class ConfigurableTask(Task):
if
self
.
config
.
fewshot_split
is
not
None
:
return
self
.
dataset
[
self
.
config
.
fewshot_split
]
else
:
if
self
.
config
.
num_fewshot
>
0
:
if
(
self
.
config
.
num_fewshot
is
not
None
)
and
(
self
.
config
.
num_fewshot
>
0
)
:
eval_logger
.
warning
(
f
"Task '
{
self
.
config
.
task
}
': "
"num_fewshot > 0 but fewshot_split is None. "
...
...
lm_eval/evaluator.py
View file @
336cc455
...
...
@@ -260,7 +260,7 @@ def evaluate(
if
"num_fewshot"
in
configs
[
task_name
]:
n_shot
=
configs
[
task_name
][
"num_fewshot"
]
else
:
n_shot
=
-
1
n_shot
=
0
num_fewshot
[
task_name
]
=
n_shot
if
"task_alias"
in
configs
[
task_name
]:
...
...
@@ -440,7 +440,6 @@ def evaluate(
vals
=
vals_torch
if
lm
.
rank
==
0
:
### Get task ordering for correct sample-wide aggregation
group_to_task
=
{}
for
group
in
task_hierarchy
.
keys
():
...
...
@@ -451,7 +450,6 @@ def evaluate(
group_to_task
[
group
]
=
task_hierarchy
[
group
].
copy
()
for
task
in
task_hierarchy
[
group
]:
if
task
in
task_order
:
task_order
[
task
]
+=
1
else
:
...
...
@@ -498,9 +496,7 @@ def evaluate(
results
[
task_name
][
metric
+
"_stderr"
+
","
+
key
]
=
stderr
(
items
)
if
bool
(
results
):
for
group
,
task_list
in
reversed
(
task_hierarchy
.
items
()):
if
task_list
==
[]:
total_size
=
results
[
group
][
"samples"
]
else
:
...
...
@@ -520,7 +516,6 @@ def evaluate(
for
metric
in
[
key
for
key
in
metrics
.
keys
()
if
"_stderr"
not
in
key
]:
stderr
=
"_stderr,"
.
join
(
metric
.
split
(
","
))
stderr_score
=
results
[
task
][
stderr
]
var_score
=
stderr_score
**
2
...
...
@@ -557,11 +552,9 @@ def evaluate(
results
[
group
][
"samples"
]
=
total_size
def
print_tasks
(
task_hierarchy
,
task_order
,
task_version
,
task_group_alias
):
results_agg
=
collections
.
defaultdict
(
dict
)
groups_agg
=
collections
.
defaultdict
(
dict
)
for
group_name
,
task_list
in
task_hierarchy
.
items
():
order
=
task_order
[
group_name
]
results_agg
[
group_name
]
=
results
[
group_name
].
copy
()
results_agg
[
group_name
][
"tab"
]
=
order
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment