Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8f448eed
Unverified
Commit
8f448eed
authored
Sep 05, 2023
by
Hailey Schoelkopf
Committed by
GitHub
Sep 05, 2023
Browse files
Merge pull request #809 from ethanhs/mypy
Add mypy baseline config
parents
cc7828dd
4721379e
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
112 additions
and
113 deletions
+112
-113
.pre-commit-config.yaml
.pre-commit-config.yaml
+6
-0
lm_eval/api/filter.py
lm_eval/api/filter.py
+2
-2
lm_eval/api/instance.py
lm_eval/api/instance.py
+1
-1
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+1
-1
lm_eval/api/model.py
lm_eval/api/model.py
+5
-5
lm_eval/api/samplers.py
lm_eval/api/samplers.py
+3
-5
lm_eval/api/task.py
lm_eval/api/task.py
+8
-20
lm_eval/benchmarks/__init__.py
lm_eval/benchmarks/__init__.py
+1
-2
lm_eval/decontamination/archiver.py
lm_eval/decontamination/archiver.py
+19
-12
lm_eval/decontamination/decontaminate.py
lm_eval/decontamination/decontaminate.py
+4
-3
lm_eval/decontamination/janitor.py
lm_eval/decontamination/janitor.py
+25
-22
lm_eval/evaluator.py
lm_eval/evaluator.py
+7
-11
lm_eval/filters/decontamination.py
lm_eval/filters/decontamination.py
+2
-2
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+4
-3
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+3
-4
lm_eval/models/anthropic_llms.py
lm_eval/models/anthropic_llms.py
+2
-3
lm_eval/models/dummy.py
lm_eval/models/dummy.py
+1
-1
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+11
-9
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+4
-4
lm_eval/models/textsynth.py
lm_eval/models/textsynth.py
+3
-3
No files found.
.pre-commit-config.yaml
View file @
8f448eed
...
...
@@ -43,3 +43,9 @@ repos:
.*\.json|ignore.txt
)$
args
:
[
--check-filenames
,
--check-hidden
,
--ignore-words=ignore.txt
]
-
repo
:
https://github.com/pre-commit/mirrors-mypy
rev
:
v1.5.1
hooks
:
-
id
:
mypy
additional_dependencies
:
[
"
.[sentencepiece,multilingual,promptsource,gptq]"
,
"
types-PyYAML"
,
"
types-requests"
]
exclude
:
^tests/.*$
lm_eval/api/filter.py
View file @
8f448eed
...
...
@@ -14,7 +14,7 @@ class Filter:
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
...
...
@@ -41,7 +41,7 @@ class FilterEnsemble:
name
:
str
filters
:
List
[
Filter
]
def
apply
(
self
,
instances
:
List
[
Instance
],
docs
:
List
[
Dataset
]):
def
apply
(
self
,
instances
:
List
[
Instance
],
docs
:
List
[
Dataset
])
->
None
:
resps
=
[
inst
.
resps
for
inst
in
instances
...
...
lm_eval/api/instance.py
View file @
8f448eed
...
...
@@ -19,7 +19,7 @@ class Instance:
doc_id
:
str
=
None
repeats
:
str
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
)
->
None
:
# unpack metadata field
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
...
...
lm_eval/api/metrics.py
View file @
8f448eed
...
...
@@ -302,7 +302,7 @@ def _sacreformat(refs, preds):
class
_bootstrap_internal
:
def
__init__
(
self
,
f
,
n
):
def
__init__
(
self
,
f
,
n
)
->
None
:
self
.
f
=
f
self
.
n
=
n
...
...
lm_eval/api/model.py
View file @
8f448eed
...
...
@@ -13,7 +13,7 @@ from lm_eval.logger import eval_logger
class
LM
(
abc
.
ABC
):
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
"""Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output
(inputs/outputs should be tokenization-agnostic.)
...
...
@@ -133,7 +133,7 @@ class LM(abc.ABC):
# not support multi-device parallelism nor expect it.
return
self
.
_world_size
def
set_cache_hook
(
self
,
cache_hook
):
def
set_cache_hook
(
self
,
cache_hook
)
->
None
:
self
.
cache_hook
=
cache_hook
...
...
@@ -144,14 +144,14 @@ def hash_args(attr, args):
class
CacheHook
:
def
__init__
(
self
,
cachinglm
):
def
__init__
(
self
,
cachinglm
)
->
None
:
if
cachinglm
is
None
:
self
.
dbdict
=
None
return
self
.
dbdict
=
cachinglm
.
dbdict
def
add_partial
(
self
,
attr
,
req
,
res
):
def
add_partial
(
self
,
attr
,
req
,
res
)
->
None
:
if
self
.
dbdict
is
None
:
return
hsh
=
hash_args
(
attr
,
req
)
...
...
@@ -159,7 +159,7 @@ class CacheHook:
class
CachingLM
:
def
__init__
(
self
,
lm
,
cache_db
):
def
__init__
(
self
,
lm
,
cache_db
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
...
...
lm_eval/api/samplers.py
View file @
8f448eed
class
Sampler
:
def
__init__
(
self
,
docs
,
task
,
fewshot_indices
=
None
,
rnd
=
None
):
def
__init__
(
self
,
docs
,
task
,
fewshot_indices
=
None
,
rnd
=
None
)
->
None
:
self
.
rnd
=
rnd
assert
self
.
rnd
,
"must pass rnd to FewShotSampler!"
...
...
@@ -19,7 +18,6 @@ class Sampler:
self
.
docs
=
self
.
docs
.
select
(
fewshot_indices
)
def
get_context
(
self
,
doc
,
num_fewshot
):
# draw an extra fewshot sample if using same split as evaluating on
n_samples
=
(
num_fewshot
+
1
...
...
@@ -74,7 +72,7 @@ class Sampler:
class
BalancedSampler
(
Sampler
):
def
sample
(
self
,
n
):
def
sample
(
self
,
n
)
->
None
:
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
...
...
@@ -84,7 +82,7 @@ class BalancedSampler(Sampler):
class
ManualSampler
(
Sampler
):
def
sample
(
self
,
n
):
def
sample
(
self
,
n
)
->
None
:
""" """
pass
...
...
lm_eval/api/task.py
View file @
8f448eed
...
...
@@ -88,8 +88,8 @@ class TaskConfig(dict):
metadata
:
str
=
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
def
__post_init__
(
self
):
def
__post_init__
(
self
)
->
None
:
if
"."
in
self
.
dataset_path
:
import
inspect
from
importlib
import
import_module
...
...
@@ -177,7 +177,7 @@ class Task(abc.ABC):
cache_dir
=
None
,
download_mode
=
None
,
config
=
None
,
):
)
->
None
:
"""
:param data_dir: str
Stores the path to a local folder containing the `Task`'s data files.
...
...
@@ -188,7 +188,6 @@ class Task(abc.ABC):
HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode
...
...
@@ -219,7 +218,7 @@ class Task(abc.ABC):
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
(
1234
)
)
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
)
->
None
:
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
...
...
@@ -328,7 +327,7 @@ class Task(abc.ABC):
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
def
doc_to_decontamination_query
(
self
,
doc
):
def
doc_to_decontamination_query
(
self
,
doc
)
->
None
:
print
(
"Override doc_to_decontamination_query with document specific decontamination query."
)
...
...
@@ -342,7 +341,7 @@ class Task(abc.ABC):
def
doc_to_target
(
self
,
doc
):
pass
def
build_all_requests
(
self
,
limit
=
None
,
rank
=
None
,
world_size
=
None
):
def
build_all_requests
(
self
,
limit
=
None
,
rank
=
None
,
world_size
=
None
)
->
None
:
"""Build a set of Instances for a task, and store them in task.instances"""
if
self
.
has_test_docs
():
docs
=
self
.
test_docs
()
...
...
@@ -478,7 +477,6 @@ class Task(abc.ABC):
return
labeled_examples
+
str
(
example
)
def
apply_filters
(
self
):
if
hasattr
(
self
,
"_filters"
):
for
f
in
self
.
_filters
:
f
.
apply
(
self
.
_instances
)
...
...
@@ -504,7 +502,7 @@ class ConfigurableTask(Task):
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
):
# TODO no super() call here
)
->
None
:
# TODO no super() call here
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
...
...
@@ -576,7 +574,6 @@ class ConfigurableTask(Task):
"aggregation"
]
else
:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
metric_agg
=
get_default_aggregation
(
metric_name
)
eval_logger
.
warning
(
...
...
@@ -689,8 +686,7 @@ class ConfigurableTask(Task):
f
'Both target_delimiter and target choice: "
{
choice
}
" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def
download
(
self
,
dataset_kwargs
=
None
):
def
download
(
self
,
dataset_kwargs
=
None
)
->
None
:
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
,
...
...
@@ -782,7 +778,6 @@ class ConfigurableTask(Task):
return
doc
def
doc_to_text
(
self
,
doc
):
if
self
.
prompt
is
not
None
:
doc_to_text
=
self
.
prompt
else
:
...
...
@@ -817,7 +812,6 @@ class ConfigurableTask(Task):
raise
TypeError
def
doc_to_target
(
self
,
doc
:
dict
)
->
Union
[
int
,
str
,
list
]:
if
self
.
prompt
is
not
None
:
doc_to_target
=
self
.
prompt
else
:
...
...
@@ -859,7 +853,6 @@ class ConfigurableTask(Task):
raise
TypeError
def
doc_to_choice
(
self
,
doc
:
Any
)
->
List
[
str
]:
if
self
.
prompt
is
not
None
:
doc_to_choice
=
self
.
prompt
elif
self
.
_config
.
doc_to_choice
is
None
:
...
...
@@ -903,13 +896,11 @@ class ConfigurableTask(Task):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
Union
[
List
[
Instance
],
Instance
]:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
arguments
=
(
self
.
doc_to_target
(
doc
),)
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
choices
=
self
.
doc_to_choice
(
doc
)
target_delimiter
=
self
.
_config
.
target_delimiter
if
self
.
multiple_input
:
...
...
@@ -960,7 +951,6 @@ class ConfigurableTask(Task):
)
def
process_results
(
self
,
doc
,
results
):
if
callable
(
self
.
_config
.
process_results
):
return
self
.
_config
.
process_results
(
doc
,
results
)
...
...
@@ -995,7 +985,6 @@ class ConfigurableTask(Task):
),
}
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
lls
,
is_greedy
=
zip
(
*
results
)
# retrieve choices in List[str] form, to compute choice lengths, etc.
...
...
@@ -1067,7 +1056,6 @@ class ConfigurableTask(Task):
result_dict
[
"acc_mutual_info"
]
=
acc_mutual_info
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
gold
=
self
.
doc_to_target
(
doc
)
if
self
.
_config
.
doc_to_choice
is
not
None
:
# If you set doc_to_choice,
...
...
@@ -1197,7 +1185,7 @@ class PerplexityTask(Task):
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
)
->
str
:
return
""
def
doc_to_target
(
self
,
doc
):
...
...
lm_eval/benchmarks/__init__.py
View file @
8f448eed
...
...
@@ -11,8 +11,7 @@ from lm_eval.api.registry import (
)
def
include_benchmarks
(
task_dir
):
def
include_benchmarks
(
task_dir
:
str
)
->
None
:
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
if
(
subdirs
==
[]
or
subdirs
==
[
"__pycache__"
])
and
(
len
(
file_list
)
>
0
):
for
f
in
file_list
:
...
...
lm_eval/decontamination/archiver.py
View file @
8f448eed
import
os
from
typing
import
Any
import
zstandard
import
json
import
jsonlines
...
...
@@ -9,7 +10,7 @@ import tqdm
from
pathlib
import
Path
def
json_serial
(
obj
)
:
def
json_serial
(
obj
:
Any
)
->
str
:
"""JSON serializer for objects not serializable by default json code"""
if
isinstance
(
obj
,
(
datetime
.
datetime
,)):
...
...
@@ -19,7 +20,7 @@ def json_serial(obj):
# Modified version of lm_dataformat Archive for single file.
class
Archive
:
def
__init__
(
self
,
file_path
,
compression_level
=
3
)
:
def
__init__
(
self
,
file_path
:
str
,
compression_level
:
int
=
3
)
->
None
:
self
.
file_path
=
file_path
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
...
...
@@ -28,7 +29,7 @@ class Archive:
self
.
cctx
=
zstandard
.
ZstdCompressor
(
level
=
compression_level
)
self
.
compressor
=
self
.
cctx
.
stream_writer
(
self
.
fh
)
def
add_data
(
self
,
data
,
meta
=
{}):
def
add_data
(
self
,
data
,
meta
=
{})
->
None
:
self
.
compressor
.
write
(
json
.
dumps
({
"text"
:
data
,
"meta"
:
meta
},
default
=
json_serial
).
encode
(
"UTF-8"
...
...
@@ -36,7 +37,7 @@ class Archive:
+
b
"
\n
"
)
def
commit
(
self
):
def
commit
(
self
)
->
None
:
self
.
compressor
.
flush
(
zstandard
.
FLUSH_FRAME
)
self
.
fh
.
flush
()
self
.
fh
.
close
()
...
...
@@ -44,10 +45,16 @@ class Archive:
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class
Reader
:
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
pass
def
read
(
self
,
file
,
get_meta
=
False
,
autojoin_paragraphs
=
True
,
para_joiner
=
"
\n\n
"
):
def
read
(
self
,
file
,
get_meta
:
bool
=
False
,
autojoin_paragraphs
:
bool
=
True
,
para_joiner
:
str
=
"
\n\n
"
,
):
with
open
(
file
,
"rb"
)
as
fh
:
self
.
fh
=
fh
cctx
=
zstandard
.
ZstdDecompressor
()
...
...
@@ -72,7 +79,7 @@ class Reader:
class
TextArchive
:
def
__init__
(
self
,
file_path
,
mode
=
"rb+"
)
:
def
__init__
(
self
,
file_path
,
mode
:
str
=
"rb+"
)
->
None
:
self
.
file_path
=
file_path
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
...
...
@@ -83,21 +90,21 @@ class TextArchive:
self
.
fh
=
open
(
self
.
file_path
,
mode
)
def
add_data
(
self
,
data
):
def
add_data
(
self
,
data
)
->
None
:
self
.
fh
.
write
(
data
.
encode
(
"UTF-8"
)
+
b
"
\n
"
)
def
commit
(
self
):
def
commit
(
self
)
->
None
:
self
.
fh
.
flush
()
self
.
fh
.
close
()
class
TextReader
:
def
__init__
(
self
,
file_path
):
def
__init__
(
self
,
file_path
)
->
None
:
self
.
file_path
=
file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def
read_tqdm
(
self
,
update_frequency
=
10000
):
def
read_tqdm
(
self
,
update_frequency
:
int
=
10000
):
current_file_position
=
0
line_counter
=
0
with
open
(
self
.
file_path
,
"r"
)
as
fh
,
tqdm
.
tqdm
(
...
...
@@ -149,7 +156,7 @@ class TextReader:
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class
ZStdTextReader
:
def
__init__
(
self
,
file
):
def
__init__
(
self
,
file
)
->
None
:
self
.
file
=
file
def
read_tqdm
(
self
):
...
...
lm_eval/decontamination/decontaminate.py
View file @
8f448eed
...
...
@@ -11,7 +11,7 @@ from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def
get_train_overlap_stub
(
docs
,
ngrams_path
,
ngrams_n_size
):
def
get_train_overlap_stub
(
docs
:
dict
,
ngrams_path
:
str
,
ngrams_n_size
:
str
):
simulated_overlap
=
0.1
contaminated
=
int
(
len
(
docs
)
*
simulated_overlap
)
return
random
.
sample
(
range
(
len
(
docs
)),
contaminated
)
...
...
@@ -25,6 +25,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
...
...
@@ -33,7 +34,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def
get_train_overlap
(
docs_by_task_set
,
ngrams_path
,
limit
)
:
def
get_train_overlap
(
docs_by_task_set
:
dict
,
ngrams_path
:
str
,
limit
:
int
)
->
dict
:
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path
=
os
.
path
.
join
(
ngrams_path
,
"info.json"
)
...
...
@@ -46,7 +47,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
print
(
"Building Lookups..."
)
start
=
time
.
perf_counter
()
def
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
):
def
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
->
str
:
return
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.overlaps"
lookups
=
{}
...
...
lm_eval/decontamination/janitor.py
View file @
8f448eed
import
re
import
string
import
timeit
import
pickle
import
traceback
from
pprint
import
pprint
from
typing
import
Iterator
,
Sequence
,
TypeVar
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...
...
@@ -16,10 +16,12 @@ except Exception:
traceback
.
print_exc
()
JANITOR_CPP
=
False
T
=
TypeVar
(
"T"
)
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def
form_ngrams
(
sequence
,
n
)
:
def
form_ngrams
(
sequence
:
Iterator
[
T
],
n
:
int
)
->
Iterator
[
tuple
[
T
,
...]]
:
history
=
[]
while
n
>
1
:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
...
...
@@ -36,7 +38,7 @@ def form_ngrams(sequence, n):
del
history
[
0
]
def
word_ngrams
(
s
,
n
)
:
def
word_ngrams
(
s
:
str
,
n
:
int
)
->
Iterator
[
str
]
:
"""Splits a string into ngram words"""
tokens
=
s
.
split
()
# not a generator :(
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
...
...
@@ -68,14 +70,14 @@ def word_ngrams(s, n):
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def
split_indices
(
s
)
:
def
split_indices
(
s
:
str
)
->
Iterator
[
tuple
[
str
,
tuple
[
int
,
int
]]]
:
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
"\S+"
,
s
))
def
word_ngrams_indices
(
s
,
n
)
:
def
word_ngrams_indices
(
s
:
str
,
n
:
int
)
->
Iterator
[
tuple
[
str
,
tuple
[
int
,
int
]]]
:
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices
=
split_indices
(
s
)
...
...
@@ -104,16 +106,15 @@ def word_ngrams_indices(s, n):
class
Janitor
:
# FIXME delete_chars: Should anything else go here? Special chars?
def
__init__
(
self
,
ngram_n
=
13
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
,
delete_chars
=
string
.
punctuation
,
):
ngram_n
:
int
=
13
,
window_to_remove
:
int
=
200
,
too_dirty_cutoff
:
int
=
10
,
minimum_slice_length
:
int
=
200
,
delete_chars
:
str
=
string
.
punctuation
,
)
->
None
:
self
.
ngram_n
=
ngram_n
self
.
window_to_remove
=
window_to_remove
self
.
too_dirty_cutoff
=
too_dirty_cutoff
...
...
@@ -135,11 +136,11 @@ class Janitor:
# I/O for saving contamination ngrams
##############
def
save_contamination_ngrams
(
self
,
filename
)
:
def
save_contamination_ngrams
(
self
,
filename
:
str
)
->
None
:
with
open
(
filename
,
"wb"
)
as
fp
:
pickle
.
dump
(
filename
,
fp
)
def
load_contamination_ngrams
(
self
,
filename
)
:
def
load_contamination_ngrams
(
self
,
filename
:
str
)
->
None
:
with
open
(
filename
,
"rb"
)
as
fp
:
self
.
dirt_ngrams
=
pickle
.
load
(
fp
)
...
...
@@ -147,7 +148,7 @@ class Janitor:
# Call these :)
##############
def
register_contaminant
(
self
,
dirt_string
)
:
def
register_contaminant
(
self
,
dirt_string
:
str
)
->
None
:
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if
JANITOR_CPP
:
...
...
@@ -156,7 +157,7 @@ class Janitor:
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
register_contaminant_python
(
dirt_string
)
def
clean
(
self
,
dirty_string
)
:
def
clean
(
self
,
dirty_string
:
str
)
->
list
[
str
]
:
"""Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
...
...
@@ -166,7 +167,9 @@ class Janitor:
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
clean_python
(
dirty_string
)
def
_split_chunks
(
self
,
dirty_string
,
dirty_parts
):
def
_split_chunks
(
self
,
dirty_string
:
str
,
dirty_parts
:
Sequence
[
tuple
]
)
->
list
[
str
]:
clean_chunks
=
[]
splice_idx
=
0
end
=
-
1
...
...
@@ -189,12 +192,12 @@ class Janitor:
# Fast C++
##############
def
register_contaminant_cpp
(
self
,
dirt_string
):
def
register_contaminant_cpp
(
self
,
dirt_string
)
->
None
:
self
.
dirt_ngrams
.
update
(
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
)
)
def
clean_cpp
(
self
,
dirty_string
)
:
def
clean_cpp
(
self
,
dirty_string
:
str
)
->
list
[
str
]
:
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
...
...
@@ -204,15 +207,15 @@ class Janitor:
# Slow python
##############
def
normalize_string
(
self
,
s
)
:
def
normalize_string
(
self
,
s
:
str
)
->
str
:
return
s
.
translate
(
self
.
translation_table
)
def
register_contaminant_python
(
self
,
dirt_string
)
:
def
register_contaminant_python
(
self
,
dirt_string
:
str
)
->
None
:
self
.
dirt_ngrams
.
update
(
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
)
)
def
clean_python
(
self
,
dirty_string
)
:
def
clean_python
(
self
,
dirty_string
:
str
)
->
list
[
str
]
:
contamination_indices
=
(
(
None
,
*
idx_pair
)
for
dirty_ngram
,
idx_pair
in
word_ngrams_indices
(
dirty_string
,
self
.
ngram_n
)
...
...
lm_eval/evaluator.py
View file @
8f448eed
...
...
@@ -42,11 +42,11 @@ def simple_evaluate(
device
=
None
,
use_cache
=
None
,
limit
=
None
,
bootstrap_iters
=
100000
,
check_integrity
=
False
,
bootstrap_iters
:
int
=
100000
,
check_integrity
:
bool
=
False
,
decontamination_ngrams_path
=
None
,
write_out
=
False
,
log_samples
=
True
,
write_out
:
bool
=
False
,
log_samples
:
bool
=
True
,
):
"""Instantiate and evaluate a model on a list of tasks.
...
...
@@ -117,7 +117,6 @@ def simple_evaluate(
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
for
task_name
in
task_dict
.
keys
():
task_obj
=
task_dict
[
task_name
]
if
type
(
task_obj
)
==
tuple
:
group
,
task_obj
=
task_obj
...
...
@@ -175,10 +174,10 @@ def evaluate(
lm
,
task_dict
,
limit
=
None
,
bootstrap_iters
=
100000
,
bootstrap_iters
:
int
=
100000
,
decontamination_ngrams_path
=
None
,
write_out
=
False
,
log_samples
=
True
,
write_out
:
bool
=
False
,
log_samples
:
bool
=
True
,
):
"""Instantiate and evaluate a model on a list of tasks.
...
...
@@ -223,7 +222,6 @@ def evaluate(
# get lists of each type of request
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
group
,
task
=
task
task_groups
[
task_name
]
=
group
...
...
@@ -349,7 +347,6 @@ def evaluate(
# if multigpu, then gather data across all ranks
# first gather logged samples across all ranks
for
task_name
,
task_samples
in
list
(
samples
.
items
()):
full_samples
=
[
None
]
*
lm
.
world_size
torch
.
distributed
.
all_gather_object
(
full_samples
,
task_samples
)
...
...
@@ -358,7 +355,6 @@ def evaluate(
# then collect metrics across all ranks
vals_torch
=
collections
.
defaultdict
(
list
)
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
numitem
=
0
if
type
(
items
[
0
])
==
tuple
:
numitem
=
len
(
items
[
0
])
...
...
lm_eval/filters/decontamination.py
View file @
8f448eed
...
...
@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter):
name
=
"track_decontamination"
def
__init__
(
self
,
path
):
def
__init__
(
self
,
path
)
->
None
:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
...
...
@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter):
"""
self
.
_decontam_results
=
None
def
apply
(
self
,
reps
,
docs
):
def
apply
(
self
,
re
s
ps
,
docs
)
->
None
:
"""
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
"""
...
...
lm_eval/filters/extraction.py
View file @
8f448eed
...
...
@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter
class
RegexFilter
(
Filter
):
""" """
def
__init__
(
self
,
regex_pattern
=
r
"#### (\-?[0-9\.\,]+)"
,
fallback
=
"[invalid]"
):
def
__init__
(
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
fallback
:
str
=
"[invalid]"
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
...
...
@@ -41,12 +43,11 @@ class RegexFilter(Filter):
class
WhitespaceFilter
(
Filter
):
""" """
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
filter_set
(
inst
):
filtered_resp
=
[]
for
resp
in
inst
:
if
resp
.
startswith
(
" "
):
...
...
lm_eval/filters/selection.py
View file @
8f448eed
...
...
@@ -4,7 +4,7 @@ from lm_eval.api.filter import Filter
class
TakeFirstFilter
(
Filter
):
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
...
...
@@ -17,8 +17,7 @@ class TakeFirstFilter(Filter):
class
TakeKFilter
(
Filter
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
k
=
kwargs
.
pop
(
"k"
)
super
().
__init__
(
*
args
,
**
kwargs
)
...
...
@@ -32,7 +31,7 @@ class TakeKFilter(Filter):
class
MajorityVoteFilter
(
Filter
):
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
...
...
lm_eval/models/anthropic_llms.py
View file @
8f448eed
...
...
@@ -76,7 +76,7 @@ class AnthropicLM(LM):
max_tokens_to_sample
:
int
=
256
,
temperature
:
float
=
0
,
# defaults to 1
**
kwargs
,
# top_p, top_k, etc.
):
)
->
None
:
"""Anthropic API wrapper.
:param model: str
...
...
@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def
tok_decode
(
self
,
tokens
:
List
[
int
])
->
str
:
return
self
.
tokenizer
.
decode
(
tokens
)
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
:
bool
=
False
):
raise
NotImplementedError
(
"No support for logits."
)
def
greedy_until
(
self
,
requests
)
->
List
[
str
]:
if
not
requests
:
return
[]
...
...
lm_eval/models/dummy.py
View file @
8f448eed
...
...
@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model
@
register_model
(
"dummy"
)
class
DummyLM
(
LM
):
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
@
classmethod
...
...
lm_eval/models/huggingface.py
View file @
8f448eed
...
...
@@ -94,7 +94,7 @@ class HFLM(LM):
bnb_4bit_compute_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
gptq
:
Optional
[
Union
[
bool
,
str
]]
=
False
,
gptq_use_triton
:
Optional
[
bool
]
=
False
,
):
)
->
None
:
super
().
__init__
()
assert
isinstance
(
device
,
str
)
...
...
@@ -347,7 +347,7 @@ class HFLM(LM):
return
self
.
_DEFAULT_MAX_LENGTH
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
)
->
int
:
return
256
@
property
...
...
@@ -366,7 +366,7 @@ class HFLM(LM):
def
world_size
(
self
):
return
self
.
_world_size
def
_detect_batch_size
(
self
,
requests
=
None
,
pos
=
0
):
def
_detect_batch_size
(
self
,
requests
=
None
,
pos
:
int
=
0
):
if
requests
:
_
,
context_enc
,
continuation_enc
=
requests
[
pos
]
max_length
=
len
(
...
...
@@ -432,11 +432,11 @@ class HFLM(LM):
return
encoding
def
tok_batch_encode
(
self
,
strings
:
List
[
str
],
padding_side
=
"left"
,
left_truncate_len
=
None
,
truncation
=
False
,
self
,
strings
:
List
[
str
],
padding_side
:
str
=
"left"
,
left_truncate_len
:
int
=
None
,
truncation
:
bool
=
False
,
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side
=
self
.
tokenizer
.
padding_side
...
...
@@ -613,7 +613,9 @@ class HFLM(LM):
return
loglikelihoods
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
,
override_bs
=
None
):
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
:
bool
=
False
,
override_bs
=
None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
...
...
lm_eval/models/openai_completions.py
View file @
8f448eed
...
...
@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM):
engine
:
str
=
"text-davinci-003"
,
truncate
:
bool
=
False
,
batch_size
:
int
=
1
,
):
)
->
None
:
"""
:param engine: str
...
...
@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM):
return
self
.
end_of_text_token_id
@
property
def
max_length
(
self
):
def
max_length
(
self
)
->
int
:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return
2048
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
)
->
int
:
return
256
@
property
...
...
@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM):
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
List
[
Tuple
[
float
,
bool
]]:
res
=
[]
...
...
lm_eval/models/textsynth.py
View file @
8f448eed
...
...
@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs):
@
register_model
(
"textsynth"
)
class
TextSynthLM
(
LM
):
def
__init__
(
self
,
engine
,
truncate
=
False
)
:
def
__init__
(
self
,
engine
,
truncate
:
bool
=
False
)
->
None
:
"""
:param engine: str
TextSynth API engine (e.g. `gptj_6B`)
...
...
@@ -62,12 +62,12 @@ class TextSynthLM(LM):
raise
NotImplementedError
()
@
property
def
max_length
(
self
):
def
max_length
(
self
)
->
int
:
# NOTE: Turn on truncation to avoid errors on long inputs.
return
2048
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
)
->
int
:
return
256
@
property
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment