Unverified Commit 1d8107bf authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #362 from EleutherAI/cleanup-for-release

Cleanup `README.md` and package deps
parents fdd3dbc3 1e5d55d9
...@@ -32,7 +32,9 @@ jobs: ...@@ -32,7 +32,9 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install flake8 pytest pytest-cov pip install flake8 pytest pytest-cov
pip install -e .[dev] pip install -e .[dev,multilingual]
# Install optional git dependencies
pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
......
This diff is collapsed.
This diff is collapsed.
...@@ -16,6 +16,20 @@ from lm_eval import metrics ...@@ -16,6 +16,20 @@ from lm_eval import metrics
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from typing import List from typing import List
try:
import nagisa
HAS_NAGISA = True
except ImportError:
HAS_NAGISA = False
try:
import jieba
HAS_JIEBA = True
except ImportError:
HAS_JIEBA = False
_CITATION = """ _CITATION = """
@inproceedings{post-2018-call, @inproceedings{post-2018-call,
...@@ -63,14 +77,22 @@ def create_tasks_from_benchmarks(benchmark_dict): ...@@ -63,14 +77,22 @@ def create_tasks_from_benchmarks(benchmark_dict):
def zh_split(zh_text: List[str]) -> List[str]: def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting""" """Chinese splitting"""
import jieba if not HAS_JIEBA:
raise ImportError(
"Chinese text splitting requires the `jieba` package. "
"Please install it with:\npip install jieba"
)
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text] return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
def ja_split(ja_text: List[str]) -> List[str]: def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting""" """Japanese splitting"""
import nagisa if not HAS_NAGISA:
raise ImportError(
"Japanese text splitting requires the `nagisa` package. "
"Please install it with:\npip install nagisa"
)
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text] return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
......
...@@ -27,6 +27,14 @@ from lm_eval.base import rf, Task ...@@ -27,6 +27,14 @@ from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
try:
import bleurt
HAS_BLEURT = True
except ImportError:
HAS_BLEURT = False
_CITATION = """ _CITATION = """
@misc{lin2021truthfulqa, @misc{lin2021truthfulqa,
title={TruthfulQA: Measuring How Models Mimic Human Falsehoods}, title={TruthfulQA: Measuring How Models Mimic Human Falsehoods},
...@@ -164,6 +172,12 @@ class TruthfulQAGeneration(Task): ...@@ -164,6 +172,12 @@ class TruthfulQAGeneration(Task):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if not HAS_BLEURT:
raise ImportError(
"`TruthfulQAGeneration` requires the `bleurt` package. Please install it with:\n"
"pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
"\nWARNING: Installing any other version of bleurt may result in different results."
)
self.bleurt = datasets.load_metric("bleurt") self.bleurt = datasets.load_metric("bleurt")
def has_training_docs(self): def has_training_docs(self):
......
...@@ -5,7 +5,6 @@ import collections ...@@ -5,7 +5,6 @@ import collections
import functools import functools
import inspect import inspect
import sys import sys
import pytest
from typing import List from typing import List
...@@ -187,6 +186,8 @@ def run_task_tests(task_list: List[str]): ...@@ -187,6 +186,8 @@ def run_task_tests(task_list: List[str]):
""" """
Find the package root and run the tests for the given tasks Find the package root and run the tests for the given tasks
""" """
import pytest
package_root = find_test_root(start_path=pathlib.Path(__file__)) package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = " or ".join(task_list) task_string = " or ".join(task_list)
args = [ args = [
......
"""
Usage:
python make_table_tasks.py --output <markdown_filename>
"""
import argparse
import logging
from lm_eval import tasks from lm_eval import tasks
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]
values = [] logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def chk(tf): def check(tf):
if tf: if tf:
return "✓" return "✓"
else: else:
return " " return " "
for tname, Task in tasks.TASK_REGISTRY.items(): if __name__ == "__main__":
task = Task() parser = argparse.ArgumentParser()
parser.add_argument("--output", type=str, default="task_table.md")
v = [ args = parser.parse_args()
tname,
chk(task.has_training_docs()), writer = MarkdownTableWriter()
chk(task.has_validation_docs()), writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]
chk(task.has_test_docs()), values = []
len(list(task.test_docs() if task.has_test_docs() else task.validation_docs())),
", ".join(task.aggregation().keys()), tasks = tasks.TASK_REGISTRY.items()
] tasks = sorted(tasks, key=lambda x: x[0])
print(v) for tname, Task in tasks:
values.append(v) task = Task()
v = [
writer.value_matrix = values tname,
check(task.has_training_docs()),
print(writer.dumps()) check(task.has_validation_docs()),
check(task.has_test_docs()),
len(
list(
task.test_docs() if task.has_test_docs() else task.validation_docs()
)
),
", ".join(task.aggregation().keys()),
]
logger.info(v)
values.append(v)
writer.value_matrix = values
table = writer.dumps()
with open(args.output, "w") as f:
f.write(table)
...@@ -14,6 +14,7 @@ setuptools.setup( ...@@ -14,6 +14,7 @@ setuptools.setup(
url="https://github.com/EleutherAI/lm-evaluation-harness", url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
...@@ -21,29 +22,23 @@ setuptools.setup( ...@@ -21,29 +22,23 @@ setuptools.setup(
python_requires=">=3.6", python_requires=">=3.6",
install_requires=[ install_requires=[
"datasets>=2.0.0", "datasets>=2.0.0",
"click>=7.1", "jsonlines",
"numexpr",
"openai>=0.6.4",
"pybind11>=2.6.2",
"pycountry",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu==1.5.0",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.7", "torch>=1.7",
"tqdm-multiprocess",
"transformers>=4.1", "transformers>=4.1",
"sqlitedict==1.6.0", "zstandard",
"pytablewriter==0.58.0",
"sacrebleu==1.5.0",
"rouge-score==0.0.4",
"pycountry==20.7.3",
"numexpr>=2.7.2",
"lm_dataformat==0.0.20",
"pybind11==2.6.2",
"tqdm-multiprocess==0.0.11",
"zstandard==0.15.2",
"jsonlines==2.0.0",
"mock==4.0.3",
"openai==0.6.4",
"jieba==0.42.1",
"nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
],
dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
], ],
extras_require={"dev": ["pytest", "black", "pre-commit"]}, extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
},
) )
...@@ -258,8 +258,9 @@ def textsynth_mock_completion(**kwargs): ...@@ -258,8 +258,9 @@ def textsynth_mock_completion(**kwargs):
import requests import requests
os.makedirs("tests/testdata", exist_ok=True) os.makedirs("tests/testdata", exist_ok=True)
hash_kwargs = {k: v for k, v in kwargs.items() if k != "headers"}
hash = hashlib.sha256( hash = hashlib.sha256(
json.dumps(kwargs, sort_keys=True).encode("utf-8") json.dumps(hash_kwargs, sort_keys=True).encode("utf-8")
).hexdigest() ).hexdigest()
fname = f"tests/testdata/textsynth_test_{hash}.pkl" fname = f"tests/testdata/textsynth_test_{hash}.pkl"
......
...@@ -7,10 +7,7 @@ from itertools import islice ...@@ -7,10 +7,7 @@ from itertools import islice
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class): def test_basic_interface(taskname, task_class):
print("Evaluating task", taskname) print("Evaluating task", taskname)
# dl = task_class.download
# task_class.download = MagicMock()
task = task_class() task = task_class()
# task_class.download = dl
assert task.has_training_docs() in [True, False] assert task.has_training_docs() in [True, False]
assert task.has_validation_docs() in [True, False] assert task.has_validation_docs() in [True, False]
......
...@@ -51,7 +51,7 @@ def flatten(d, parent_key="", sep="."): ...@@ -51,7 +51,7 @@ def flatten(d, parent_key="", sep="."):
items = [] items = []
for k, v in d.items(): for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping): if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items()) items.extend(flatten(v, new_key, sep=sep).items())
else: else:
items.append((new_key, v)) items.append((new_key, v))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment