Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
97f5c020
Commit
97f5c020
authored
Aug 08, 2023
by
baberabb
Browse files
added typehints
parent
b0f67f2c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
30 deletions
+29
-30
lm_eval/api/task.py
lm_eval/api/task.py
+28
-29
lm_eval/utils.py
lm_eval/utils.py
+1
-1
No files found.
lm_eval/api/task.py
View file @
97f5c020
...
@@ -13,7 +13,7 @@ from tqdm import tqdm
...
@@ -13,7 +13,7 @@ from tqdm import tqdm
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
from
typing
import
Union
from
typing
import
Union
,
List
,
Any
,
Tuple
,
Literal
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
lm_eval
import
utils
from
lm_eval
import
utils
...
@@ -477,7 +477,7 @@ class Task(abc.ABC):
...
@@ -477,7 +477,7 @@ class Task(abc.ABC):
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
return
self
.
_instances
return
self
.
_instances
def
dump_config
(
self
):
def
dump_config
(
self
)
->
dict
:
"""Returns a dictionary representing the task's config.
"""Returns a dictionary representing the task's config.
:returns: str
:returns: str
...
@@ -489,14 +489,13 @@ class Task(abc.ABC):
...
@@ -489,14 +489,13 @@ class Task(abc.ABC):
class
ConfigurableTask
(
Task
):
class
ConfigurableTask
(
Task
):
VERSION
=
"Yaml"
VERSION
=
"Yaml"
OUTPUT_TYPE
=
None
OUTPUT_TYPE
=
None
CONFIG
=
None
CONFIG
=
None
def
__init__
(
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
):
):
# TODO no super() call here
# Get pre-configured attributes
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
self
.
_config
=
self
.
CONFIG
...
@@ -662,25 +661,25 @@ class ConfigurableTask(Task):
...
@@ -662,25 +661,25 @@ class ConfigurableTask(Task):
**
dataset_kwargs
if
dataset_kwargs
is
not
None
else
{},
**
dataset_kwargs
if
dataset_kwargs
is
not
None
else
{},
)
)
def
has_training_docs
(
self
):
def
has_training_docs
(
self
)
->
bool
:
if
self
.
_config
.
training_split
is
not
None
:
if
self
.
_config
.
training_split
is
not
None
:
return
True
return
True
else
:
else
:
return
False
return
False
def
has_validation_docs
(
self
):
def
has_validation_docs
(
self
)
->
bool
:
if
self
.
_config
.
validation_split
is
not
None
:
if
self
.
_config
.
validation_split
is
not
None
:
return
True
return
True
else
:
else
:
return
False
return
False
def
has_test_docs
(
self
):
def
has_test_docs
(
self
)
->
bool
:
if
self
.
_config
.
test_split
is
not
None
:
if
self
.
_config
.
test_split
is
not
None
:
return
True
return
True
else
:
else
:
return
False
return
False
def
training_docs
(
self
):
def
training_docs
(
self
)
->
datasets
.
Dataset
:
if
self
.
has_training_docs
():
if
self
.
has_training_docs
():
if
self
.
_config
.
process_docs
is
not
None
:
if
self
.
_config
.
process_docs
is
not
None
:
return
self
.
_config
.
process_docs
(
return
self
.
_config
.
process_docs
(
...
@@ -688,7 +687,7 @@ class ConfigurableTask(Task):
...
@@ -688,7 +687,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
_config
.
training_split
]
return
self
.
dataset
[
self
.
_config
.
training_split
]
def
validation_docs
(
self
):
def
validation_docs
(
self
)
->
datasets
.
Dataset
:
if
self
.
has_validation_docs
():
if
self
.
has_validation_docs
():
if
self
.
_config
.
process_docs
is
not
None
:
if
self
.
_config
.
process_docs
is
not
None
:
return
self
.
_config
.
process_docs
(
return
self
.
_config
.
process_docs
(
...
@@ -696,7 +695,7 @@ class ConfigurableTask(Task):
...
@@ -696,7 +695,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
_config
.
validation_split
]
return
self
.
dataset
[
self
.
_config
.
validation_split
]
def
test_docs
(
self
):
def
test_docs
(
self
)
->
datasets
.
Dataset
:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
if
self
.
_config
.
process_docs
is
not
None
:
if
self
.
_config
.
process_docs
is
not
None
:
return
self
.
_config
.
process_docs
(
self
.
dataset
[
self
.
_config
.
test_split
])
return
self
.
_config
.
process_docs
(
self
.
dataset
[
self
.
_config
.
test_split
])
...
@@ -767,7 +766,7 @@ class ConfigurableTask(Task):
...
@@ -767,7 +766,7 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
print
(
type
(
doc_to_text
))
raise
TypeError
raise
TypeError
def
doc_to_target
(
self
,
doc
)
:
def
doc_to_target
(
self
,
doc
:
dict
)
->
Union
[
int
,
str
]
:
if
self
.
prompt
is
not
None
:
if
self
.
prompt
is
not
None
:
doc_to_target
=
self
.
prompt
doc_to_target
=
self
.
prompt
...
@@ -796,7 +795,7 @@ class ConfigurableTask(Task):
...
@@ -796,7 +795,7 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_choice
(
self
,
doc
)
:
def
doc_to_choice
(
self
,
doc
:
Any
)
->
List
[
str
]
:
if
self
.
prompt
is
not
None
:
if
self
.
prompt
is
not
None
:
doc_to_choice
=
self
.
prompt
doc_to_choice
=
self
.
prompt
...
@@ -838,7 +837,9 @@ class ConfigurableTask(Task):
...
@@ -838,7 +837,9 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
Union
[
List
[
Instance
],
Instance
]:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
...
@@ -1037,13 +1038,12 @@ class ConfigurableTask(Task):
...
@@ -1037,13 +1038,12 @@ class ConfigurableTask(Task):
class
MultipleChoiceTask
(
Task
):
class
MultipleChoiceTask
(
Task
):
OUTPUT_TYPE
:
str
=
"loglikelihood"
OUTPUT_TYPE
:
str
=
"loglikelihood"
def
doc_to_target
(
self
,
doc
)
:
def
doc_to_target
(
self
,
doc
:
dict
)
->
str
:
return
" "
+
doc
[
"choices"
][
doc
[
"gold"
]]
return
" "
+
doc
[
"choices"
][
doc
[
"gold"
]]
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
)
:
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
List
[
Instance
]
:
# TODO: add mutual info here?
# TODO: add mutual info here?
return
[
return
[
Instance
(
Instance
(
...
@@ -1056,7 +1056,7 @@ class MultipleChoiceTask(Task):
...
@@ -1056,7 +1056,7 @@ class MultipleChoiceTask(Task):
for
i
,
choice
in
enumerate
(
doc
[
"choices"
])
for
i
,
choice
in
enumerate
(
doc
[
"choices"
])
]
]
def
process_results
(
self
,
doc
,
results
)
:
def
process_results
(
self
,
doc
:
dict
,
results
:
List
[
Tuple
[
float
,
bool
]])
->
dict
:
results
=
[
results
=
[
res
[
0
]
for
res
in
results
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
]
# only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...
@@ -1071,13 +1071,13 @@ class MultipleChoiceTask(Task):
...
@@ -1071,13 +1071,13 @@ class MultipleChoiceTask(Task):
"acc_norm"
:
acc_norm
,
"acc_norm"
:
acc_norm
,
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
)
->
dict
:
return
{
return
{
"acc"
:
True
,
"acc"
:
True
,
"acc_norm"
:
True
,
"acc_norm"
:
True
,
}
}
def
aggregation
(
self
):
def
aggregation
(
self
)
->
dict
:
return
{
return
{
"acc"
:
mean
,
"acc"
:
mean
,
"acc_norm"
:
mean
,
"acc_norm"
:
mean
,
...
@@ -1085,24 +1085,23 @@ class MultipleChoiceTask(Task):
...
@@ -1085,24 +1085,23 @@ class MultipleChoiceTask(Task):
class
PerplexityTask
(
Task
):
class
PerplexityTask
(
Task
):
OUTPUT_TYPE
=
"loglikelihood_rolling"
OUTPUT_TYPE
=
"loglikelihood_rolling"
def
has_training_docs
(
self
):
def
has_training_docs
(
self
)
->
bool
:
return
False
return
False
def
fewshot_examples
(
self
,
k
,
rnd
):
def
fewshot_examples
(
self
,
k
:
int
,
rnd
)
->
List
:
assert
k
==
0
assert
k
==
0
return
[]
return
[]
def
fewshot_context
(
self
,
doc
,
num_fewshot
)
:
def
fewshot_context
(
self
,
doc
:
dict
,
num_fewshot
:
int
)
->
Literal
[
""
]
:
assert
(
assert
(
num_fewshot
==
0
num_fewshot
==
0
),
"The number of fewshot examples must be 0 for perplexity tasks."
),
"The number of fewshot examples must be 0 for perplexity tasks."
return
""
return
""
def
higher_is_better
(
self
):
def
higher_is_better
(
self
)
->
dict
:
return
{
return
{
"word_perplexity"
:
False
,
"word_perplexity"
:
False
,
"byte_perplexity"
:
False
,
"byte_perplexity"
:
False
,
...
@@ -1118,7 +1117,7 @@ class PerplexityTask(Task):
...
@@ -1118,7 +1117,7 @@ class PerplexityTask(Task):
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
doc
return
doc
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Union
[
str
,
None
]
,
**
kwargs
):
assert
not
ctx
assert
not
ctx
return
Instance
(
return
Instance
(
...
@@ -1129,7 +1128,7 @@ class PerplexityTask(Task):
...
@@ -1129,7 +1128,7 @@ class PerplexityTask(Task):
**
kwargs
,
**
kwargs
,
)
)
def
process_results
(
self
,
doc
,
results
)
:
def
process_results
(
self
,
doc
:
dict
,
results
:
float
)
->
dict
:
(
loglikelihood
,)
=
results
(
loglikelihood
,)
=
results
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
bytes_
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
bytes_
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
...
@@ -1139,7 +1138,7 @@ class PerplexityTask(Task):
...
@@ -1139,7 +1138,7 @@ class PerplexityTask(Task):
"bits_per_byte"
:
(
loglikelihood
,
bytes_
),
"bits_per_byte"
:
(
loglikelihood
,
bytes_
),
}
}
def
aggregation
(
self
):
def
aggregation
(
self
)
->
dict
:
return
{
return
{
"word_perplexity"
:
weighted_perplexity
,
"word_perplexity"
:
weighted_perplexity
,
"byte_perplexity"
:
weighted_perplexity
,
"byte_perplexity"
:
weighted_perplexity
,
...
@@ -1147,10 +1146,10 @@ class PerplexityTask(Task):
...
@@ -1147,10 +1146,10 @@ class PerplexityTask(Task):
}
}
@
classmethod
@
classmethod
def
count_bytes
(
cls
,
doc
):
def
count_bytes
(
cls
,
doc
)
->
int
:
return
len
(
doc
.
encode
(
"utf-8"
))
return
len
(
doc
.
encode
(
"utf-8"
))
@
classmethod
@
classmethod
def
count_words
(
cls
,
doc
):
def
count_words
(
cls
,
doc
)
->
int
:
"""Downstream tasks with custom word boundaries should override this!"""
"""Downstream tasks with custom word boundaries should override this!"""
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
lm_eval/utils.py
View file @
97f5c020
...
@@ -456,7 +456,7 @@ env = Environment(loader=BaseLoader, undefined=StrictUndefined)
...
@@ -456,7 +456,7 @@ env = Environment(loader=BaseLoader, undefined=StrictUndefined)
env
.
filters
[
"regex_replace"
]
=
regex_replace
env
.
filters
[
"regex_replace"
]
=
regex_replace
def
apply_template
(
template
,
doc
)
:
def
apply_template
(
template
:
str
,
doc
:
dict
)
->
str
:
rtemplate
=
env
.
from_string
(
template
)
rtemplate
=
env
.
from_string
(
template
)
return
rtemplate
.
render
(
**
doc
)
return
rtemplate
.
render
(
**
doc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment