Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
61d11b8a
Commit
61d11b8a
authored
Aug 01, 2025
by
Baber
Browse files
require multiple_inputs and multiple_targets to be explicitly set in taskconfig
parent
57b86c47
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
48 deletions
+64
-48
lm_eval/api/task.py
lm_eval/api/task.py
+59
-48
lm_eval/tasks/winogrande/default.yaml
lm_eval/tasks/winogrande/default.yaml
+1
-0
lm_eval/utils.py
lm_eval/utils.py
+4
-0
No files found.
lm_eval/api/task.py
View file @
61d11b8a
...
@@ -39,6 +39,7 @@ from lm_eval.api.registry import (
...
@@ -39,6 +39,7 @@ from lm_eval.api.registry import (
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.prompts
import
get_prompt
from
lm_eval.prompts
import
get_prompt
from
lm_eval.utils
import
validate_index
ALL_OUTPUT_TYPES
=
[
ALL_OUTPUT_TYPES
=
[
...
@@ -96,6 +97,8 @@ class TaskConfig(dict):
...
@@ -96,6 +97,8 @@ class TaskConfig(dict):
should_decontaminate
:
bool
=
False
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
multiple_inputs
:
bool
=
False
multiple_targets
:
bool
=
False
metadata
:
Optional
[
dict
]
=
(
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
)
...
@@ -767,6 +770,12 @@ class ConfigurableTask(Task):
...
@@ -767,6 +770,12 @@ class ConfigurableTask(Task):
)
)
self
.
OUTPUT_TYPE
=
self
.
config
.
output_type
self
.
OUTPUT_TYPE
=
self
.
config
.
output_type
self
.
multiple_targets
=
self
.
config
.
multiple_targets
self
.
multiple_inputs
=
self
.
config
.
multiple_inputs
assert
not
(
self
.
multiple_targets
and
self
.
multiple_inputs
),
(
"Cannot have both multiple_targets and multiple_inputs"
)
if
self
.
config
.
doc_to_image
is
not
None
:
if
self
.
config
.
doc_to_image
is
not
None
:
# mark the task as requiring multimodality.
# mark the task as requiring multimodality.
self
.
MULTIMODAL
=
True
self
.
MULTIMODAL
=
True
...
@@ -923,50 +932,54 @@ class ConfigurableTask(Task):
...
@@ -923,50 +932,54 @@ class ConfigurableTask(Task):
# Test One Doc
# Test One Doc
self
.
features
=
list
(
self
.
task_docs
.
features
.
keys
())
self
.
features
=
list
(
self
.
task_docs
.
features
.
keys
())
self
.
multiple_input
=
0
self
.
multiple_target
=
0
test_doc
=
self
.
task_docs
[
0
]
test_doc
=
self
.
task_docs
[
0
]
test_text
=
self
.
doc_to_text
(
test_doc
)
test_text
=
self
.
doc_to_text
(
test_doc
)
test_target
=
self
.
doc_to_target
(
test_doc
)
test_target
=
self
.
doc_to_target
(
test_doc
)
if
self
.
config
.
doc_to_choice
is
not
None
:
if
self
.
config
.
doc_to_choice
is
not
None
:
test_choice
=
self
.
doc_to_choice
(
test_doc
)
test_choice
=
self
.
doc_to_choice
(
test_doc
)
if
not
isinstance
(
test_choice
,
list
):
if
self
.
multiple_inputs
:
eval_logger
.
error
(
"doc_to_choice must return list"
)
# we require:
# doc_to_text: int
# doc_to_choice: list
# doc_to_target: str
# e.g. text: 1, choice: [Maria was better than Sarah, Sarah was better than Sarah]
# target: so she was envious
assert
isinstance
(
test_text
,
int
),
(
"doc_to_text must return int for multiple inputs"
)
assert
isinstance
(
test_target
,
str
),
(
"doc_to_target must return str for multiple inputs"
)
assert
self
.
config
.
output_type
!=
"generate_until"
,
(
"Only multiple-choice tasks can be used with multiple inputs"
)
test_text
=
test_choice
[
0
]
elif
self
.
multiple_targets
:
# we require:
# doc_to_text: str
# doc_to_choice: list
# doc_to_target: list
assert
isinstance
(
test_target
,
(
list
,
tuple
)),
(
"doc_to_target must be an iterable for multiple targets"
)
test_target
=
test_target
[
0
]
else
:
else
:
num_choice
=
len
(
test_choice
)
assert
isinstance
(
test_target
,
int
),
(
"doc_to_target must return int for multiple choices"
if
isinstance
(
test_text
,
int
):
eval_logger
.
debug
(
"doc_to_text returned an int. Assuming multiple inputs."
)
)
self
.
multiple_input
=
num_choice
else
:
test_choice
=
None
if
isinstance
(
test_target
,
list
):
eval_logger
.
debug
(
"doc_to_target returned a list. Assuming multiple targets."
)
self
.
multiple_target
=
len
(
test_target
)
else
:
if
(
isinstance
(
test_target
,
int
))
and
(
test_choice
is
not
None
):
test_target
=
test_choice
[
test_target
]
test_target
=
test_choice
[
test_target
]
else
:
test_target
=
str
(
test_target
)
if
test_choice
is
not
None
:
assert
hasattr
(
test_choice
,
"__iter__"
)
and
not
isinstance
(
check_choices
=
test_choice
test_choice
,
(
str
,
bytes
)
else
:
),
"doc_to_choice must be an iterable!"
check_choices
=
[
test_target
]
if
self
.
config
.
doc_to_choice
is
not
None
:
for
choice
in
test_choice
:
for
choice
in
check_choices
:
choice_has_whitespace
=
choice
[
0
].
isspace
()
choice_has_whitespace
=
True
if
choice
[
0
].
isspace
()
else
False
delimiter_has_whitespace
=
(
delimiter_has_whitespace
=
(
True
self
.
config
.
target_delimiter
.
rstrip
()
if
self
.
config
.
target_delimiter
.
rstrip
()
!=
self
.
config
.
target_delimiter
!=
self
.
config
.
target_delimiter
else
False
)
)
if
delimiter_has_whitespace
and
choice_has_whitespace
:
if
delimiter_has_whitespace
and
choice_has_whitespace
:
...
@@ -1162,7 +1175,7 @@ class ConfigurableTask(Task):
...
@@ -1162,7 +1175,7 @@ class ConfigurableTask(Task):
example
=
self
.
doc_to_text
(
doc
)
example
=
self
.
doc_to_text
(
doc
)
if
apply_chat_template
:
if
apply_chat_template
:
if
self
.
multiple_input
:
if
self
.
multiple_input
s
:
# TODO: append prefill?
# TODO: append prefill?
if
not
labeled_examples
:
if
not
labeled_examples
:
return
""
return
""
...
@@ -1222,7 +1235,7 @@ class ConfigurableTask(Task):
...
@@ -1222,7 +1235,7 @@ class ConfigurableTask(Task):
if
gen_prefix
is
not
None
if
gen_prefix
is
not
None
else
""
else
""
)
)
if
self
.
multiple_input
:
if
self
.
multiple_input
s
:
return
labeled_examples
return
labeled_examples
if
isinstance
(
example
,
str
):
if
isinstance
(
example
,
str
):
return
labeled_examples
+
example
+
prefix
return
labeled_examples
+
example
+
prefix
...
@@ -1371,7 +1384,7 @@ class ConfigurableTask(Task):
...
@@ -1371,7 +1384,7 @@ class ConfigurableTask(Task):
if
doc_to_choice
in
self
.
features
:
if
doc_to_choice
in
self
.
features
:
return
doc
[
doc_to_choice
]
return
doc
[
doc_to_choice
]
else
:
else
:
return
ast
.
literal_eval
(
utils
.
apply_template
(
doc_to_choice
,
doc
)
)
return
utils
.
apply_template
(
doc_to_choice
,
doc
)
elif
isinstance
(
doc_to_choice
,
list
):
elif
isinstance
(
doc_to_choice
,
list
):
return
doc_to_choice
return
doc_to_choice
elif
isinstance
(
doc_to_choice
,
dict
):
elif
isinstance
(
doc_to_choice
,
dict
):
...
@@ -1454,7 +1467,7 @@ class ConfigurableTask(Task):
...
@@ -1454,7 +1467,7 @@ class ConfigurableTask(Task):
target_delimiter
=
self
.
config
.
target_delimiter
target_delimiter
=
self
.
config
.
target_delimiter
if
apply_chat_template
:
if
apply_chat_template
:
target_delimiter
=
""
target_delimiter
=
""
if
self
.
multiple_input
:
if
self
.
multiple_input
s
:
# If there are multiple inputs, choices are placed in the ctx
# If there are multiple inputs, choices are placed in the ctx
# apply chat_template to choices if apply_chat_template
# apply chat_template to choices if apply_chat_template
cont
=
self
.
doc_to_target
(
doc
)
cont
=
self
.
doc_to_target
(
doc
)
...
@@ -1595,24 +1608,22 @@ class ConfigurableTask(Task):
...
@@ -1595,24 +1608,22 @@ class ConfigurableTask(Task):
pred
=
np
.
argmax
(
lls
)
pred
=
np
.
argmax
(
lls
)
pred_norm
=
np
.
argmax
(
lls
/
completion_len
)
pred_norm
=
np
.
argmax
(
lls
/
completion_len
)
if
self
.
multiple_input
:
gold
=
(
gold
=
self
.
doc_to_text
(
doc
)
self
.
doc_to_text
(
doc
)
else
:
if
not
self
.
multiple_inputs
gold
=
self
.
doc_to_target
(
doc
)
else
self
.
doc_to_text
(
doc
)
)
gold_index_error
=
False
if
isinstance
(
gold
,
list
):
if
isinstance
(
gold
,
list
):
gold
=
[
i
if
i
<
len
(
choices
)
else
-
100
for
i
in
gold
]
gold
=
[
validate_index
(
g
,
len
(
choices
))
for
g
in
gold
]
if
-
100
in
gold
:
gold_index_error
=
-
100
in
gold
gold_index_error
=
True
else
:
else
:
if
isinstance
(
gold
,
int
):
if
isinstance
(
gold
,
int
):
gold
=
gold
if
gold
<
len
(
choices
)
else
-
100
gold
=
validate_index
(
gold
,
len
(
choices
)
)
elif
isinstance
(
gold
,
str
):
elif
isinstance
(
gold
,
str
):
gold
=
choices
.
index
(
gold
)
if
gold
in
choices
else
-
100
gold
=
choices
.
index
(
gold
)
if
gold
in
choices
else
-
100
if
gold
==
-
100
:
gold_index_error
=
gold
==
-
100
gold_index_error
=
True
if
gold_index_error
:
if
gold_index_error
:
eval_logger
.
warning
(
eval_logger
.
warning
(
...
@@ -1620,7 +1631,7 @@ class ConfigurableTask(Task):
...
@@ -1620,7 +1631,7 @@ class ConfigurableTask(Task):
f
"Sample:
\n\n
{
doc
}
\n\n
"
f
"Sample:
\n\n
{
doc
}
\n\n
"
)
)
if
self
.
multiple_target
:
if
self
.
multiple_target
s
:
acc
=
1.0
if
pred
in
gold
else
0.0
acc
=
1.0
if
pred
in
gold
else
0.0
acc_norm
=
1.0
if
pred_norm
in
gold
else
0.0
acc_norm
=
1.0
if
pred_norm
in
gold
else
0.0
exact_match
=
int
(
any
([
is_greedy
[
i
]
if
i
!=
-
100
else
0
for
i
in
gold
]))
exact_match
=
int
(
any
([
is_greedy
[
i
]
if
i
!=
-
100
else
0
for
i
in
gold
]))
...
...
lm_eval/tasks/winogrande/default.yaml
View file @
61d11b8a
...
@@ -9,6 +9,7 @@ doc_to_target: !function preprocess_winogrande.doc_to_target
...
@@ -9,6 +9,7 @@ doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_choice
:
!function
preprocess_winogrande.doc_to_choice
doc_to_choice
:
!function
preprocess_winogrande.doc_to_choice
should_decontaminate
:
true
should_decontaminate
:
true
doc_to_decontamination_query
:
sentence
doc_to_decontamination_query
:
sentence
multiple_inputs
:
true
metric_list
:
metric_list
:
-
metric
:
acc
-
metric
:
acc
aggregation
:
mean
aggregation
:
mean
...
...
lm_eval/utils.py
View file @
61d11b8a
...
@@ -623,3 +623,7 @@ def hash_dict_images(data_dict):
...
@@ -623,3 +623,7 @@ def hash_dict_images(data_dict):
if
importlib
.
util
.
find_spec
(
"PIL"
)
if
importlib
.
util
.
find_spec
(
"PIL"
)
else
data_dict
else
data_dict
)
)
def
validate_index
(
index
:
int
,
length
:
int
)
->
int
:
return
index
if
index
<
length
else
-
100
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment