Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9b192374
Commit
9b192374
authored
Jun 30, 2025
by
Baber
Browse files
update type hints
parent
cb8dfe63
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
25 deletions
+40
-25
lm_eval/api/task.py
lm_eval/api/task.py
+40
-25
No files found.
lm_eval/api/task.py
View file @
9b192374
...
@@ -63,11 +63,10 @@ class MetricConfig:
...
@@ -63,11 +63,10 @@ class MetricConfig:
aggregation_fn
:
Optional
[
Callable
]
=
None
aggregation_fn
:
Optional
[
Callable
]
=
None
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
hf_evaluate
:
bool
=
False
sample_metric
:
bool
=
True
is_elementwise
:
bool
=
True
is_elementwise
:
bool
=
True
@
cached_property
@
cached_property
def
metric_name
s
(
self
)
->
str
:
def
metric_name
(
self
)
->
str
:
return
self
.
name
return
self
.
name
@
cached_property
@
cached_property
...
@@ -82,6 +81,12 @@ class MetricConfig:
...
@@ -82,6 +81,12 @@ class MetricConfig:
return
is_higher_better
(
self
.
name
)
return
is_higher_better
(
self
.
name
)
return
self
.
higher_is_better
return
self
.
higher_is_better
def
calculate_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
return
self
.
fn
(
*
args
,
**
{
**
self
.
kwargs
,
**
kwargs
})
@
dataclass
@
dataclass
class
RepeatConfig
:
class
RepeatConfig
:
...
@@ -108,6 +113,16 @@ class FewshotConfig:
...
@@ -108,6 +113,16 @@ class FewshotConfig:
process_docs
:
Optional
[
Callable
]
=
None
process_docs
:
Optional
[
Callable
]
=
None
@
dataclass
class
DatasetConfig
:
"""Encapsulates information about a dataset."""
dataset_path
:
Optional
[
str
]
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_kwargs
:
Optional
[
dict
]
=
None
custom_dataset
:
Optional
[
Callable
]
=
None
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
...
@@ -132,8 +147,8 @@ class TaskConfig(dict):
...
@@ -132,8 +147,8 @@ class TaskConfig(dict):
process_docs
:
Optional
[
Callable
]
=
None
process_docs
:
Optional
[
Callable
]
=
None
doc_to_text
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_text
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_target
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_target
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_image
:
Union
[
Callable
,
str
]
=
None
doc_to_image
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_audio
:
Union
[
Callable
,
str
]
=
None
doc_to_audio
:
Union
[
Callable
,
str
,
None
]
=
None
unsafe_code
:
bool
=
False
unsafe_code
:
bool
=
False
doc_to_choice
:
Optional
[
Union
[
Callable
,
str
,
dict
,
list
]]
=
None
doc_to_choice
:
Optional
[
Union
[
Callable
,
str
,
dict
,
list
]]
=
None
process_results
:
Optional
[
Union
[
Callable
,
str
]]
=
None
process_results
:
Optional
[
Union
[
Callable
,
str
]]
=
None
...
@@ -466,17 +481,17 @@ class Task(abc.ABC):
...
@@ -466,17 +481,17 @@ class Task(abc.ABC):
return
self
.
_config
return
self
.
_config
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
has_training_docs
(
self
):
def
has_training_docs
(
self
)
->
bool
:
"""Whether the task has a training set"""
"""Whether the task has a training set"""
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
has_validation_docs
(
self
):
def
has_validation_docs
(
self
)
->
bool
:
"""Whether the task has a validation set"""
"""Whether the task has a validation set"""
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
has_test_docs
(
self
):
def
has_test_docs
(
self
)
->
bool
:
"""Whether the task has a test set"""
"""Whether the task has a test set"""
pass
pass
...
@@ -536,7 +551,7 @@ class Task(abc.ABC):
...
@@ -536,7 +551,7 @@ class Task(abc.ABC):
"""
"""
return
self
.
_instances
return
self
.
_instances
def
fewshot_examples
(
self
,
k
,
rnd
):
def
fewshot_examples
(
self
,
k
,
rnd
)
->
Iterable
[
dict
]
:
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
training_docs
())
self
.
_training_docs
=
list
(
self
.
training_docs
())
...
@@ -548,11 +563,11 @@ class Task(abc.ABC):
...
@@ -548,11 +563,11 @@ class Task(abc.ABC):
)
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
)
->
str
:
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
)
->
Union
[
str
,
int
]
:
pass
pass
# not an abstractmethod because not every language-only task has to implement this
# not an abstractmethod because not every language-only task has to implement this
...
@@ -562,7 +577,7 @@ class Task(abc.ABC):
...
@@ -562,7 +577,7 @@ class Task(abc.ABC):
def
doc_to_audio
(
self
,
doc
):
def
doc_to_audio
(
self
,
doc
):
raise
NotImplementedError
raise
NotImplementedError
def
doc_to_prefix
(
self
,
doc
):
def
doc_to_prefix
(
self
,
doc
)
->
str
:
return
""
return
""
def
build_all_requests
(
def
build_all_requests
(
...
@@ -734,12 +749,12 @@ class Task(abc.ABC):
...
@@ -734,12 +749,12 @@ class Task(abc.ABC):
return
getattr
(
self
.
_config
,
key
,
None
)
return
getattr
(
self
.
_config
,
key
,
None
)
@
classmethod
@
classmethod
def
count_bytes
(
cls
,
doc
):
def
count_bytes
(
cls
,
doc
)
->
int
:
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return
len
(
doc
.
encode
(
"utf-8"
))
return
len
(
doc
.
encode
(
"utf-8"
))
@
classmethod
@
classmethod
def
count_words
(
cls
,
doc
):
def
count_words
(
cls
,
doc
)
->
int
:
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
...
@@ -853,7 +868,7 @@ class Task(abc.ABC):
...
@@ -853,7 +868,7 @@ class Task(abc.ABC):
self
.
sampler
.
rnd
=
self
.
fewshot_rnd
self
.
sampler
.
rnd
=
self
.
fewshot_rnd
@
property
@
property
def
eval_docs
(
self
)
->
Union
[
datasets
.
Dataset
,
List
[
dict
]]:
def
eval_docs
(
self
)
->
Union
[
datasets
.
Dataset
,
Iterable
[
dict
]]:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
return
self
.
test_docs
()
return
self
.
test_docs
()
elif
self
.
has_validation_docs
():
elif
self
.
has_validation_docs
():
...
@@ -952,7 +967,7 @@ class ConfigurableTask(Task):
...
@@ -952,7 +967,7 @@ class ConfigurableTask(Task):
if
self
.
config
.
dataset_name
is
not
None
:
if
self
.
config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
metric_list
:
list
[
MetricConfig
]
=
self
.
_
config
.
get_metrics
()
self
.
metric_list
:
list
[
MetricConfig
]
=
self
.
config
.
get_metrics
()
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
...
@@ -1088,7 +1103,7 @@ class ConfigurableTask(Task):
...
@@ -1088,7 +1103,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
False
return
False
def
training_docs
(
self
)
->
datasets
.
Dataset
:
def
training_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
if
self
.
has_training_docs
():
if
self
.
has_training_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -1096,7 +1111,7 @@ class ConfigurableTask(Task):
...
@@ -1096,7 +1111,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
training_split
]
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
datasets
.
Dataset
:
def
validation_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
if
self
.
has_validation_docs
():
if
self
.
has_validation_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -1104,7 +1119,7 @@ class ConfigurableTask(Task):
...
@@ -1104,7 +1119,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
validation_split
]
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
datasets
.
Dataset
:
def
test_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
...
@@ -1170,7 +1185,7 @@ class ConfigurableTask(Task):
...
@@ -1170,7 +1185,7 @@ class ConfigurableTask(Task):
fewshot_as_multiturn
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
chat_template
:
Optional
[
Callable
]
=
None
,
chat_template
:
Optional
[
Callable
]
=
None
,
gen_prefix
:
Optional
[
str
]
=
None
,
gen_prefix
:
Optional
[
str
]
=
None
,
)
->
Union
[
str
,
List
[
str
]]:
)
->
Union
[
str
,
List
[
str
]
,
None
]:
"""Returns a fewshot context string that is made up of a prepended description
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
...
@@ -1457,7 +1472,7 @@ class ConfigurableTask(Task):
...
@@ -1457,7 +1472,7 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_image
(
self
,
doc
:
Any
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_image
(
self
,
doc
:
Any
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
if
doc_to_image
is
not
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
doc_to_image
=
doc_to_image
elif
self
.
config
.
doc_to_image
is
not
None
:
elif
self
.
config
.
doc_to_image
is
not
None
:
...
@@ -1480,7 +1495,7 @@ class ConfigurableTask(Task):
...
@@ -1480,7 +1495,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
if
doc_to_audio
is
not
None
:
if
doc_to_audio
is
not
None
:
doc_to_audio
=
doc_to_audio
doc_to_audio
=
doc_to_audio
elif
self
.
config
.
doc_to_audio
is
not
None
:
elif
self
.
config
.
doc_to_audio
is
not
None
:
...
@@ -1503,7 +1518,7 @@ class ConfigurableTask(Task):
...
@@ -1503,7 +1518,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
def
doc_to_prefix
(
self
,
doc
):
def
doc_to_prefix
(
self
,
doc
)
->
Optional
[
str
]
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
return
doc
[
gen_prefix
]
...
@@ -1550,7 +1565,7 @@ class ConfigurableTask(Task):
...
@@ -1550,7 +1565,7 @@ class ConfigurableTask(Task):
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if
"acc_mutual_info"
in
[
m
.
metric_name
s
for
m
in
self
.
metric_list
]:
if
"acc_mutual_info"
in
[
m
.
metric_name
for
m
in
self
.
metric_list
]:
# if we are calculating multiple choice accuracy
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...
@@ -1617,7 +1632,7 @@ class ConfigurableTask(Task):
...
@@ -1617,7 +1632,7 @@ class ConfigurableTask(Task):
return
self
.
config
.
process_results
(
doc
,
results
)
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
result_dict
=
{}
use_metric
=
list
(
m
.
metric_name
s
for
m
in
self
.
metric_list
)
use_metric
=
list
(
m
.
metric_name
for
m
in
self
.
metric_list
)
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
results
=
results
[
0
]
results
=
results
[
0
]
ll
,
is_greedy
=
results
ll
,
is_greedy
=
results
...
@@ -1815,7 +1830,7 @@ class ConfigurableTask(Task):
...
@@ -1815,7 +1830,7 @@ class ConfigurableTask(Task):
return
getattr
(
self
.
_config
,
key
,
None
)
return
getattr
(
self
.
_config
,
key
,
None
)
@
property
@
property
def
task_name
(
self
)
->
Any
:
def
task_name
(
self
)
->
Optional
[
str
]
:
return
getattr
(
self
.
config
,
"task"
,
None
)
return
getattr
(
self
.
config
,
"task"
,
None
)
def
__repr__
(
self
):
def
__repr__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment