Unverified Commit 65b8761d authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Switch Linting to `ruff` (#1166)

* add ruff and isort. remove black and flake8

* remove unnecessary dependencies

* remove dependency from table

* change order

* ran ruff

* check 3.9

* exclude evaluator

* update CI workflow

* use ruff config in pyproject.toml

* test

* add isort rules to ruff

* sort imports

* import `make_table`

* try stages for no-commit-to-branch

* turn on mypy for pre-commit

* test

* test

* test

* change no-commit-to-branch to default

* nits

* fixed dependency
parent 21d4ae98
import copy
import os import os
import time import time
from typing import List, Tuple, Optional
import copy
from collections import defaultdict from collections import defaultdict
from importlib.util import find_spec
from typing import List, Optional, Tuple
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
...@@ -44,13 +45,13 @@ def oa_completion(**kwargs): ...@@ -44,13 +45,13 @@ def oa_completion(**kwargs):
Retry with back-off until they respond Retry with back-off until they respond
""" """
try: if not find_spec("openai") or not find_spec("tiktoken"):
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", "Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
) )
else:
import openai
backoff_time = 3 backoff_time = 3
while True: while True:
...@@ -88,7 +89,8 @@ class OpenaiCompletionsLM(LM): ...@@ -88,7 +89,8 @@ class OpenaiCompletionsLM(LM):
super().__init__() super().__init__()
self.seed = seed self.seed = seed
try: try:
import openai, tiktoken # noqa: E401 import openai # noqa: E401
import tiktoken
except ModuleNotFoundError: except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
...@@ -154,8 +156,9 @@ class OpenaiCompletionsLM(LM): ...@@ -154,8 +156,9 @@ class OpenaiCompletionsLM(LM):
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( context_enc, continuation_enc = (
continuation [self.eot_token_id],
self.tok_encode(continuation),
) )
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
...@@ -326,13 +329,13 @@ def oa_chat_completion(client, **kwargs): ...@@ -326,13 +329,13 @@ def oa_chat_completion(client, **kwargs):
Retry with back-off until they respond Retry with back-off until they respond
""" """
try: if not find_spec("openai") or not find_spec("tiktoken"):
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", "Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
) )
else:
import openai
async def _get_completions(**kwargs): async def _get_completions(**kwargs):
chat_completions = await client.chat.completions.create(**kwargs) chat_completions = await client.chat.completions.create(**kwargs)
...@@ -364,7 +367,8 @@ class OpenaiChatCompletionsLM(LM): ...@@ -364,7 +367,8 @@ class OpenaiChatCompletionsLM(LM):
""" """
super().__init__() super().__init__()
try: try:
import openai, tiktoken # noqa: E401 import openai # noqa: E401
import tiktoken
except ModuleNotFoundError: except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
......
...@@ -13,9 +13,11 @@ Homepage: https://textsynth.com/index.html ...@@ -13,9 +13,11 @@ Homepage: https://textsynth.com/index.html
""" """
import logging import logging
import os import os
import requests as _requests
import time import time
import requests as _requests
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
...@@ -149,7 +151,7 @@ class TextSynthLM(LM): ...@@ -149,7 +151,7 @@ class TextSynthLM(LM):
self.cache_hook.add_partial("generate_until", (inp, request_args), s) self.cache_hook.add_partial("generate_until", (inp, request_args), s)
else: else:
logger.error( logger.error(
f"The following response does not contain generated `text`. " "The following response does not contain generated `text`. "
"Got:\n{resp}" "Got:\n{resp}"
) )
assert False assert False
......
import copy
from collections import defaultdict from collections import defaultdict
from typing import List, Tuple, Optional, Literal, Union, Any from importlib.util import find_spec
from transformers import AutoTokenizer from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
import copy
from tqdm import tqdm
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils
try: try:
from vllm import LLM, SamplingParams
from ray.util.multiprocessing import Pool from ray.util.multiprocessing import Pool
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
...@@ -54,12 +57,10 @@ class VLLM(LM): ...@@ -54,12 +57,10 @@ class VLLM(LM):
): ):
super().__init__() super().__init__()
try: if not find_spec("vllm"):
import vllm
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'vllm' LM type, but package `vllm` is not installed. \ "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`", "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
) )
assert "cuda" in device or device is None, "vLLM only supports CUDA" assert "cuda" in device or device is None, "vLLM only supports CUDA"
...@@ -193,8 +194,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -193,8 +194,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( context_enc, continuation_enc = (
continuation [self.eot_token_id],
self.tok_encode(continuation),
) )
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
......
...@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None ...@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
def load_prompt_list( def load_prompt_list(
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
): ):
category_name, prompt_name = use_prompt.split(":") category_name, prompt_name = use_prompt.split(":")
if category_name == "promptsource": if category_name == "promptsource":
...@@ -113,7 +112,6 @@ class PromptString: ...@@ -113,7 +112,6 @@ class PromptString:
self.prompt_string = prompt_string self.prompt_string = prompt_string
def apply(self, doc): def apply(self, doc):
doc_to_text = self.prompt_string["doc_to_text"] doc_to_text = self.prompt_string["doc_to_text"]
doc_to_target = self.prompt_string["doc_to_target"] doc_to_target = self.prompt_string["doc_to_target"]
......
...@@ -180,7 +180,6 @@ def include_path(task_dir): ...@@ -180,7 +180,6 @@ def include_path(task_dir):
def initialize_tasks(verbosity="INFO"): def initialize_tasks(verbosity="INFO"):
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
......
...@@ -24,7 +24,6 @@ def parse_args(): ...@@ -24,7 +24,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -37,7 +36,6 @@ if __name__ == "__main__": ...@@ -37,7 +36,6 @@ if __name__ == "__main__":
dataset_path = "lukaemon/bbh" dataset_path = "lukaemon/bbh"
for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()): for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
resp = requests.get( resp = requests.get(
f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt" f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt"
).content.decode("utf-8") ).content.decode("utf-8")
......
...@@ -23,7 +23,6 @@ def parse_args(): ...@@ -23,7 +23,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......
...@@ -173,7 +173,6 @@ all_subtasks = [ ...@@ -173,7 +173,6 @@ all_subtasks = [
def main() -> None: def main() -> None:
for path, task_type in zip( for path, task_type in zip(
["multiple_choice", "generate_until"], ["multiple_choice", "generate_until"],
["multiple_choice_template_yaml", "generate_until_template_yaml"], ["multiple_choice_template_yaml", "generate_until_template_yaml"],
......
...@@ -73,7 +73,6 @@ all_subtasks = [ ...@@ -73,7 +73,6 @@ all_subtasks = [
def main() -> None: def main() -> None:
for task in all_subtasks: for task in all_subtasks:
file_name = f"{task}.yaml" file_name = f"{task}.yaml"
try: try:
with open(f"{file_name}", "w") as f: with open(f"{file_name}", "w") as f:
......
...@@ -75,7 +75,6 @@ def parse_args(): ...@@ -75,7 +75,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -93,7 +92,9 @@ if __name__ == "__main__": ...@@ -93,7 +92,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None: if args.cot_prompt_path is not None:
description = cot_file[subject_eng] description = cot_file[subject_eng]
else: else:
description = f"以下是中国关于{subject_zh}的单项选择题,请选出其中的正确答案。\n\n" description = (
f"以下是中国关于{subject_zh}的单项选择题,请选出其中的正确答案。\n\n"
)
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
......
...@@ -90,7 +90,6 @@ def parse_args(): ...@@ -90,7 +90,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -108,7 +107,9 @@ if __name__ == "__main__": ...@@ -108,7 +107,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None: if args.cot_prompt_path is not None:
description = cot_file[subject_eng] description = cot_file[subject_eng]
else: else:
description = f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n" description = (
f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
)
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
......
#!/usr/bin/python #!/usr/bin/python
import os
import re import re
import sys import sys
import math import math
import subprocess
import xml.sax.saxutils import xml.sax.saxutils
from typing import List, Pattern, Tuple, Union, Dict, Any, Optional from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
...@@ -65,14 +63,14 @@ def normalize(s): ...@@ -65,14 +63,14 @@ def normalize(s):
if type(s) is not str: if type(s) is not str:
s = " ".join(s) s = " ".join(s)
# language-independent part: # language-independent part:
for (pattern, replace) in normalize1: for pattern, replace in normalize1:
s = re.sub(pattern, replace, s) s = re.sub(pattern, replace, s)
s = xml.sax.saxutils.unescape(s, {""": '"'}) s = xml.sax.saxutils.unescape(s, {""": '"'})
# language-dependent part (assuming Western languages): # language-dependent part (assuming Western languages):
s = " %s " % s s = " %s " % s
if not preserve_case: if not preserve_case:
s = s.lower() # this might not be identical to the original s = s.lower() # this might not be identical to the original
for (pattern, replace) in normalize2: for pattern, replace in normalize2:
s = re.sub(pattern, replace, s) s = re.sub(pattern, replace, s)
return s.split() return s.split()
...@@ -95,7 +93,7 @@ def cook_refs(refs, n=4): ...@@ -95,7 +93,7 @@ def cook_refs(refs, n=4):
maxcounts: Dict[Tuple[str], int] = {} maxcounts: Dict[Tuple[str], int] = {}
for ref in refs: for ref in refs:
counts = count_ngrams(ref, n) counts = count_ngrams(ref, n)
for (ngram, count) in counts.items(): for ngram, count in counts.items():
maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
return ([len(ref) for ref in refs], maxcounts) return ([len(ref) for ref in refs], maxcounts)
...@@ -125,7 +123,7 @@ def cook_test(test, item, n=4): ...@@ -125,7 +123,7 @@ def cook_test(test, item, n=4):
result["correct"] = [0] * n result["correct"] = [0] * n
counts = count_ngrams(test, n) counts = count_ngrams(test, n)
for (ngram, count) in counts.items(): for ngram, count in counts.items():
result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
return result return result
...@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2): ...@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2):
def smoothed_bleu_4(references, predictions, **kwargs): def smoothed_bleu_4(references, predictions, **kwargs):
predictionMap = {} predictionMap = {}
goldMap = {} goldMap = {}
......
def doc_to_text(doc): def doc_to_text(doc):
inputs = " ".join(doc["code_tokens"]).replace("\n", " ") inputs = " ".join(doc["code_tokens"]).replace("\n", " ")
inputs = " ".join(inputs.strip().split()) inputs = " ".join(inputs.strip().split())
...@@ -7,7 +6,6 @@ def doc_to_text(doc): ...@@ -7,7 +6,6 @@ def doc_to_text(doc):
def doc_to_target(doc): def doc_to_target(doc):
targets = " ".join(doc["docstring_tokens"]).replace("\n", "") targets = " ".join(doc["docstring_tokens"]).replace("\n", "")
targets = " ".join(targets.strip().split()) targets = " ".join(targets.strip().split())
......
...@@ -7,7 +7,7 @@ def doc_to_text(doc): ...@@ -7,7 +7,7 @@ def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n" doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest( for q, a in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1] doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai ): # omit target answer ai
question = f"Q: {q}\n\n" question = f"Q: {q}\n\n"
...@@ -17,7 +17,6 @@ def doc_to_text(doc): ...@@ -17,7 +17,6 @@ def doc_to_text(doc):
def doc_to_target(doc): def doc_to_target(doc):
turn_id = len(doc["questions"]["input_text"]) turn_id = len(doc["questions"]["input_text"])
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = [] answers = []
...@@ -71,7 +70,6 @@ def compute_scores(gold_list, pred): ...@@ -71,7 +70,6 @@ def compute_scores(gold_list, pred):
def process_results(doc, results): def process_results(doc, results):
gold_list = doc_to_target(doc) gold_list = doc_to_target(doc)
pred = results[0].strip().split("\n")[0] pred = results[0].strip().split("\n")[0]
......
...@@ -21,7 +21,6 @@ def parse_args(): ...@@ -21,7 +21,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -30,7 +29,6 @@ if __name__ == "__main__": ...@@ -30,7 +29,6 @@ if __name__ == "__main__":
base_yaml = yaml.full_load(f) base_yaml = yaml.full_load(f)
for name in tqdm(SUBSETS): for name in tqdm(SUBSETS):
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
"task": f"csatqa_{args.task_prefix}_{name}" "task": f"csatqa_{args.task_prefix}_{name}"
......
...@@ -62,7 +62,6 @@ def parse_answer(answer): ...@@ -62,7 +62,6 @@ def parse_answer(answer):
def process_results(doc, results): def process_results(doc, results):
preds, golds = results, doc["answers"] preds, golds = results, doc["answers"]
max_em = 0 max_em = 0
max_f1 = 0 max_f1 = 0
......
...@@ -78,8 +78,7 @@ INSTRUCTION_CONFLICTS = { ...@@ -78,8 +78,7 @@ INSTRUCTION_CONFLICTS = {
# _KEYWORD + "key_sentences": instructions.KeySentenceChecker, # _KEYWORD + "key_sentences": instructions.KeySentenceChecker,
_KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"}, _KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"},
_KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"}, _KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"},
_LANGUAGE _LANGUAGE + "response_language": {
+ "response_language": {
_LANGUAGE + "response_language", _LANGUAGE + "response_language",
_FORMAT + "multiple_sections", _FORMAT + "multiple_sections",
_KEYWORD + "existence", _KEYWORD + "existence",
...@@ -90,16 +89,14 @@ INSTRUCTION_CONFLICTS = { ...@@ -90,16 +89,14 @@ INSTRUCTION_CONFLICTS = {
_CHANGE_CASES + "english_lowercase", _CHANGE_CASES + "english_lowercase",
}, },
_LENGTH + "number_sentences": {_LENGTH + "number_sentences"}, _LENGTH + "number_sentences": {_LENGTH + "number_sentences"},
_LENGTH _LENGTH + "number_paragraphs": {
+ "number_paragraphs": {
_LENGTH + "number_paragraphs", _LENGTH + "number_paragraphs",
_LENGTH + "nth_paragraph_first_word", _LENGTH + "nth_paragraph_first_word",
_LENGTH + "number_sentences", _LENGTH + "number_sentences",
_LENGTH + "nth_paragraph_first_word", _LENGTH + "nth_paragraph_first_word",
}, },
_LENGTH + "number_words": {_LENGTH + "number_words"}, _LENGTH + "number_words": {_LENGTH + "number_words"},
_LENGTH _LENGTH + "nth_paragraph_first_word": {
+ "nth_paragraph_first_word": {
_LENGTH + "nth_paragraph_first_word", _LENGTH + "nth_paragraph_first_word",
_LENGTH + "number_paragraphs", _LENGTH + "number_paragraphs",
}, },
...@@ -110,23 +107,20 @@ INSTRUCTION_CONFLICTS = { ...@@ -110,23 +107,20 @@ INSTRUCTION_CONFLICTS = {
# _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph, # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph,
_FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()), _FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()),
_FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"}, _FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"},
_FORMAT _FORMAT + "multiple_sections": {
+ "multiple_sections": {
_FORMAT + "multiple_sections", _FORMAT + "multiple_sections",
_LANGUAGE + "response_language", _LANGUAGE + "response_language",
_FORMAT + "number_highlighted_sections", _FORMAT + "number_highlighted_sections",
}, },
# TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message.
# _FORMAT + "rephrase": instructions.RephraseChecker, # _FORMAT + "rephrase": instructions.RephraseChecker,
_FORMAT _FORMAT + "json_format": set(INSTRUCTION_DICT.keys()).difference(
+ "json_format": set(INSTRUCTION_DICT.keys()).difference(
{_KEYWORD + "forbidden_words", _KEYWORD + "existence"} {_KEYWORD + "forbidden_words", _KEYWORD + "existence"}
), ),
_FORMAT + "title": {_FORMAT + "title"}, _FORMAT + "title": {_FORMAT + "title"},
# TODO(tianjianlu): Re-enable with specific prompts. # TODO(tianjianlu): Re-enable with specific prompts.
# _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker, # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker,
_COMBINATION _COMBINATION + "two_responses": set(INSTRUCTION_DICT.keys()).difference(
+ "two_responses": set(INSTRUCTION_DICT.keys()).difference(
{ {
_KEYWORD + "forbidden_words", _KEYWORD + "forbidden_words",
_KEYWORD + "existence", _KEYWORD + "existence",
...@@ -135,20 +129,17 @@ INSTRUCTION_CONFLICTS = { ...@@ -135,20 +129,17 @@ INSTRUCTION_CONFLICTS = {
_PUNCTUATION + "no_comma", _PUNCTUATION + "no_comma",
} }
), ),
_COMBINATION _COMBINATION + "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference(
+ "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference(
{_KEYWORD + "existence", _FORMAT + "title", _PUNCTUATION + "no_comma"} {_KEYWORD + "existence", _FORMAT + "title", _PUNCTUATION + "no_comma"}
), ),
_STARTEND + "end_checker": {_STARTEND + "end_checker"}, _STARTEND + "end_checker": {_STARTEND + "end_checker"},
_CHANGE_CASES _CHANGE_CASES + "capital_word_frequency": {
+ "capital_word_frequency": {
_CHANGE_CASES + "capital_word_frequency", _CHANGE_CASES + "capital_word_frequency",
_CHANGE_CASES + "english_lowercase", _CHANGE_CASES + "english_lowercase",
_CHANGE_CASES + "english_capital", _CHANGE_CASES + "english_capital",
}, },
_CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"}, _CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"},
_CHANGE_CASES _CHANGE_CASES + "english_lowercase": {
+ "english_lowercase": {
_CHANGE_CASES + "english_lowercase", _CHANGE_CASES + "english_lowercase",
_CHANGE_CASES + "english_capital", _CHANGE_CASES + "english_capital",
}, },
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import functools import functools
import random import random
import re import re
from typing import List
import immutabledict import immutabledict
import nltk import nltk
......
...@@ -94,7 +94,6 @@ LANGUAGES = { ...@@ -94,7 +94,6 @@ LANGUAGES = {
def add_regex_pattern(regex_pattern): def add_regex_pattern(regex_pattern):
if regex_pattern is None: if regex_pattern is None:
return {} return {}
return { return {
......
...@@ -7,7 +7,6 @@ import argparse ...@@ -7,7 +7,6 @@ import argparse
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
SUBJECTS = { SUBJECTS = {
...@@ -82,7 +81,6 @@ def parse_args(): ...@@ -82,7 +81,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our "other" YAMLs. # get filename of base_yaml so we can `"include": ` it in our "other" YAMLs.
...@@ -98,7 +96,6 @@ if __name__ == "__main__": ...@@ -98,7 +96,6 @@ if __name__ == "__main__":
ALL_CATEGORIES = [] ALL_CATEGORIES = []
for subject, category in tqdm(SUBJECTS.items()): for subject, category in tqdm(SUBJECTS.items()):
if category not in ALL_CATEGORIES: if category not in ALL_CATEGORIES:
ALL_CATEGORIES.append(category) ALL_CATEGORIES.append(category)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment