Unverified Commit 65b8761d authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Switch Linting to `ruff` (#1166)

* add ruff and isort. remove black and flake8

* remove unnecessary dependencies

* remove dependency from table

* change order

* ran ruff

* check 3.9

* exclude evaluator

* update CI workflow

* use ruff config in pyproject.toml

* test

* add isort rules to ruff

* sort imports

* import `make_table`

* try stages for no-commit-to-branch

* turn on mypy for pre-commit

* test

* test

* test

* change no-commit-to-branch to default

* nits

* fixed dependency
parent 21d4ae98
import copy
import os
import time
from typing import List, Tuple, Optional
import copy
from collections import defaultdict
from importlib.util import find_spec
from typing import List, Optional, Tuple
from tqdm import tqdm
from lm_eval import utils
......@@ -44,13 +45,13 @@ def oa_completion(**kwargs):
Retry with back-off until they respond
"""
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
if not find_spec("openai") or not find_spec("tiktoken"):
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
"Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
)
else:
import openai
backoff_time = 3
while True:
......@@ -88,7 +89,8 @@ class OpenaiCompletionsLM(LM):
super().__init__()
self.seed = seed
try:
import openai, tiktoken # noqa: E401
import openai # noqa: E401
import tiktoken
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
......@@ -154,8 +156,9 @@ class OpenaiCompletionsLM(LM):
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
......@@ -326,13 +329,13 @@ def oa_chat_completion(client, **kwargs):
Retry with back-off until they respond
"""
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
if not find_spec("openai") or not find_spec("tiktoken"):
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
"Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
)
else:
import openai
async def _get_completions(**kwargs):
chat_completions = await client.chat.completions.create(**kwargs)
......@@ -364,7 +367,8 @@ class OpenaiChatCompletionsLM(LM):
"""
super().__init__()
try:
import openai, tiktoken # noqa: E401
import openai # noqa: E401
import tiktoken
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
......
......@@ -13,9 +13,11 @@ Homepage: https://textsynth.com/index.html
"""
import logging
import os
import requests as _requests
import time
import requests as _requests
from tqdm import tqdm
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
......@@ -149,7 +151,7 @@ class TextSynthLM(LM):
self.cache_hook.add_partial("generate_until", (inp, request_args), s)
else:
logger.error(
f"The following response does not contain generated `text`. "
"The following response does not contain generated `text`. "
"Got:\n{resp}"
)
assert False
......
import copy
from collections import defaultdict
from typing import List, Tuple, Optional, Literal, Union, Any
from transformers import AutoTokenizer
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
import copy
from tqdm import tqdm
from lm_eval.api.registry import register_model
from lm_eval import utils
try:
from vllm import LLM, SamplingParams
from ray.util.multiprocessing import Pool
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError:
pass
......@@ -54,12 +57,10 @@ class VLLM(LM):
):
super().__init__()
try:
import vllm
except ModuleNotFoundError:
if not find_spec("vllm"):
raise Exception(
"attempted to use 'vllm' LM type, but package `vllm` is not installed. \
please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`",
"attempted to use 'vllm' LM type, but package `vllm` is not installed. "
"Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
)
assert "cuda" in device or device is None, "vLLM only supports CUDA"
......@@ -193,8 +194,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
......
......@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
def load_prompt_list(
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
):
category_name, prompt_name = use_prompt.split(":")
if category_name == "promptsource":
......@@ -113,7 +112,6 @@ class PromptString:
self.prompt_string = prompt_string
def apply(self, doc):
doc_to_text = self.prompt_string["doc_to_text"]
doc_to_target = self.prompt_string["doc_to_target"]
......
......@@ -180,7 +180,6 @@ def include_path(task_dir):
def initialize_tasks(verbosity="INFO"):
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
......
......@@ -24,7 +24,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......@@ -37,7 +36,6 @@ if __name__ == "__main__":
dataset_path = "lukaemon/bbh"
for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
resp = requests.get(
f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt"
).content.decode("utf-8")
......
......@@ -23,7 +23,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......
......@@ -173,7 +173,6 @@ all_subtasks = [
def main() -> None:
for path, task_type in zip(
["multiple_choice", "generate_until"],
["multiple_choice_template_yaml", "generate_until_template_yaml"],
......
......@@ -73,7 +73,6 @@ all_subtasks = [
def main() -> None:
for task in all_subtasks:
file_name = f"{task}.yaml"
try:
with open(f"{file_name}", "w") as f:
......
......@@ -75,7 +75,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......@@ -93,7 +92,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None:
description = cot_file[subject_eng]
else:
description = f"以下是中国关于{subject_zh}的单项选择题,请选出其中的正确答案。\n\n"
description = (
f"以下是中国关于{subject_zh}的单项选择题,请选出其中的正确答案。\n\n"
)
yaml_dict = {
"include": base_yaml_name,
......
......@@ -90,7 +90,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......@@ -108,7 +107,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None:
description = cot_file[subject_eng]
else:
description = f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
description = (
f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
)
yaml_dict = {
"include": base_yaml_name,
......
#!/usr/bin/python
import os
import re
import sys
import math
import subprocess
import xml.sax.saxutils
from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
......@@ -65,14 +63,14 @@ def normalize(s):
if type(s) is not str:
s = " ".join(s)
# language-independent part:
for (pattern, replace) in normalize1:
for pattern, replace in normalize1:
s = re.sub(pattern, replace, s)
s = xml.sax.saxutils.unescape(s, {""": '"'})
# language-dependent part (assuming Western languages):
s = " %s " % s
if not preserve_case:
s = s.lower() # this might not be identical to the original
for (pattern, replace) in normalize2:
for pattern, replace in normalize2:
s = re.sub(pattern, replace, s)
return s.split()
......@@ -95,7 +93,7 @@ def cook_refs(refs, n=4):
maxcounts: Dict[Tuple[str], int] = {}
for ref in refs:
counts = count_ngrams(ref, n)
for (ngram, count) in counts.items():
for ngram, count in counts.items():
maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
return ([len(ref) for ref in refs], maxcounts)
......@@ -125,7 +123,7 @@ def cook_test(test, item, n=4):
result["correct"] = [0] * n
counts = count_ngrams(test, n)
for (ngram, count) in counts.items():
for ngram, count in counts.items():
result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
return result
......@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2):
def smoothed_bleu_4(references, predictions, **kwargs):
predictionMap = {}
goldMap = {}
......
def doc_to_text(doc):
inputs = " ".join(doc["code_tokens"]).replace("\n", " ")
inputs = " ".join(inputs.strip().split())
......@@ -7,7 +6,6 @@ def doc_to_text(doc):
def doc_to_target(doc):
targets = " ".join(doc["docstring_tokens"]).replace("\n", "")
targets = " ".join(targets.strip().split())
......
......@@ -7,7 +7,7 @@ def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(
for q, a in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
......@@ -17,7 +17,6 @@ def doc_to_text(doc):
def doc_to_target(doc):
turn_id = len(doc["questions"]["input_text"])
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
......@@ -71,7 +70,6 @@ def compute_scores(gold_list, pred):
def process_results(doc, results):
gold_list = doc_to_target(doc)
pred = results[0].strip().split("\n")[0]
......
......@@ -21,7 +21,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......@@ -30,7 +29,6 @@ if __name__ == "__main__":
base_yaml = yaml.full_load(f)
for name in tqdm(SUBSETS):
yaml_dict = {
"include": base_yaml_name,
"task": f"csatqa_{args.task_prefix}_{name}"
......
......@@ -62,7 +62,6 @@ def parse_answer(answer):
def process_results(doc, results):
preds, golds = results, doc["answers"]
max_em = 0
max_f1 = 0
......
......@@ -78,8 +78,7 @@ INSTRUCTION_CONFLICTS = {
# _KEYWORD + "key_sentences": instructions.KeySentenceChecker,
_KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"},
_KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"},
_LANGUAGE
+ "response_language": {
_LANGUAGE + "response_language": {
_LANGUAGE + "response_language",
_FORMAT + "multiple_sections",
_KEYWORD + "existence",
......@@ -90,16 +89,14 @@ INSTRUCTION_CONFLICTS = {
_CHANGE_CASES + "english_lowercase",
},
_LENGTH + "number_sentences": {_LENGTH + "number_sentences"},
_LENGTH
+ "number_paragraphs": {
_LENGTH + "number_paragraphs": {
_LENGTH + "number_paragraphs",
_LENGTH + "nth_paragraph_first_word",
_LENGTH + "number_sentences",
_LENGTH + "nth_paragraph_first_word",
},
_LENGTH + "number_words": {_LENGTH + "number_words"},
_LENGTH
+ "nth_paragraph_first_word": {
_LENGTH + "nth_paragraph_first_word": {
_LENGTH + "nth_paragraph_first_word",
_LENGTH + "number_paragraphs",
},
......@@ -110,23 +107,20 @@ INSTRUCTION_CONFLICTS = {
# _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph,
_FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()),
_FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"},
_FORMAT
+ "multiple_sections": {
_FORMAT + "multiple_sections": {
_FORMAT + "multiple_sections",
_LANGUAGE + "response_language",
_FORMAT + "number_highlighted_sections",
},
# TODO(tianjianlu): Re-enable rephrasing with preprocessing the message.
# _FORMAT + "rephrase": instructions.RephraseChecker,
_FORMAT
+ "json_format": set(INSTRUCTION_DICT.keys()).difference(
_FORMAT + "json_format": set(INSTRUCTION_DICT.keys()).difference(
{_KEYWORD + "forbidden_words", _KEYWORD + "existence"}
),
_FORMAT + "title": {_FORMAT + "title"},
# TODO(tianjianlu): Re-enable with specific prompts.
# _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker,
_COMBINATION
+ "two_responses": set(INSTRUCTION_DICT.keys()).difference(
_COMBINATION + "two_responses": set(INSTRUCTION_DICT.keys()).difference(
{
_KEYWORD + "forbidden_words",
_KEYWORD + "existence",
......@@ -135,20 +129,17 @@ INSTRUCTION_CONFLICTS = {
_PUNCTUATION + "no_comma",
}
),
_COMBINATION
+ "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference(
_COMBINATION + "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference(
{_KEYWORD + "existence", _FORMAT + "title", _PUNCTUATION + "no_comma"}
),
_STARTEND + "end_checker": {_STARTEND + "end_checker"},
_CHANGE_CASES
+ "capital_word_frequency": {
_CHANGE_CASES + "capital_word_frequency": {
_CHANGE_CASES + "capital_word_frequency",
_CHANGE_CASES + "english_lowercase",
_CHANGE_CASES + "english_capital",
},
_CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"},
_CHANGE_CASES
+ "english_lowercase": {
_CHANGE_CASES + "english_lowercase": {
_CHANGE_CASES + "english_lowercase",
_CHANGE_CASES + "english_capital",
},
......
......@@ -17,7 +17,6 @@
import functools
import random
import re
from typing import List
import immutabledict
import nltk
......
......@@ -94,7 +94,6 @@ LANGUAGES = {
def add_regex_pattern(regex_pattern):
if regex_pattern is None:
return {}
return {
......
......@@ -7,7 +7,6 @@ import argparse
from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
SUBJECTS = {
......@@ -82,7 +81,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our "other" YAMLs.
......@@ -98,7 +96,6 @@ if __name__ == "__main__":
ALL_CATEGORIES = []
for subject, category in tqdm(SUBJECTS.items()):
if category not in ALL_CATEGORIES:
ALL_CATEGORIES.append(category)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment