Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
db5dff9c
Commit
db5dff9c
authored
Jul 03, 2025
by
Baber
Browse files
type hints
parent
023bfe0d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
105 additions
and
96 deletions
+105
-96
lm_eval/api/model.py
lm_eval/api/model.py
+30
-18
lm_eval/api/registry.py
lm_eval/api/registry.py
+6
-2
lm_eval/api/task.py
lm_eval/api/task.py
+69
-76
No files found.
lm_eval/api/model.py
View file @
db5dff9c
...
@@ -3,7 +3,7 @@ import hashlib
...
@@ -3,7 +3,7 @@ import hashlib
import
json
import
json
import
logging
import
logging
import
os
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Type
,
TypeVar
,
Union
import
transformers
import
transformers
from
sqlitedict
import
SqliteDict
from
sqlitedict
import
SqliteDict
...
@@ -12,6 +12,10 @@ from tqdm import tqdm
...
@@ -12,6 +12,10 @@ from tqdm import tqdm
from
lm_eval
import
utils
from
lm_eval
import
utils
if
TYPE_CHECKING
:
from
lm_eval.api.instance
import
Instance
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
T
=
TypeVar
(
"T"
,
bound
=
"LM"
)
T
=
TypeVar
(
"T"
,
bound
=
"LM"
)
...
@@ -30,7 +34,7 @@ class LM(abc.ABC):
...
@@ -30,7 +34,7 @@ class LM(abc.ABC):
self
.
cache_hook
=
CacheHook
(
None
)
self
.
cache_hook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
def
loglikelihood
(
self
,
requests
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
LM calls whenever possible.
...
@@ -55,7 +59,7 @@ class LM(abc.ABC):
...
@@ -55,7 +59,7 @@ class LM(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
)
->
L
ist
[
float
]:
def
loglikelihood_rolling
(
self
,
requests
)
->
l
ist
[
float
]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...
@@ -97,7 +101,7 @@ class LM(abc.ABC):
...
@@ -97,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
# TODO: Add an optional max length
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
)
->
L
ist
[
str
]:
def
generate_until
(
self
,
requests
)
->
l
ist
[
str
]:
"""Generate greedily until a stopping sequence
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
:param requests: list[Instance]
...
@@ -114,7 +118,7 @@ class LM(abc.ABC):
...
@@ -114,7 +118,7 @@ class LM(abc.ABC):
pass
pass
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
=
True
)
->
str
:
)
->
str
:
"""
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
...
@@ -173,14 +177,14 @@ class LM(abc.ABC):
...
@@ -173,14 +177,14 @@ class LM(abc.ABC):
return
cls
(
**
arg_dict
,
**
additional_config
)
return
cls
(
**
arg_dict
,
**
additional_config
)
@
property
@
property
def
rank
(
self
):
def
rank
(
self
)
->
int
:
# used in the case of parallelism. Hardcoded to
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
# not support multi-device parallelism nor expect it.
return
self
.
_rank
return
self
.
_rank
@
property
@
property
def
world_size
(
self
):
def
world_size
(
self
)
->
int
:
# used in the case of parallelism. Hardcoded to
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
# not support multi-device parallelism nor expect it.
...
@@ -230,7 +234,7 @@ class CacheHook:
...
@@ -230,7 +234,7 @@ class CacheHook:
class
CachingLM
:
class
CachingLM
:
def
__init__
(
self
,
lm
,
cache_db
)
->
None
:
def
__init__
(
self
,
lm
:
"LM"
,
cache_db
:
str
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
:param lm: LM
...
@@ -253,7 +257,7 @@ class CachingLM:
...
@@ -253,7 +257,7 @@ class CachingLM:
eval_logger
.
debug
(
f
"Passing through attribute '
{
attr
}
' to underlying LM"
)
eval_logger
.
debug
(
f
"Passing through attribute '
{
attr
}
' to underlying LM"
)
return
lm_attr
return
lm_attr
def
fn
(
requests
)
:
def
fn
(
requests
:
list
[
Instance
])
->
list
[
Instance
]
:
res
=
[]
res
=
[]
remaining_reqs
=
[]
remaining_reqs
=
[]
warned
=
False
warned
=
False
...
@@ -322,28 +326,35 @@ class TemplateLM(LM):
...
@@ -322,28 +326,35 @@ class TemplateLM(LM):
@
property
@
property
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
eot_token_id
(
self
):
def
eot_token_id
(
self
)
->
int
:
pass
pass
@
property
@
property
def
prefix_token_id
(
self
):
def
prefix_token_id
(
self
)
->
int
:
# it is used as prefix for loglikelihood
# it is used as prefix for loglikelihood
return
self
.
eot_token_id
return
self
.
eot_token_id
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
tok_encode
(
self
,
string
:
str
,
**
kwargs
)
->
L
ist
[
int
]:
def
tok_encode
(
self
,
string
:
str
,
**
kwargs
)
->
l
ist
[
int
]:
"""
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
Tokenize a string using the model's tokenizer and return a list of token IDs.
"""
"""
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_loglikelihood_tokens
(
self
,
requests
,
**
kwargs
)
->
List
[
Tuple
[
float
,
bool
]]:
def
_loglikelihood_tokens
(
self
,
requests
:
list
[
Instance
],
**
kwargs
)
->
list
[
tuple
[
float
,
bool
]]:
pass
pass
def
_encode_pair
(
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
self
,
context
:
str
,
continuation
:
str
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
)
->
tuple
[
list
[
int
],
list
[
int
]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
"""
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
if
n_spaces
>
0
:
if
n_spaces
>
0
:
continuation
=
context
[
-
n_spaces
:]
+
continuation
continuation
=
context
[
-
n_spaces
:]
+
continuation
...
@@ -364,8 +375,8 @@ class TemplateLM(LM):
...
@@ -364,8 +375,8 @@ class TemplateLM(LM):
return
context_enc
,
continuation_enc
return
context_enc
,
continuation_enc
def
loglikelihood
(
def
loglikelihood
(
self
,
requests
,
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
Instance
]
,
disable_tqdm
:
bool
=
False
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
new_reqs
=
[]
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
if
context
==
""
:
...
@@ -384,15 +395,16 @@ class TemplateLM(LM):
...
@@ -384,15 +395,16 @@ class TemplateLM(LM):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
,
disable_tqdm
:
bool
=
False
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
l
ist
[
str
]:
pass
pass
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
"""
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
...
...
lm_eval/api/registry.py
View file @
db5dff9c
...
@@ -8,6 +8,10 @@ if TYPE_CHECKING:
...
@@ -8,6 +8,10 @@ if TYPE_CHECKING:
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
MODEL_REGISTRY
=
{}
MODEL_REGISTRY
=
{}
DEFAULTS
=
{
"model"
:
{
"max_length"
:
2048
},
"tasks"
:
{
"generate_until"
:
{
"max_length"
:
2048
}},
}
def
register_model
(
*
names
):
def
register_model
(
*
names
):
...
@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
...
@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
)
->
Optional
[
bool
]:
def
is_higher_better
(
metric_name
:
str
)
->
Optional
[
bool
]:
try
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
except
KeyError
:
...
@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
...
@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
)
)
def
register_filter
(
name
):
def
register_filter
(
name
:
str
):
def
decorate
(
cls
):
def
decorate
(
cls
):
if
name
in
FILTER_REGISTRY
:
if
name
in
FILTER_REGISTRY
:
eval_logger
.
info
(
eval_logger
.
info
(
...
...
lm_eval/api/task.py
View file @
db5dff9c
...
@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
...
@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
from
functools
import
cached_property
from
functools
import
cached_property
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
Any
,
Any
,
Dict
,
Dict
,
Iterable
,
Iterable
,
...
@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
...
@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
"generate_until"
,
"generate_until"
,
]
]
if
TYPE_CHECKING
:
from
lm_eval.api.filter
import
FilterEnsemble
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -81,7 +86,7 @@ class MetricConfig:
...
@@ -81,7 +86,7 @@ class MetricConfig:
return
is_higher_better
(
self
.
name
)
return
is_higher_better
(
self
.
name
)
return
self
.
higher_is_better
return
self
.
higher_is_better
def
c
alcula
te_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
def
c
ompu
te_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Calculates the metric using the provided function and arguments."""
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
...
@@ -99,7 +104,7 @@ class RepeatConfig:
...
@@ -99,7 +104,7 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
"""Encapsulates information about a single repeat."""
repeats
:
int
=
1
repeats
:
int
=
1
metric_fn
:
Optional
[
Callable
]
=
None
metric_fn
:
Optional
[
Callable
]
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Optional
[
dict
]
=
None
...
@@ -246,15 +251,15 @@ class TaskConfig(dict):
...
@@ -246,15 +251,15 @@ class TaskConfig(dict):
output_type
:
OutputType
=
"generate_until"
output_type
:
OutputType
=
"generate_until"
generation_kwargs
:
Optional
[
dict
]
=
None
generation_kwargs
:
Optional
[
dict
]
=
None
repeats
:
int
=
1
repeats
:
int
=
1
filter_list
:
Optional
[
Union
[
str
,
lis
t
]]
=
None
filter_list
:
Optional
[
list
[
dic
t
]]
=
None
should_decontaminate
:
bool
=
False
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
)
_metric_list
=
None
_metric_list
:
list
[
MetricConfig
]
=
None
_filter_list
=
None
_filter_list
:
list
[
FilterConfig
]
=
None
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
if
self
.
generation_kwargs
is
not
None
:
if
self
.
generation_kwargs
is
not
None
:
...
@@ -289,16 +294,13 @@ class TaskConfig(dict):
...
@@ -289,16 +294,13 @@ class TaskConfig(dict):
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
)
if
self
.
metric_list
is
not
None
:
if
self
.
metric_list
and
not
all
(
"metric"
in
cfg
for
cfg
in
self
.
metric_list
):
for
metric_config
in
self
.
metric_list
:
raise
ValueError
(
"each entry in metric_list must include a 'metric' key"
)
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
metrics
=
[]
metrics
=
[]
if
self
.
metric_list
is
None
:
if
self
.
metric_list
is
None
:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
eval_logger
.
info
(
eval_logger
.
info
(
f
"No metrics defined in config, using default metrics for
{
self
.
output_type
}
=
{
_metric_list
}
"
f
"No metrics defined in config, using default metrics for
{
self
.
output_type
}
=
{
_metric_list
}
"
...
@@ -313,11 +315,8 @@ class TaskConfig(dict):
...
@@ -313,11 +315,8 @@ class TaskConfig(dict):
for
metric_name
in
_metric_list
for
metric_name
in
_metric_list
)
)
else
:
else
:
# ---------- 2. How will the samples be evaluated ----------
for
metric_config
in
self
.
metric_list
:
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name
=
metric_config
[
"metric"
]
metric_name
=
metric_config
[
"metric"
]
_metric_fn_kwargs
=
{
_metric_fn_kwargs
=
{
key
:
metric_config
[
key
]
key
:
metric_config
[
key
]
...
@@ -379,34 +378,30 @@ class TaskConfig(dict):
...
@@ -379,34 +378,30 @@ class TaskConfig(dict):
)
)
return
metrics
return
metrics
def
get_filters
(
self
):
def
get_filters
(
self
)
->
list
[
"FilterEnsemble"
]:
if
self
.
filter_list
is
not
None
:
if
not
self
.
filter_list
:
_filter_list
=
[]
if
isinstance
(
self
.
filter_list
,
dict
):
for
filter_config
in
self
.
filter_list
:
_filter_list
.
append
(
build_filter_ensemble
(
filter_name
=
filter_config
[
"name"
],
components
=
[
[
{
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
}
]
for
function
in
filter_config
[
"filter"
]
],
)
)
else
:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger
.
debug
(
eval_logger
.
debug
(
"No custom filters defined
. Using default
'take_first'
filter
for handling repeats."
"No custom filters defined
; falling back to
'take_first' for handling repeats."
)
)
_filter_list
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
return
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
else
:
return
_filter_list
def
_strip_fn
(
d
:
dict
)
->
dict
:
return
{
k
:
v
for
k
,
v
in
d
.
items
()
if
k
!=
"function"
}
configs
=
(
self
.
filter_list
.
values
()
if
isinstance
(
self
.
filter_list
,
dict
)
else
self
.
filter_list
)
return
[
build_filter_ensemble
(
filter_name
=
cfg
[
"name"
],
components
=
[[
_strip_fn
(
f
)
for
f
in
cfg
[
"filter"
]]],
)
for
cfg
in
configs
]
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
...
@@ -415,31 +410,27 @@ class TaskConfig(dict):
...
@@ -415,31 +410,27 @@ class TaskConfig(dict):
return
setattr
(
self
,
item
,
value
)
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
"""Return a printable dict with Nones stripped and callables serialised.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict
:return: dict
A printable dictionary version of the TaskConfig object.
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
"""
cfg_dict
=
asdict
(
self
)
# remove values that are `None`
def
_maybe_serialize
(
val
):
for
k
,
v
in
list
(
cfg_dict
.
items
()):
return
(
if
v
is
None
:
self
.
serialize_function
(
val
,
keep_callable
=
keep_callable
)
cfg_dict
.
pop
(
k
)
if
callable
(
val
)
elif
k
==
"metric_list"
:
else
val
for
metric_dict
in
v
:
)
for
metric_key
,
metric_value
in
metric_dict
.
items
():
if
callable
(
metric_value
):
cfg
=
asdict
(
self
)
metric_dict
[
metric_key
]
=
self
.
serialize_function
(
return
{
metric_value
,
keep_callable
=
keep_callable
k
:
[{
mk
:
_maybe_serialize
(
mv
)
for
mk
,
mv
in
md
.
items
()}
for
md
in
v
]
)
if
k
==
"metric_list"
cfg_dict
[
k
]
=
v
else
_maybe_serialize
(
v
)
elif
callable
(
v
):
for
k
,
v
in
cfg
.
items
()
cfg_dict
[
k
]
=
self
.
serialize_function
(
v
,
keep_callable
=
keep_callable
)
if
v
is
not
None
return
cfg_dict
}
def
serialize_function
(
def
serialize_function
(
self
,
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
self
,
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
...
@@ -627,7 +618,7 @@ class Task(abc.ABC):
...
@@ -627,7 +618,7 @@ class Task(abc.ABC):
return
doc
return
doc
@
property
@
property
def
instances
(
self
)
->
L
ist
[
Instance
]:
def
instances
(
self
)
->
l
ist
[
Instance
]:
"""After calling `task.build_all_requests()`, tasks
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
maintain a list of the dataset instances which will be evaluated.
"""
"""
...
@@ -639,27 +630,27 @@ class Task(abc.ABC):
...
@@ -639,27 +630,27 @@ class Task(abc.ABC):
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
def
doc_to_decontamination_query
(
self
,
doc
):
def
doc_to_decontamination_query
(
self
,
doc
:
dict
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Override doc_to_decontamination_query with document specific decontamination query."
"Override doc_to_decontamination_query with document specific decontamination query."
)
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
)
->
str
:
def
doc_to_text
(
self
,
doc
:
dict
)
->
str
:
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
)
->
Union
[
str
,
int
]:
def
doc_to_target
(
self
,
doc
:
dict
)
->
Union
[
str
,
int
]:
pass
pass
# not an abstractmethod because not every language-only task has to implement this
# not an abstractmethod because not every language-only task has to implement this
def
doc_to_image
(
self
,
doc
):
def
doc_to_image
(
self
,
doc
:
dict
):
raise
NotImplementedError
raise
NotImplementedError
def
doc_to_audio
(
self
,
doc
):
def
doc_to_audio
(
self
,
doc
:
dict
):
raise
NotImplementedError
raise
NotImplementedError
def
doc_to_prefix
(
self
,
doc
)
->
str
:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
str
:
return
""
return
""
def
build_all_requests
(
def
build_all_requests
(
...
@@ -776,7 +767,7 @@ class Task(abc.ABC):
...
@@ -776,7 +767,7 @@ class Task(abc.ABC):
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Union
[
list
[
dict
],
str
]
,
**
kwargs
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -797,7 +788,7 @@ class Task(abc.ABC):
...
@@ -797,7 +788,7 @@ class Task(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
:
dict
,
results
:
list
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
the metric for that one document
the metric for that one document
...
@@ -1446,7 +1437,7 @@ class ConfigurableTask(Task):
...
@@ -1446,7 +1437,7 @@ class ConfigurableTask(Task):
"""
"""
return
doc
return
doc
def
doc_to_text
(
self
,
doc
,
doc_to_text
=
None
):
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
Optional
[
int
,
str
,
Callable
]
=
None
):
if
self
.
prompt
is
not
None
:
if
self
.
prompt
is
not
None
:
doc_to_text
=
self
.
prompt
doc_to_text
=
self
.
prompt
elif
doc_to_text
is
not
None
:
elif
doc_to_text
is
not
None
:
...
@@ -1482,7 +1473,7 @@ class ConfigurableTask(Task):
...
@@ -1482,7 +1473,7 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
print
(
type
(
doc_to_text
))
raise
TypeError
raise
TypeError
def
doc_to_target
(
self
,
doc
:
Mapping
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
if
self
.
prompt
is
not
None
:
if
self
.
prompt
is
not
None
:
doc_to_target
=
self
.
prompt
doc_to_target
=
self
.
prompt
elif
doc_to_target
is
not
None
:
elif
doc_to_target
is
not
None
:
...
@@ -1528,7 +1519,9 @@ class ConfigurableTask(Task):
...
@@ -1528,7 +1519,9 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_choice
(
self
,
doc
:
Any
,
doc_to_choice
=
None
)
->
List
[
str
]:
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
]
=
None
)
->
List
[
str
]:
if
self
.
prompt
is
not
None
:
if
self
.
prompt
is
not
None
:
doc_to_choice
=
self
.
prompt
doc_to_choice
=
self
.
prompt
elif
doc_to_choice
is
not
None
:
elif
doc_to_choice
is
not
None
:
...
@@ -1554,7 +1547,7 @@ class ConfigurableTask(Task):
...
@@ -1554,7 +1547,7 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_image
(
self
,
doc
:
Any
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
if
doc_to_image
is
not
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
doc_to_image
=
doc_to_image
elif
self
.
config
.
doc_to_image
is
not
None
:
elif
self
.
config
.
doc_to_image
is
not
None
:
...
@@ -1600,7 +1593,7 @@ class ConfigurableTask(Task):
...
@@ -1600,7 +1593,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
def
doc_to_prefix
(
self
,
doc
)
->
Optional
[
str
]:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
Optional
[
str
]:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
return
doc
[
gen_prefix
]
...
@@ -1709,7 +1702,7 @@ class ConfigurableTask(Task):
...
@@ -1709,7 +1702,7 @@ class ConfigurableTask(Task):
**
kwargs
,
**
kwargs
,
)
)
def
process_results
(
self
,
doc
,
results
)
:
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
:
if
callable
(
self
.
config
.
process_results
):
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
return
self
.
config
.
process_results
(
doc
,
results
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment