Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
fb436108
Commit
fb436108
authored
Sep 11, 2023
by
haileyschoelkopf
Browse files
make Task._config a public property
parent
0b99c7d2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
67 deletions
+72
-67
lm_eval/api/task.py
lm_eval/api/task.py
+72
-67
No files found.
lm_eval/api/task.py
View file @
fb436108
...
...
@@ -246,6 +246,11 @@ class Task(abc.ABC):
download_mode
=
download_mode
,
)
@
property
def
config
(
self
):
"""Returns the TaskConfig associated with this class."""
return
self
.
_config
@
abc
.
abstractmethod
def
has_training_docs
(
self
):
"""Whether the task has a training set"""
...
...
@@ -348,7 +353,7 @@ class Task(abc.ABC):
),
f
"Task dataset (path=
{
self
.
DATASET_PATH
}
, name=
{
self
.
DATASET_NAME
}
) must have valid or test docs!"
eval_logger
.
info
(
f
"Building contexts for task '
{
self
.
_
config
.
task
}
' on rank
{
rank
}
..."
f
"Building contexts for task '
{
self
.
config
.
task
}
' on rank
{
rank
}
..."
)
instances
=
[]
...
...
@@ -358,14 +363,14 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx
=
self
.
fewshot_context
(
doc
,
self
.
_
config
.
num_fewshot
,
self
.
config
.
num_fewshot
,
)
# TODO: we should override self.
_
config.repeats if doing greedy gen so users don't waste time+compute
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
inst
=
self
.
construct_requests
(
doc
=
doc
,
ctx
=
fewshot_ctx
,
metadata
=
(
self
.
_
config
[
"task"
],
doc_id
,
self
.
_
config
.
repeats
),
metadata
=
(
self
.
config
[
"task"
],
doc_id
,
self
.
config
.
repeats
),
)
if
not
isinstance
(
inst
,
list
):
...
...
@@ -453,9 +458,9 @@ class Task(abc.ABC):
if
num_fewshot
==
0
:
# always prepend the (possibly empty) task description
labeled_examples
=
self
.
_
config
.
description
labeled_examples
=
self
.
config
.
description
else
:
labeled_examples
=
self
.
_
config
.
description
+
self
.
sampler
.
get_context
(
labeled_examples
=
self
.
config
.
description
+
self
.
sampler
.
get_context
(
doc
,
num_fewshot
)
...
...
@@ -465,7 +470,7 @@ class Task(abc.ABC):
elif
type
(
example
)
==
list
:
return
[
labeled_examples
+
ex
for
ex
in
example
]
elif
type
(
example
)
==
int
:
if
self
.
_
config
.
doc_to_choice
is
not
None
:
if
self
.
config
.
doc_to_choice
is
not
None
:
choices
=
self
.
doc_to_choice
(
doc
)
return
labeled_examples
+
choices
[
example
]
else
:
...
...
@@ -488,7 +493,7 @@ class Task(abc.ABC):
"""
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (num_fewshot)
return
self
.
_
config
.
to_dict
()
return
self
.
config
.
to_dict
()
class
ConfigurableTask
(
Task
):
...
...
@@ -503,35 +508,35 @@ class ConfigurableTask(Task):
self
.
_config
=
self
.
CONFIG
# Use new configurations if there was no preconfiguration
if
self
.
_
config
is
None
:
if
self
.
config
is
None
:
self
.
_config
=
TaskConfig
(
**
config
)
# Overwrite configs
else
:
if
config
is
not
None
:
self
.
_config
.
__dict__
.
update
(
config
)
if
self
.
_
config
is
None
:
if
self
.
config
is
None
:
raise
ValueError
(
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if
self
.
_
config
.
output_type
is
not
None
:
assert
self
.
_
config
.
output_type
in
ALL_OUTPUT_TYPES
self
.
OUTPUT_TYPE
=
self
.
_
config
.
output_type
if
self
.
config
.
output_type
is
not
None
:
assert
self
.
config
.
output_type
in
ALL_OUTPUT_TYPES
self
.
OUTPUT_TYPE
=
self
.
config
.
output_type
if
self
.
_
config
.
dataset_path
is
not
None
:
self
.
DATASET_PATH
=
self
.
_
config
.
dataset_path
if
self
.
config
.
dataset_path
is
not
None
:
self
.
DATASET_PATH
=
self
.
config
.
dataset_path
if
self
.
_
config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
_
config
.
dataset_name
if
self
.
config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
_metric_fn_list
=
{}
self
.
_metric_fn_kwargs
=
{}
self
.
_aggregation_list
=
{}
self
.
_higher_is_better
=
{}
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
_
config
.
output_type
]
if
self
.
_
config
.
metric_list
is
None
:
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
config
.
output_type
]
if
self
.
config
.
metric_list
is
None
:
# TODO: handle this in TaskConfig.__post_init__ ?
for
metric_name
in
_metric_list
:
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
...
...
@@ -540,7 +545,7 @@ class ConfigurableTask(Task):
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
else
:
for
metric_config
in
self
.
_
config
.
metric_list
:
for
metric_config
in
self
.
config
.
metric_list
:
assert
"metric"
in
metric_config
metric_name
=
metric_config
[
"metric"
]
kwargs
=
{
...
...
@@ -549,7 +554,7 @@ class ConfigurableTask(Task):
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
]
}
if
self
.
_
config
.
process_results
is
not
None
:
if
self
.
config
.
process_results
is
not
None
:
self
.
_metric_fn_list
[
metric_name
]
=
None
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
elif
callable
(
metric_name
):
...
...
@@ -592,13 +597,13 @@ class ConfigurableTask(Task):
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
download
(
self
.
_
config
.
dataset_kwargs
)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
self
.
_fewshot_docs
=
None
if
self
.
_
config
.
filter_list
is
not
None
:
if
self
.
config
.
filter_list
is
not
None
:
self
.
_filters
=
[]
for
filter_config
in
self
.
_
config
.
filter_list
:
for
filter_config
in
self
.
config
.
filter_list
:
for
filter_pipeline
in
filter_config
:
filter_name
=
filter_config
[
"name"
]
filter_functions
=
filter_config
[
"filter"
]
...
...
@@ -613,10 +618,10 @@ class ConfigurableTask(Task):
else
:
self
.
_filters
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
if
self
.
_
config
.
use_prompt
is
not
None
:
eval_logger
.
info
(
f
"loading prompt
{
self
.
_
config
.
use_prompt
}
"
)
if
self
.
config
.
use_prompt
is
not
None
:
eval_logger
.
info
(
f
"loading prompt
{
self
.
config
.
use_prompt
}
"
)
self
.
prompt
=
get_prompt
(
self
.
_
config
.
use_prompt
,
self
.
DATASET_PATH
,
self
.
DATASET_NAME
self
.
config
.
use_prompt
,
self
.
DATASET_PATH
,
self
.
DATASET_NAME
)
else
:
self
.
prompt
=
None
...
...
@@ -643,7 +648,7 @@ class ConfigurableTask(Task):
test_text
=
self
.
doc_to_text
(
test_doc
)
test_target
=
self
.
doc_to_target
(
test_doc
)
if
self
.
_
config
.
doc_to_choice
is
not
None
:
if
self
.
config
.
doc_to_choice
is
not
None
:
test_choice
=
self
.
doc_to_choice
(
test_doc
)
if
type
(
test_choice
)
is
not
list
:
eval_logger
.
error
(
"doc_to_choice must return list"
)
...
...
@@ -671,7 +676,7 @@ class ConfigurableTask(Task):
for
choice
in
check_choices
:
choice_has_whitespace
=
True
if
" "
in
choice
else
False
delimiter_has_whitespace
=
(
True
if
" "
in
self
.
_
config
.
target_delimiter
else
False
True
if
" "
in
self
.
config
.
target_delimiter
else
False
)
if
delimiter_has_whitespace
and
choice_has_whitespace
:
...
...
@@ -692,67 +697,67 @@ class ConfigurableTask(Task):
)
def
has_training_docs
(
self
)
->
bool
:
if
self
.
_
config
.
training_split
is
not
None
:
if
self
.
config
.
training_split
is
not
None
:
return
True
else
:
return
False
def
has_validation_docs
(
self
)
->
bool
:
if
self
.
_
config
.
validation_split
is
not
None
:
if
self
.
config
.
validation_split
is
not
None
:
return
True
else
:
return
False
def
has_test_docs
(
self
)
->
bool
:
if
self
.
_
config
.
test_split
is
not
None
:
if
self
.
config
.
test_split
is
not
None
:
return
True
else
:
return
False
def
training_docs
(
self
)
->
datasets
.
Dataset
:
if
self
.
has_training_docs
():
if
self
.
_
config
.
process_docs
is
not
None
:
return
self
.
_
config
.
process_docs
(
self
.
dataset
[
self
.
_
config
.
training_split
]
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
training_split
]
)
return
self
.
dataset
[
self
.
_
config
.
training_split
]
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
datasets
.
Dataset
:
if
self
.
has_validation_docs
():
if
self
.
_
config
.
process_docs
is
not
None
:
return
self
.
_
config
.
process_docs
(
self
.
dataset
[
self
.
_
config
.
validation_split
]
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
validation_split
]
)
return
self
.
dataset
[
self
.
_
config
.
validation_split
]
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
datasets
.
Dataset
:
if
self
.
has_test_docs
():
if
self
.
_
config
.
process_docs
is
not
None
:
return
self
.
_
config
.
process_docs
(
self
.
dataset
[
self
.
_
config
.
test_split
])
return
self
.
dataset
[
self
.
_
config
.
test_split
]
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
return
self
.
dataset
[
self
.
config
.
test_split
]
def
fewshot_docs
(
self
):
if
self
.
_
config
.
fewshot_split
is
not
None
:
return
self
.
dataset
[
self
.
_
config
.
fewshot_split
]
if
self
.
config
.
fewshot_split
is
not
None
:
return
self
.
dataset
[
self
.
config
.
fewshot_split
]
else
:
if
self
.
_
config
.
num_fewshot
>
0
:
if
self
.
config
.
num_fewshot
>
0
:
eval_logger
.
warning
(
f
"Task '
{
self
.
_
config
.
task
}
': "
f
"Task '
{
self
.
config
.
task
}
': "
"num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule."
)
return
super
().
fewshot_docs
()
def
should_decontaminate
(
self
):
return
self
.
_
config
.
should_decontaminate
return
self
.
config
.
should_decontaminate
def
doc_to_decontamination_query
(
self
,
doc
):
if
self
.
_
config
.
should_decontaminate
:
if
self
.
_
config
.
doc_to_decontamination_query
in
self
.
features
:
return
doc
[
self
.
_
config
.
doc_to_decontamination_query
]
if
self
.
config
.
should_decontaminate
:
if
self
.
config
.
doc_to_decontamination_query
in
self
.
features
:
return
doc
[
self
.
config
.
doc_to_decontamination_query
]
else
:
return
ast
.
literal_eval
(
utils
.
apply_template
(
self
.
_
config
.
doc_to_decontamination_query
,
doc
)
utils
.
apply_template
(
self
.
config
.
doc_to_decontamination_query
,
doc
)
)
def
_process_doc
(
self
,
doc
):
...
...
@@ -771,13 +776,13 @@ class ConfigurableTask(Task):
if
self
.
prompt
is
not
None
:
doc_to_text
=
self
.
prompt
else
:
doc_to_text
=
self
.
_
config
.
doc_to_text
doc_to_text
=
self
.
config
.
doc_to_text
if
type
(
doc_to_text
)
==
int
:
return
doc_to_text
elif
type
(
doc_to_text
)
==
str
:
if
doc_to_text
in
self
.
features
:
# if self.
_
config.doc_to_choice is not None:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]]
# else:
return
doc
[
doc_to_text
]
...
...
@@ -796,7 +801,7 @@ class ConfigurableTask(Task):
return
applied_prompt
[
0
]
else
:
eval_logger
.
warning
(
"Applied prompt returns empty string"
)
return
self
.
_
config
.
fewshot_delimiter
return
self
.
config
.
fewshot_delimiter
else
:
print
(
type
(
doc_to_text
))
raise
TypeError
...
...
@@ -806,13 +811,13 @@ class ConfigurableTask(Task):
if
self
.
prompt
is
not
None
:
doc_to_target
=
self
.
prompt
else
:
doc_to_target
=
self
.
_
config
.
doc_to_target
doc_to_target
=
self
.
config
.
doc_to_target
if
type
(
doc_to_target
)
==
int
:
return
doc_to_target
elif
type
(
doc_to_target
)
==
str
:
if
doc_to_target
in
self
.
features
:
# if self.
_
config.doc_to_choice is not None:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]]
# else:
return
doc
[
doc_to_target
]
...
...
@@ -839,7 +844,7 @@ class ConfigurableTask(Task):
return
applied_prompt
[
1
]
else
:
eval_logger
.
warning
(
"Applied prompt returns empty string"
)
return
self
.
_
config
.
fewshot_delimiter
return
self
.
config
.
fewshot_delimiter
else
:
raise
TypeError
...
...
@@ -847,10 +852,10 @@ class ConfigurableTask(Task):
if
self
.
prompt
is
not
None
:
doc_to_choice
=
self
.
prompt
elif
self
.
_
config
.
doc_to_choice
is
None
:
elif
self
.
config
.
doc_to_choice
is
None
:
eval_logger
.
error
(
"doc_to_choice was called but not set in config"
)
else
:
doc_to_choice
=
self
.
_
config
.
doc_to_choice
doc_to_choice
=
self
.
config
.
doc_to_choice
if
type
(
doc_to_choice
)
==
str
:
return
ast
.
literal_eval
(
utils
.
apply_template
(
doc_to_choice
,
doc
))
...
...
@@ -871,8 +876,8 @@ class ConfigurableTask(Task):
# in multiple_choice tasks, this should be castable to an int corresponding to the index
# within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}.
if
self
.
_
config
.
gold_alias
is
not
None
:
doc_to_target
=
self
.
_
config
.
gold_alias
if
self
.
config
.
gold_alias
is
not
None
:
doc_to_target
=
self
.
config
.
gold_alias
else
:
return
self
.
doc_to_target
(
doc
)
...
...
@@ -896,7 +901,7 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
choices
=
self
.
doc_to_choice
(
doc
)
target_delimiter
=
self
.
_
config
.
target_delimiter
target_delimiter
=
self
.
config
.
target_delimiter
if
self
.
multiple_input
:
# If there are multiple inputs, choices are placed in the ctx
cont
=
self
.
doc_to_target
(
doc
)
...
...
@@ -938,7 +943,7 @@ class ConfigurableTask(Task):
return
request_list
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
arguments
=
(
ctx
,
self
.
_
config
.
generation_kwargs
)
arguments
=
(
ctx
,
self
.
config
.
generation_kwargs
)
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
arguments
,
idx
=
0
,
**
kwargs
...
...
@@ -946,8 +951,8 @@ class ConfigurableTask(Task):
def
process_results
(
self
,
doc
,
results
):
if
callable
(
self
.
_
config
.
process_results
):
return
self
.
_
config
.
process_results
(
doc
,
results
)
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
use_metric
=
list
(
self
.
_metric_fn_list
.
keys
())
...
...
@@ -1036,7 +1041,7 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
gold
=
self
.
doc_to_target
(
doc
)
if
self
.
_
config
.
doc_to_choice
is
not
None
:
if
self
.
config
.
doc_to_choice
is
not
None
:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices
=
self
.
doc_to_choice
(
doc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment