Unverified Commit 65b8761d authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Switch Linting to `ruff` (#1166)

* add ruff and isort. remove black and flake8

* remove unnecessary dependencies

* remove dependency from table

* change order

* ran ruff

* check 3.9

* exclude evaluator

* update CI workflow

* use ruff config in pyproject.toml

* test

* add isort rules to ruff

* sort imports

* import `make_table`

* try stages for no-commit-to-branch

* turn on mypy for pre-commit

* test

* test

* test

* change no-commit-to-branch to default

* nits

* fixed dependency
parent 21d4ae98
import yaml import yaml
import inspect
import datasets import datasets
from tqdm import tqdm from tqdm import tqdm
def main() -> None: def main() -> None:
dataset_path = "EleutherAI/advanced_ai_risk" dataset_path = "EleutherAI/advanced_ai_risk"
for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()): for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
file_name = f"{task}.yaml" file_name = f"{task}.yaml"
......
import yaml import yaml
import inspect
import datasets import datasets
from tqdm import tqdm from tqdm import tqdm
def main() -> None: def main() -> None:
dataset_path = "EleutherAI/persona" dataset_path = "EleutherAI/persona"
for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()): for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
file_name = f"{task}.yaml" file_name = f"{task}.yaml"
......
import argparse import argparse
from typing import Dict, List
import yaml import yaml
......
def doc_to_text(doc) -> str: def doc_to_text(doc) -> str:
ctxs = "\n".join(doc["CONTEXTS"]) ctxs = "\n".join(doc["CONTEXTS"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format( return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["QUESTION"], doc["final_decision"] ctxs,
doc["QUESTION"],
) )
...@@ -3,7 +3,6 @@ from functools import partial ...@@ -3,7 +3,6 @@ from functools import partial
def process_docs(dataset, set_answer_type="bool"): def process_docs(dataset, set_answer_type="bool"):
FEATURES = ["title", "abstract", "question", "answer", "answer_type"] FEATURES = ["title", "abstract", "question", "answer", "answer_type"]
def _categorise_answer(answer_blob): def _categorise_answer(answer_blob):
......
...@@ -235,7 +235,6 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask): ...@@ -235,7 +235,6 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
} }
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
......
...@@ -14,7 +14,6 @@ also determine when no answer is supported by the paragraph and abstain from ans ...@@ -14,7 +14,6 @@ also determine when no answer is supported by the paragraph and abstain from ans
Homepage: https://rajpurkar.github.io/SQuAD-explorer/ Homepage: https://rajpurkar.github.io/SQuAD-explorer/
""" """
import datasets import datasets
from evaluate import load
from math import exp from math import exp
from functools import partial from functools import partial
...@@ -120,14 +119,14 @@ class SQuAD2(Task): ...@@ -120,14 +119,14 @@ class SQuAD2(Task):
doc=doc, doc=doc,
arguments=(ctx, {"until": ["\n"]}), arguments=(ctx, {"until": ["\n"]}),
idx=0, idx=0,
**kwargs **kwargs,
), ),
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " " + "unanswerable"), arguments=(ctx, " " + "unanswerable"),
idx=0, idx=0,
**kwargs **kwargs,
), ),
] ]
......
...@@ -2,7 +2,6 @@ import sklearn.metrics ...@@ -2,7 +2,6 @@ import sklearn.metrics
def mean_3class_f1(predictions, references): # This is a passthrough function def mean_3class_f1(predictions, references): # This is a passthrough function
string_label = ["entailment", "contradiction", "neutral"] string_label = ["entailment", "contradiction", "neutral"]
predictions = ( predictions = (
string_label.index(predictions[0]) if predictions[0] in string_label else 0 string_label.index(predictions[0]) if predictions[0] in string_label else 0
...@@ -13,7 +12,6 @@ def mean_3class_f1(predictions, references): # This is a passthrough function ...@@ -13,7 +12,6 @@ def mean_3class_f1(predictions, references): # This is a passthrough function
def agg_mean_3class_f1(items): def agg_mean_3class_f1(items):
predictions, references = zip(*items) predictions, references = zip(*items)
"""Computes the unweighted average of the F1 per class.""" """Computes the unweighted average of the F1 per class."""
......
...@@ -5,7 +5,6 @@ import sklearn.metrics ...@@ -5,7 +5,6 @@ import sklearn.metrics
def f1(predictions, references): # This is a passthrough function def f1(predictions, references): # This is a passthrough function
_prediction = predictions[0] _prediction = predictions[0]
_reference = references[0].split("_")[-1] _reference = references[0].split("_")[-1]
string_label = ["False", "True"] string_label = ["False", "True"]
...@@ -20,7 +19,6 @@ def f1(predictions, references): # This is a passthrough function ...@@ -20,7 +19,6 @@ def f1(predictions, references): # This is a passthrough function
def agg_f1(items): def agg_f1(items):
predictions, references = zip(*items) predictions, references = zip(*items)
references, predictions = np.asarray(references), np.asarray(predictions) references, predictions = np.asarray(references), np.asarray(predictions)
...@@ -28,7 +26,6 @@ def agg_f1(items): ...@@ -28,7 +26,6 @@ def agg_f1(items):
def em(predictions, references): # This is a passthrough function def em(predictions, references): # This is a passthrough function
_prediction = predictions[0] _prediction = predictions[0]
_group, _reference = references[0].split("_") _group, _reference = references[0].split("_")
string_label = ["False", "True"] string_label = ["False", "True"]
......
...@@ -3,14 +3,12 @@ import string ...@@ -3,14 +3,12 @@ import string
import collections import collections
import numpy as np import numpy as np
from tqdm import tqdm from datasets import Dataset
from datasets import Dataset, concatenate_datasets
from lm_eval.api.metrics import metric_max_over_ground_truths from lm_eval.api.metrics import metric_max_over_ground_truths
def doc_to_text(doc): def doc_to_text(doc):
passage = doc["passage"] passage = doc["passage"]
passage = re.sub(r"(\.|\?|\!|\"|\')\n@highlight\n", r"\1 ", passage) passage = re.sub(r"(\.|\?|\!|\"|\')\n@highlight\n", r"\1 ", passage)
passage = re.sub(r"\n@highlight\n", ". ", passage) passage = re.sub(r"\n@highlight\n", ". ", passage)
...@@ -34,7 +32,6 @@ def process_docs(dataset): ...@@ -34,7 +32,6 @@ def process_docs(dataset):
} }
answers = doc.pop("answers") answers = doc.pop("answers")
for idx, answer in enumerate(answers): for idx, answer in enumerate(answers):
for key in split_doc.keys(): for key in split_doc.keys():
if key in doc: if key in doc:
split_doc[key].append(doc[key]) split_doc[key].append(doc[key])
......
...@@ -8,7 +8,6 @@ def doc_to_text(x): ...@@ -8,7 +8,6 @@ def doc_to_text(x):
def _wsc_inputs(x): def _wsc_inputs(x):
words = x["text"].split(" ") words = x["text"].split(" ")
# We would need some special logic to handle the case where the pronoun is the # We would need some special logic to handle the case where the pronoun is the
...@@ -55,7 +54,6 @@ def _wsc_inputs(x): ...@@ -55,7 +54,6 @@ def _wsc_inputs(x):
class WSCPostprocess(Filter): class WSCPostprocess(Filter):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.determiners = { self.determiners = {
"a", "a",
"an", "an",
...@@ -86,10 +84,8 @@ class WSCPostprocess(Filter): ...@@ -86,10 +84,8 @@ class WSCPostprocess(Filter):
return " ".join([w for w in s.split(" ") if w not in self.determiners]) return " ".join([w for w in s.split(" ") if w not in self.determiners])
def apply(self, resps, docs): def apply(self, resps, docs):
filtered_resps = [] filtered_resps = []
for prediction, reference in zip(*(resps, docs["span1_text"])): for prediction, reference in zip(*(resps, docs["span1_text"])):
prediction = self.clean(prediction[0]) prediction = self.clean(prediction[0])
reference = self.clean(reference) reference = self.clean(reference)
......
import argparse import argparse
from typing import Dict, List
import yaml import yaml
import sacrebleu
try: try:
import pycountry import pycountry
......
...@@ -6,7 +6,6 @@ from rouge_score import rouge_scorer, scoring ...@@ -6,7 +6,6 @@ from rouge_score import rouge_scorer, scoring
def process_results_mc2(doc, results): def process_results_mc2(doc, results):
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# Split on the first `0` as everything before it is true (`1`). # Split on the first `0` as everything before it is true (`1`).
...@@ -20,7 +19,6 @@ def process_results_mc2(doc, results): ...@@ -20,7 +19,6 @@ def process_results_mc2(doc, results):
def process_docs_gen(dataset: datasets.Dataset) -> datasets.Dataset: def process_docs_gen(dataset: datasets.Dataset) -> datasets.Dataset:
return dataset.map(preprocess_function) return dataset.map(preprocess_function)
...@@ -49,7 +47,6 @@ def preprocess_function(examples): ...@@ -49,7 +47,6 @@ def preprocess_function(examples):
def process_results_gen(doc, results): def process_results_gen(doc, results):
completion = results[0] completion = results[0]
true_refs, false_refs = doc["correct_answers"], doc["incorrect_answers"] true_refs, false_refs = doc["correct_answers"], doc["incorrect_answers"]
all_refs = true_refs + false_refs all_refs = true_refs + false_refs
......
import argparse import argparse
from typing import Dict, List
import yaml import yaml
......
import os import collections
import re import fnmatch
import sys import functools
import yaml import gc
import importlib.util
import inspect import inspect
import logging
import os
import pathlib import pathlib
import functools import re
import subprocess import subprocess
import collections import sys
import importlib.util from itertools import islice
import fnmatch from typing import Any, Callable, Iterator, List, Literal, Union
from typing import Iterator, List, Literal, Union, Any, Callable
import gc
import torch import torch
import transformers import transformers
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice
import logging
logging.basicConfig( logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
...@@ -143,7 +141,7 @@ class MultiChoice: ...@@ -143,7 +141,7 @@ class MultiChoice:
def __contains__(self, values) -> bool: def __contains__(self, values) -> bool:
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info(f"Available tasks to choose:") eval_logger.info("Available tasks to choose:")
for choice in self.choices: for choice in self.choices:
eval_logger.info(f" - {choice}") eval_logger.info(f" - {choice}")
raise ValueError("'{}' is not in task list".format(value)) raise ValueError("'{}' is not in task list".format(value))
...@@ -157,7 +155,7 @@ class MultiChoice: ...@@ -157,7 +155,7 @@ class MultiChoice:
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns, source_list):
if type(patterns) == str: if isinstance(patterns, str):
patterns = [patterns] patterns = [patterns]
task_names = set() task_names = set()
...@@ -332,7 +330,7 @@ class Grouper: ...@@ -332,7 +330,7 @@ class Grouper:
def make_table(result_dict, column: str = "results"): def make_table(result_dict, column: str = "results"):
"""Generate table of results.""" """Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import LatexTableWriter, MarkdownTableWriter
if column == "results": if column == "results":
column_name = "Tasks" column_name = "Tasks"
...@@ -466,7 +464,7 @@ def import_function(loader, node): ...@@ -466,7 +464,7 @@ def import_function(loader, node):
yaml_path = os.path.dirname(loader.name) yaml_path = os.path.dirname(loader.name)
*module_name, function_name = function_name.split(".") *module_name, function_name = function_name.split(".")
if type(module_name) == list: if isinstance(module_name, list):
module_name = ".".join(module_name) module_name = ".".join(module_name)
module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name))) module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
...@@ -496,7 +494,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): ...@@ -496,7 +494,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
include_path = yaml_config["include"] include_path = yaml_config["include"]
del yaml_config["include"] del yaml_config["include"]
if type(include_path) == str: if isinstance(include_path, str):
include_path = [include_path] include_path = [include_path]
# Load from the last one first # Load from the last one first
......
...@@ -9,21 +9,19 @@ warn_unused_ignores = True ...@@ -9,21 +9,19 @@ warn_unused_ignores = True
warn_redundant_casts = True warn_redundant_casts = True
# We ignore errors everywhere to gradually add type annotations # We ignore errors everywhere to gradually add type annotations
# [mypy-lm_eval.*]
[mypy-lm_eval.*] # ignore_errors = True
ignore_errors = True #
# [mypy-lm_eval.api.*]
[mypy-lm_eval.api.*] # ignore_errors = True
ignore_errors = True #
# [mypy-lm_eval.prompts.*]
[mypy-lm_eval.prompts.*] # ignore_errors = True
ignore_errors = True #
# [mypy-lm_eval.models.*]
[mypy-lm_eval.models.*] # ignore_errors = True
ignore_errors = True #
# [mypy-scripts.*]
[mypy-scripts.*] # ignore_errors = True
ignore_errors = True #
# [mypy-main]
[mypy-main]
ignore_errors = True
...@@ -54,14 +54,7 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness" ...@@ -54,14 +54,7 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness" Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[project.optional-dependencies] [project.optional-dependencies]
dev = ["black", "flake8", "pre-commit", "pytest", "pytest-cov"] dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
linting = [
"flake8",
"pylint",
"mypy",
"pre-commit",
]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"]
sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"] sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
...@@ -88,3 +81,17 @@ all = [ ...@@ -88,3 +81,17 @@ all = [
"lm_eval[ifeval]", "lm_eval[ifeval]",
"lm_eval[zeno]", "lm_eval[zeno]",
] ]
[tool.ruff]
extend-exclude = ["lm_eval/evaluator.py", "lm_eval/tasks/*.py"]
[tool.ruff.lint]
extend-select = ["I"]
[tool.ruff.isort]
lines-after-imports = 2
known-first-party = ["lm_eval"]
[tool.ruff.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403","I"]
"lm_eval/tasks/*"= ["E721"]
import os
import yaml
import argparse import argparse
import os
from tqdm import tqdm import yaml
from promptsource.templates import DatasetTemplates from promptsource.templates import DatasetTemplates
from tqdm import tqdm
from lm_eval import utils
# from lm_eval.api.registry import ALL_TASKS # from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
# from lm_eval.tasks import include_task_folder # from lm_eval.tasks import include_task_folder
...@@ -22,7 +21,6 @@ def parse_args(): ...@@ -22,7 +21,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
with open(args.benchmark_path) as file: with open(args.benchmark_path) as file:
......
import glob
import argparse import argparse
import glob
import logging
import os import os
import subprocess
import shutil import shutil
import subprocess
from tqdm import tqdm from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool from tqdm_multiprocess import TqdmMultiProcessPool
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -35,7 +35,7 @@ def compress_and_move(working_directory, output_directory, process_count): ...@@ -35,7 +35,7 @@ def compress_and_move(working_directory, output_directory, process_count):
tasks = [] tasks = []
bucket_file_paths = glob.glob( bucket_file_paths = glob.glob(
os.path.join(working_directory, "output", f"*.bkt.txt.sorted") os.path.join(working_directory, "output", "*.bkt.txt.sorted")
) )
for bucket_file_path in bucket_file_paths: for bucket_file_path in bucket_file_paths:
task = (process_task, (working_directory, output_directory, bucket_file_path)) task = (process_task, (working_directory, output_directory, bucket_file_path))
......
...@@ -21,22 +21,22 @@ Arguments ...@@ -21,22 +21,22 @@ Arguments
""" """
import argparse import argparse
import glob
import json import json
import pickle import logging
import os import os
import pickle
import signal
import sys import sys
from pathlib import Path from pathlib import Path
import glob
import signal
from signal import SIGINT from signal import SIGINT
from tqdm import tqdm from tqdm import tqdm
from tqdm_multiprocess.logger import setup_logger_tqdm
from lm_eval.decontamination.archiver import Reader, TextArchive
from lm_eval.decontamination.janitor import Janitor, word_ngrams from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -89,7 +89,7 @@ class Buckets: ...@@ -89,7 +89,7 @@ class Buckets:
os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets) os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
] ]
self.buckets = list(map(TextArchive, self.bucket_files)) self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt") self.checkpoint_file = os.path.join(directory, "bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file): if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb")) self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
...@@ -119,7 +119,6 @@ class Buckets: ...@@ -119,7 +119,6 @@ class Buckets:
def do_ngrams_in_buckets(n_value, working_directory, bucket_count): def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r")) pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"] pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"] start_offsets = pile_statistics["File Start Offsets"]
...@@ -130,13 +129,13 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count): ...@@ -130,13 +129,13 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
logger.info(f"Generating {n_value}-grams and bucketing.") logger.info(f"Generating {n_value}-grams and bucketing.")
# Done file # Done file
done_file = os.path.join(output_directory, f"ngram_buckets.done") done_file = os.path.join(output_directory, "ngram_buckets.done")
if os.path.exists(done_file): if os.path.exists(done_file):
logger.info("ngrams already generated and bucketed, skipping") logger.info("ngrams already generated and bucketed, skipping")
return return
# Checkpoint # Checkpoint
checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt") checkpoint_file = os.path.join(working_directory, "pile_offset.ckpt")
if os.path.exists(checkpoint_file): if os.path.exists(checkpoint_file):
checkpoint_offset = pickle.load(open(checkpoint_file, "rb")) checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
iterate = True iterate = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment