Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e0311cd5
Commit
e0311cd5
authored
Sep 06, 2023
by
baberabb
Browse files
Merge remote-tracking branch 'origin/big-refactor' into big-refactor_python_final
parents
96c60cf6
f86d6874
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
25 deletions
+63
-25
lm_eval/utils.py
lm_eval/utils.py
+11
-14
main.py
main.py
+20
-7
mypy.ini
mypy.ini
+29
-0
setup.py
setup.py
+3
-4
No files found.
lm_eval/utils.py
View file @
e0311cd5
...
...
@@ -10,7 +10,7 @@ import collections
import
importlib.util
import
fnmatch
from
typing
import
List
,
Literal
,
Union
from
typing
import
Iterator
,
List
,
Literal
,
Union
import
gc
import
torch
...
...
@@ -65,7 +65,7 @@ def join_iters(iters):
yield
from
iter
def
chunks
(
iter
,
n
=
0
,
fn
=
None
):
def
chunks
(
iter
,
n
:
int
=
0
,
fn
=
None
):
arr
=
[]
for
i
,
x
in
enumerate
(
iter
):
arr
.
append
(
x
)
...
...
@@ -87,11 +87,11 @@ def group(arr, fn):
class
MultiChoice
:
def
__init__
(
self
,
choices
):
def
__init__
(
self
,
choices
)
->
None
:
self
.
choices
=
choices
# Simple wildcard support (linux filename patterns)
def
__contains__
(
self
,
values
):
def
__contains__
(
self
,
values
)
->
bool
:
for
value
in
values
.
split
(
","
):
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
eval_logger
.
info
(
f
"Available tasks to choose:"
)
...
...
@@ -100,7 +100,7 @@ class MultiChoice:
raise
ValueError
(
"'{}' is not in task list"
.
format
(
value
))
return
True
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
:
for
choice
in
self
.
choices
:
yield
choice
...
...
@@ -108,7 +108,6 @@ class MultiChoice:
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def
pattern_match
(
patterns
,
source_list
):
if
type
(
patterns
)
==
str
:
patterns
=
[
patterns
]
...
...
@@ -177,7 +176,7 @@ def make_disjoint_window(pair):
class
Reorderer
:
def
__init__
(
self
,
arr
,
fn
):
def
__init__
(
self
,
arr
,
fn
)
->
None
:
self
.
size
=
len
(
arr
)
arr
=
list
(
enumerate
(
arr
))
arr
=
group
(
arr
,
lambda
x
:
fn
(
x
[
1
]))
...
...
@@ -212,7 +211,7 @@ class Grouper:
objects in `arr` satisfying `key == fn(ob)`.
"""
def
__init__
(
self
,
arr
,
fn
):
def
__init__
(
self
,
arr
,
fn
)
->
None
:
# self.orig_arr = arr
self
.
size
=
len
(
arr
)
arr
=
list
(
enumerate
(
arr
))
...
...
@@ -263,7 +262,7 @@ class Grouper:
return
res
def
make_table
(
result_dict
,
column
=
"results"
):
def
make_table
(
result_dict
,
column
:
str
=
"results"
):
"""Generate table of results."""
from
pytablewriter
import
MarkdownTableWriter
,
LatexTableWriter
...
...
@@ -393,7 +392,6 @@ def get_git_commit_hash():
def
import_function
(
loader
,
node
):
function_name
=
loader
.
construct_scalar
(
node
)
yaml_path
=
os
.
path
.
dirname
(
loader
.
name
)
...
...
@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path):
include_path
.
reverse
()
final_yaml_config
=
{}
for
path
in
include_path
:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
...
...
@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path):
return
yaml_config
def
regex_replace
(
string
,
pattern
,
repl
,
count
=
0
):
def
regex_replace
(
string
,
pattern
,
repl
,
count
:
int
=
0
):
"""Implements the `re.sub` function as a custom Jinja filter."""
return
re
.
sub
(
pattern
,
repl
,
string
,
count
=
count
)
...
...
@@ -521,7 +518,7 @@ def pad_and_concat(
return
torch
.
cat
(
tensors
,
dim
=
0
)
def
clear_torch_cache
():
def
clear_torch_cache
()
->
None
:
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
tokenizer
:
transformers
.
PreTrainedTokenizer
,
initial_decoder_input_length
:
int
,
batch_size
:
int
,
):
)
->
None
:
self
.
initial_decoder_input_length
=
initial_decoder_input_length
self
.
done_tracker
=
[
False
]
*
batch_size
self
.
sequence
=
sequence
...
...
main.py
View file @
e0311cd5
...
...
@@ -9,24 +9,26 @@ from pathlib import Path
from
lm_eval
import
evaluator
,
utils
from
lm_eval.api.registry
import
ALL_TASKS
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
,
SPACING
from
lm_eval.tasks
import
include_task_folder
from
lm_eval.benchmarks
import
include_benchmarks
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
def
parse_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
argparse
.
RawTextHelpFormatter
)
parser
.
add_argument
(
"--model"
,
required
=
True
,
help
=
"Name of model e.g. `hf`"
)
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
help
=
"Available Tasks:
\n
- {}"
.
format
(
"
\n
- "
.
join
(
sorted
(
ALL_TASKS
))),
)
parser
.
add_argument
(
"--model_args"
,
default
=
""
,
help
=
"String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`"
,
)
parser
.
add_argument
(
"--tasks"
,
default
=
None
# , choices=utils.MultiChoice(sorted(ALL_TASKS))
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
...
...
@@ -99,7 +101,7 @@ def parse_args():
return
parser
.
parse_args
()
def
main
():
def
main
()
->
None
:
args
=
parse_args
()
if
args
.
limit
:
...
...
@@ -126,10 +128,21 @@ def main():
else
:
tasks_list
=
args
.
tasks
.
split
(
","
)
task_names
=
utils
.
pattern_match
(
tasks_list
,
ALL_TASKS
)
task_missing
=
[]
for
task
in
[
task
for
task
in
tasks_list
if
task
not
in
task_names
]:
if
os
.
path
.
isfile
(
task
):
config
=
utils
.
load_yaml_config
(
task
)
task_names
.
append
(
config
)
else
:
task_missing
.
append
(
task
)
if
task_missing
!=
[]:
missing
=
", "
.
join
(
task_missing
)
eval_logger
.
error
(
f
"Tasks were not found:
{
missing
}
\n
"
f
"
{
SPACING
}
Try `lm-eval -h` for list of available tasks"
,
)
raise
ValueError
(
f
"Tasks
{
missing
}
were not found."
)
if
args
.
output_path
:
path
=
Path
(
args
.
output_path
)
...
...
mypy.ini
0 → 100644
View file @
e0311cd5
[mypy]
python_version
=
3.9
show_traceback
=
True
check_untyped_defs
=
True
no_implicit_reexport
=
True
warn_unreachable
=
True
warn_unused_configs
=
True
warn_unused_ignores
=
True
warn_redundant_casts
=
True
# We ignore errors everywhere to gradually add type annotations
[mypy-lm_eval.*]
ignore_errors
=
True
[mypy-lm_eval.api.*]
ignore_errors
=
True
[mypy-lm_eval.prompts.*]
ignore_errors
=
True
[mypy-lm_eval.models.*]
ignore_errors
=
True
[mypy-scripts.*]
ignore_errors
=
True
[mypy-main]
ignore_errors
=
True
setup.py
View file @
e0311cd5
...
...
@@ -15,7 +15,7 @@ extras_require = {
],
"testing"
:
[
"pytest"
,
"pytest-cov"
,
"pytest-xdist"
],
"multilingual"
:
[
"nagisa>=0.2.7"
,
"jieba>=0.42.1"
],
"sentencepiece"
:
[
"sentencepiece>=0.1.98"
,
"protobuf>=4.22.1"
],
"sentencepiece"
:
[
"sentencepiece>=0.1.98"
,
"protobuf>=4.22.1"
,
"pycountry"
],
"promptsource"
:
[
"promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource"
],
...
...
@@ -53,7 +53,7 @@ setuptools.setup(
],
python_requires
=
">=3.9"
,
install_requires
=
[
"accelerate>=0.1
8
.0"
,
"accelerate>=0.
2
1.0"
,
"evaluate"
,
"datasets>=2.0.0"
,
"evaluate>=0.4.0"
,
...
...
@@ -62,10 +62,9 @@ setuptools.setup(
"omegaconf>=2.2"
,
"peft>=0.2.0"
,
"pybind11>=2.6.2"
,
"pycountry"
,
"pytablewriter"
,
"rouge-score>=0.0.4"
,
"sacrebleu
=
=1.5.0"
,
"sacrebleu
>
=1.5.0"
,
"scikit-learn>=0.24.1"
,
"sqlitedict"
,
"torch>=1.8"
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment