Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8f448eed
Unverified
Commit
8f448eed
authored
Sep 05, 2023
by
Hailey Schoelkopf
Committed by
GitHub
Sep 05, 2023
Browse files
Merge pull request #809 from ethanhs/mypy
Add mypy baseline config
parents
cc7828dd
4721379e
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
52 additions
and
26 deletions
+52
-26
lm_eval/prompts/__init__.py
lm_eval/prompts/__init__.py
+2
-2
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+4
-4
lm_eval/tasks/glue/mnli/utils.py
lm_eval/tasks/glue/mnli/utils.py
+1
-1
lm_eval/tasks/hendrycks_ethics/utils.py
lm_eval/tasks/hendrycks_ethics/utils.py
+1
-1
lm_eval/tasks/pubmedqa/preprocess_pubmedqa.py
lm_eval/tasks/pubmedqa/preprocess_pubmedqa.py
+2
-2
lm_eval/utils.py
lm_eval/utils.py
+11
-14
main.py
main.py
+2
-2
mypy.ini
mypy.ini
+29
-0
No files found.
lm_eval/prompts/__init__.py
View file @
8f448eed
...
@@ -5,7 +5,7 @@ from lm_eval.logger import eval_logger
...
@@ -5,7 +5,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels:
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# prompt category name, and prompt name.
# This allows us to access prompts
# This allows us to access prompts
PROMPT_REGISTRY
=
{
PROMPT_REGISTRY
:
dict
[
str
,
dict
[
str
,
str
]]
=
{
"qa-basic"
:
{
"qa-basic"
:
{
"question-newline-answer"
:
"Question: {{question}}
\n
Answer:"
,
"question-newline-answer"
:
"Question: {{question}}
\n
Answer:"
,
"q-newline-a"
:
"Q: {{question}}
\n
A:"
,
"q-newline-a"
:
"Q: {{question}}
\n
A:"
,
...
@@ -13,7 +13,7 @@ PROMPT_REGISTRY = {
...
@@ -13,7 +13,7 @@ PROMPT_REGISTRY = {
}
}
def
get_prompt
(
prompt_id
:
str
,
dataset_name
=
None
,
subset_name
=
None
):
def
get_prompt
(
prompt_id
:
str
,
dataset_name
:
str
=
None
,
subset_name
:
str
=
None
):
# unpack prompt name
# unpack prompt name
category_name
,
prompt_name
=
prompt_id
.
split
(
":"
)
category_name
,
prompt_name
=
prompt_id
.
split
(
":"
)
if
subset_name
is
None
:
if
subset_name
is
None
:
...
...
lm_eval/tasks/__init__.py
View file @
8f448eed
...
@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
...
@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
)
)
def
register_configurable_task
(
config
)
:
def
register_configurable_task
(
config
:
dict
[
str
,
str
])
->
int
:
SubClass
=
type
(
SubClass
=
type
(
config
[
"task"
]
+
"ConfigurableTask"
,
config
[
"task"
]
+
"ConfigurableTask"
,
(
ConfigurableTask
,),
(
ConfigurableTask
,),
...
@@ -38,7 +38,7 @@ def register_configurable_task(config):
...
@@ -38,7 +38,7 @@ def register_configurable_task(config):
return
0
return
0
def
check_prompt_config
(
config
)
:
def
check_prompt_config
(
config
:
dict
[
str
,
str
])
->
List
[
dict
[
str
,
str
]]
:
all_configs
=
[]
all_configs
=
[]
if
"use_prompt"
in
config
:
if
"use_prompt"
in
config
:
prompt_list
=
prompts
.
load_prompt_list
(
prompt_list
=
prompts
.
load_prompt_list
(
...
@@ -69,14 +69,14 @@ def check_prompt_config(config):
...
@@ -69,14 +69,14 @@ def check_prompt_config(config):
return
all_configs
return
all_configs
def
get_task_name_from_config
(
task_config
)
:
def
get_task_name_from_config
(
task_config
:
dict
[
str
,
str
])
->
str
:
if
"dataset_name"
in
task_config
:
if
"dataset_name"
in
task_config
:
return
"{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
return
"{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
else
:
else
:
return
"{dataset_path}"
.
format
(
**
task_config
)
return
"{dataset_path}"
.
format
(
**
task_config
)
def
include_task_folder
(
task_dir
)
:
def
include_task_folder
(
task_dir
:
str
)
->
None
:
"""
"""
Calling this function
Calling this function
"""
"""
...
...
lm_eval/tasks/glue/mnli/utils.py
View file @
8f448eed
def
doc_to_text
(
doc
):
def
doc_to_text
(
doc
)
->
str
:
return
"{}
\n
Question: {} True, False or Neither?
\n
Answer:"
.
format
(
return
"{}
\n
Question: {} True, False or Neither?
\n
Answer:"
.
format
(
doc
[
"premise"
],
doc
[
"premise"
],
doc
[
"hypothesis"
].
strip
()
doc
[
"hypothesis"
].
strip
()
...
...
lm_eval/tasks/hendrycks_ethics/utils.py
View file @
8f448eed
...
@@ -15,7 +15,7 @@ def _preproc_doc(doc):
...
@@ -15,7 +15,7 @@ def _preproc_doc(doc):
return
doc
return
doc
def
doc_to_text
(
doc
):
def
doc_to_text
(
doc
)
->
str
:
doc
=
_preproc_doc
(
doc
)
doc
=
_preproc_doc
(
doc
)
return
f
"Scenario 1:
{
doc
[
'scenarios'
][
0
]
}
\n
Scenario 2:
{
doc
[
'scenarios'
][
1
]
}
\n
Question: Is Scenario 1 preferable?
\n
Answer:"
return
f
"Scenario 1:
{
doc
[
'scenarios'
][
0
]
}
\n
Scenario 2:
{
doc
[
'scenarios'
][
1
]
}
\n
Question: Is Scenario 1 preferable?
\n
Answer:"
...
...
lm_eval/tasks/pubmedqa/preprocess_pubmedqa.py
View file @
8f448eed
def
doc_to_text
(
doc
):
def
doc_to_text
(
doc
)
->
str
:
ctxs
=
"
\n
"
.
join
(
doc
[
"context"
][
"contexts"
])
ctxs
=
"
\n
"
.
join
(
doc
[
"context"
][
"contexts"
])
return
"Abstract: {}
\n
Question: {}
\n
Answer:"
.
format
(
return
"Abstract: {}
\n
Question: {}
\n
Answer:"
.
format
(
ctxs
,
doc
[
"question"
],
doc
[
"final_decision"
]
ctxs
,
doc
[
"question"
],
doc
[
"final_decision"
]
)
)
def
doc_to_target
(
doc
):
def
doc_to_target
(
doc
)
->
str
:
return
" {}"
.
format
(
doc
[
"final_decision"
])
return
" {}"
.
format
(
doc
[
"final_decision"
])
...
...
lm_eval/utils.py
View file @
8f448eed
...
@@ -10,7 +10,7 @@ import collections
...
@@ -10,7 +10,7 @@ import collections
import
importlib.util
import
importlib.util
import
fnmatch
import
fnmatch
from
typing
import
List
,
Literal
,
Union
from
typing
import
Iterator
,
List
,
Literal
,
Union
import
gc
import
gc
import
torch
import
torch
...
@@ -65,7 +65,7 @@ def join_iters(iters):
...
@@ -65,7 +65,7 @@ def join_iters(iters):
yield
from
iter
yield
from
iter
def
chunks
(
iter
,
n
=
0
,
fn
=
None
):
def
chunks
(
iter
,
n
:
int
=
0
,
fn
=
None
):
arr
=
[]
arr
=
[]
for
i
,
x
in
enumerate
(
iter
):
for
i
,
x
in
enumerate
(
iter
):
arr
.
append
(
x
)
arr
.
append
(
x
)
...
@@ -87,11 +87,11 @@ def group(arr, fn):
...
@@ -87,11 +87,11 @@ def group(arr, fn):
class
MultiChoice
:
class
MultiChoice
:
def
__init__
(
self
,
choices
):
def
__init__
(
self
,
choices
)
->
None
:
self
.
choices
=
choices
self
.
choices
=
choices
# Simple wildcard support (linux filename patterns)
# Simple wildcard support (linux filename patterns)
def
__contains__
(
self
,
values
):
def
__contains__
(
self
,
values
)
->
bool
:
for
value
in
values
.
split
(
","
):
for
value
in
values
.
split
(
","
):
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
eval_logger
.
info
(
f
"Available tasks to choose:"
)
eval_logger
.
info
(
f
"Available tasks to choose:"
)
...
@@ -100,7 +100,7 @@ class MultiChoice:
...
@@ -100,7 +100,7 @@ class MultiChoice:
raise
ValueError
(
"'{}' is not in task list"
.
format
(
value
))
raise
ValueError
(
"'{}' is not in task list"
.
format
(
value
))
return
True
return
True
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
:
for
choice
in
self
.
choices
:
for
choice
in
self
.
choices
:
yield
choice
yield
choice
...
@@ -108,7 +108,6 @@ class MultiChoice:
...
@@ -108,7 +108,6 @@ class MultiChoice:
# Returns a list containing all values of the source_list that
# Returns a list containing all values of the source_list that
# match at least one of the patterns
# match at least one of the patterns
def
pattern_match
(
patterns
,
source_list
):
def
pattern_match
(
patterns
,
source_list
):
if
type
(
patterns
)
==
str
:
if
type
(
patterns
)
==
str
:
patterns
=
[
patterns
]
patterns
=
[
patterns
]
...
@@ -177,7 +176,7 @@ def make_disjoint_window(pair):
...
@@ -177,7 +176,7 @@ def make_disjoint_window(pair):
class
Reorderer
:
class
Reorderer
:
def
__init__
(
self
,
arr
,
fn
):
def
__init__
(
self
,
arr
,
fn
)
->
None
:
self
.
size
=
len
(
arr
)
self
.
size
=
len
(
arr
)
arr
=
list
(
enumerate
(
arr
))
arr
=
list
(
enumerate
(
arr
))
arr
=
group
(
arr
,
lambda
x
:
fn
(
x
[
1
]))
arr
=
group
(
arr
,
lambda
x
:
fn
(
x
[
1
]))
...
@@ -212,7 +211,7 @@ class Grouper:
...
@@ -212,7 +211,7 @@ class Grouper:
objects in `arr` satisfying `key == fn(ob)`.
objects in `arr` satisfying `key == fn(ob)`.
"""
"""
def
__init__
(
self
,
arr
,
fn
):
def
__init__
(
self
,
arr
,
fn
)
->
None
:
# self.orig_arr = arr
# self.orig_arr = arr
self
.
size
=
len
(
arr
)
self
.
size
=
len
(
arr
)
arr
=
list
(
enumerate
(
arr
))
arr
=
list
(
enumerate
(
arr
))
...
@@ -263,7 +262,7 @@ class Grouper:
...
@@ -263,7 +262,7 @@ class Grouper:
return
res
return
res
def
make_table
(
result_dict
,
column
=
"results"
):
def
make_table
(
result_dict
,
column
:
str
=
"results"
):
"""Generate table of results."""
"""Generate table of results."""
from
pytablewriter
import
MarkdownTableWriter
,
LatexTableWriter
from
pytablewriter
import
MarkdownTableWriter
,
LatexTableWriter
...
@@ -393,7 +392,6 @@ def get_git_commit_hash():
...
@@ -393,7 +392,6 @@ def get_git_commit_hash():
def
import_function
(
loader
,
node
):
def
import_function
(
loader
,
node
):
function_name
=
loader
.
construct_scalar
(
node
)
function_name
=
loader
.
construct_scalar
(
node
)
yaml_path
=
os
.
path
.
dirname
(
loader
.
name
)
yaml_path
=
os
.
path
.
dirname
(
loader
.
name
)
...
@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path):
...
@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path):
include_path
.
reverse
()
include_path
.
reverse
()
final_yaml_config
=
{}
final_yaml_config
=
{}
for
path
in
include_path
:
for
path
in
include_path
:
# Assumes that path is a full path.
# Assumes that path is a full path.
# If not found, assume the included yaml
# If not found, assume the included yaml
# is in the same dir as the original yaml
# is in the same dir as the original yaml
...
@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path):
...
@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path):
return
yaml_config
return
yaml_config
def
regex_replace
(
string
,
pattern
,
repl
,
count
=
0
):
def
regex_replace
(
string
,
pattern
,
repl
,
count
:
int
=
0
):
"""Implements the `re.sub` function as a custom Jinja filter."""
"""Implements the `re.sub` function as a custom Jinja filter."""
return
re
.
sub
(
pattern
,
repl
,
string
,
count
=
count
)
return
re
.
sub
(
pattern
,
repl
,
string
,
count
=
count
)
...
@@ -521,7 +518,7 @@ def pad_and_concat(
...
@@ -521,7 +518,7 @@ def pad_and_concat(
return
torch
.
cat
(
tensors
,
dim
=
0
)
return
torch
.
cat
(
tensors
,
dim
=
0
)
def
clear_torch_cache
():
def
clear_torch_cache
()
->
None
:
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
...
@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
tokenizer
:
transformers
.
PreTrainedTokenizer
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
initial_decoder_input_length
:
int
,
initial_decoder_input_length
:
int
,
batch_size
:
int
,
batch_size
:
int
,
):
)
->
None
:
self
.
initial_decoder_input_length
=
initial_decoder_input_length
self
.
initial_decoder_input_length
=
initial_decoder_input_length
self
.
done_tracker
=
[
False
]
*
batch_size
self
.
done_tracker
=
[
False
]
*
batch_size
self
.
sequence
=
sequence
self
.
sequence
=
sequence
...
...
main.py
View file @
8f448eed
...
@@ -16,7 +16,7 @@ from lm_eval.benchmarks import include_benchmarks
...
@@ -16,7 +16,7 @@ from lm_eval.benchmarks import include_benchmarks
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
def
parse_args
():
def
parse_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
required
=
True
,
help
=
"Name of model e.g. `hf`"
)
parser
.
add_argument
(
"--model"
,
required
=
True
,
help
=
"Name of model e.g. `hf`"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -99,7 +99,7 @@ def parse_args():
...
@@ -99,7 +99,7 @@ def parse_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
main
():
def
main
()
->
None
:
args
=
parse_args
()
args
=
parse_args
()
if
args
.
limit
:
if
args
.
limit
:
...
...
mypy.ini
0 → 100644
View file @
8f448eed
[mypy]
python_version
=
3.9
show_traceback
=
True
check_untyped_defs
=
True
no_implicit_reexport
=
True
warn_unreachable
=
True
warn_unused_configs
=
True
warn_unused_ignores
=
True
warn_redundant_casts
=
True
# We ignore errors everywhere to gradually add type annotations
[mypy-lm_eval.*]
ignore_errors
=
True
[mypy-lm_eval.api.*]
ignore_errors
=
True
[mypy-lm_eval.prompts.*]
ignore_errors
=
True
[mypy-lm_eval.models.*]
ignore_errors
=
True
[mypy-scripts.*]
ignore_errors
=
True
[mypy-main]
ignore_errors
=
True
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment