Unverified Commit d27c0c08 authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Apply code autoformatting with Ruff to tasks/*.py an *__init__.py (#1469)

parent f78e2da4
import re import re
from abc import abstractmethod
from functools import reduce
import numpy as np import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from abc import abstractmethod
from datasets import load_metric from datasets import load_metric
from transformers import AutoTokenizer from transformers import AutoTokenizer
from functools import reduce
from lm_eval.api.task import Task
from lm_eval.api.metrics import mean
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task from lm_eval.api.metrics import mean
from lm_eval.api.task import Task
_CITATION = """ _CITATION = """
@inproceedings{shaham-etal-2022-scrolls, @inproceedings{shaham-etal-2022-scrolls,
...@@ -44,6 +44,7 @@ _CITATION = """ ...@@ -44,6 +44,7 @@ _CITATION = """
def _download_metric(): def _download_metric():
import os import os
import shutil import shutil
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
scrolls_metric_path = hf_hub_download( scrolls_metric_path = hf_hub_download(
...@@ -148,7 +149,7 @@ class _SCROLLSTask(Task): ...@@ -148,7 +149,7 @@ class _SCROLLSTask(Task):
del self.dataset["test"] del self.dataset["test"]
for split in self.dataset: for split in self.dataset:
self.dataset[split] = _drop_duplicates_in_input(self.dataset[split]) self.dataset[split] = _drop_duplicates_in_input(self.dataset[split])
if self.PRUNE_TOKENIZERS is not None and self.PRUNE_TOKENIZERS is not None: if self.PRUNE_TOKENIZERS is not None:
self.prune() self.prune()
def _get_prune_text(self, sample): def _get_prune_text(self, sample):
......
...@@ -13,14 +13,15 @@ also determine when no answer is supported by the paragraph and abstain from ans ...@@ -13,14 +13,15 @@ also determine when no answer is supported by the paragraph and abstain from ans
Homepage: https://rajpurkar.github.io/SQuAD-explorer/ Homepage: https://rajpurkar.github.io/SQuAD-explorer/
""" """
import datasets
from math import exp
from functools import partial from functools import partial
from math import exp
import datasets
from packaging import version from packaging import version
from lm_eval.api.task import ConfigurableTask
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.task import ConfigurableTask
_CITATION = """ _CITATION = """
@misc{rajpurkar2018know, @misc{rajpurkar2018know,
...@@ -35,7 +36,6 @@ _CITATION = """ ...@@ -35,7 +36,6 @@ _CITATION = """
def _squad_metric(predictions, references): def _squad_metric(predictions, references):
# squad_metric = load("squad_v2")
squad_metric = datasets.load_metric("squad_v2") squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references) return squad_metric.compute(predictions=predictions, references=references)
...@@ -52,7 +52,7 @@ class SQuAD2(ConfigurableTask): ...@@ -52,7 +52,7 @@ class SQuAD2(ConfigurableTask):
DATASET_NAME = None DATASET_NAME = None
def __init__(self): def __init__(self):
super().__init__(config={'metadata': {'version': self.VERSION}}) super().__init__(config={"metadata": {"version": self.VERSION}})
# HF changed squad on us so we have to make sure we aren't running the old one # HF changed squad on us so we have to make sure we aren't running the old one
assert version.parse(datasets.__version__) >= version.parse( assert version.parse(datasets.__version__) >= version.parse(
......
import sklearn
import numpy as np import numpy as np
import sklearn
def cb_multi_fi(items): def cb_multi_fi(items):
......
import collections
import re import re
import string import string
import collections
import numpy as np
import numpy as np
from datasets import Dataset from datasets import Dataset
from lm_eval.api.metrics import metric_max_over_ground_truths from lm_eval.api.metrics import metric_max_over_ground_truths
......
import re import re
from typing import List from typing import List
def doc_to_text(x): def doc_to_text(x):
text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x)) text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x))
return "wsc: " + text return "wsc: " + text
...@@ -23,14 +24,14 @@ def _wsc_inputs(x): ...@@ -23,14 +24,14 @@ def _wsc_inputs(x):
[ [
" ".join(words[:pronoun_index]), " ".join(words[:pronoun_index]),
"X", "X",
" ".join(words[pronoun_index + 1:]), " ".join(words[pronoun_index + 1 :]),
] ]
) )
# Handle some special cases. # Handle some special cases.
if ( if (
x["text"] x["text"]
== 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. ' == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
): ):
return ( return (
"The boy continued to whip the pony , and eventually the pony threw " "The boy continued to whip the pony , and eventually the pony threw "
...@@ -39,8 +40,8 @@ def _wsc_inputs(x): ...@@ -39,8 +40,8 @@ def _wsc_inputs(x):
# Using the span2_index, we get 'use' instead of 'it'. # Using the span2_index, we get 'use' instead of 'it'.
if ( if (
x["text"] x["text"]
== "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?" == "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
): ):
return ( return (
"When they had eventually calmed down a bit , and had gotten home, " "When they had eventually calmed down a bit , and had gotten home, "
......
import datasets import datasets
import sacrebleu
import numpy as np import numpy as np
import sacrebleu
from rouge_score import rouge_scorer, scoring from rouge_score import rouge_scorer, scoring
......
...@@ -51,7 +51,9 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None: ...@@ -51,7 +51,9 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
for lang in LANGUAGES: for lang in LANGUAGES:
file_name = f"xwinograd_{lang}.yaml" file_name = f"xwinograd_{lang}.yaml"
try: try:
with open(f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf-8") as f: with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf-8"
) as f:
f.write("# Generated by utils.py\n") f.write("# Generated by utils.py\n")
yaml.dump( yaml.dump(
{ {
......
...@@ -90,9 +90,6 @@ all = [ ...@@ -90,9 +90,6 @@ all = [
"lm_eval[wandb]", "lm_eval[wandb]",
] ]
[tool.ruff]
extend-exclude = ["lm_eval/tasks/*.py"]
[tool.ruff.lint] [tool.ruff.lint]
extend-select = ["I"] extend-select = ["I"]
...@@ -101,5 +98,4 @@ lines-after-imports = 2 ...@@ -101,5 +98,4 @@ lines-after-imports = 2
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
[tool.ruff.extend-per-file-ignores] [tool.ruff.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403","I"] "__init__.py" = ["F401","F402","F403"]
"lm_eval/tasks/*"= ["E721"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment