Commit e3fee7ea authored by Baber's avatar Baber
Browse files

nit

parent 3e3a0d8f
from __future__ import annotations
import abc
import ast
import logging
......@@ -8,12 +10,7 @@ from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
import datasets
......@@ -54,23 +51,23 @@ class Task(abc.ABC):
{"question": ..., question, answer)
"""
VERSION: Optional[Union[int, str]] = None
VERSION: int | str | None = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: Optional[str] = None
DATASET_PATH: str | None = None
# The name of a subset within `DATASET_PATH`.
DATASET_NAME: Optional[str] = None
DATASET_NAME: str | None = None
OUTPUT_TYPE: Optional[OutputType] = None
OUTPUT_TYPE: OutputType | None = None
def __init__(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode: Optional[datasets.DownloadMode] = None,
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
data_dir: str | None = None,
cache_dir: str | None = None,
download_mode: datasets.DownloadMode | None = None,
config: Mapping | None = None, # Union[dict, TaskConfig]
) -> None:
"""
:param data_dir: str
......@@ -94,21 +91,21 @@ class Task(abc.ABC):
Fresh download and fresh dataset.
"""
self.download(data_dir, cache_dir, download_mode)
self._training_docs: Optional[list] = None
self._fewshot_docs: Optional[list] = None
self._instances: Optional[List[Instance]] = None
self._training_docs: list | None = None
self._fewshot_docs: list | None = None
self._instances: list[Instance] | None = None
self._config: TaskConfig = TaskConfig.from_yaml({**config})
self._filters = [build_filter_ensemble("none", [("take_first", None)])]
self.fewshot_rnd: Optional[random.Random] = (
self.fewshot_rnd: random.Random | None = (
None # purposely induce errors in case of improper usage
)
def download(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
data_dir: str | None = None,
cache_dir: str | None = None,
download_mode=None,
) -> None:
"""Downloads and returns the task dataset.
......@@ -235,7 +232,7 @@ class Task(abc.ABC):
pass
@abc.abstractmethod
def doc_to_target(self, doc: dict) -> Union[str, int]:
def doc_to_target(self, doc: dict) -> str | int:
pass
# not an abstractmethod because not every language-only task has to implement this
......@@ -251,16 +248,16 @@ class Task(abc.ABC):
def build_all_requests(
self,
*,
limit: Union[int, None] = None,
samples: Optional[List[int]] = None,
limit: int | None = None,
samples: list[int] | None = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
chat_template: Callable | None = None,
tokenizer_name: str = "",
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
......@@ -362,7 +359,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod
def construct_requests(self, doc: dict, ctx: Union[list[dict], str], **kwargs):
def construct_requests(self, doc: dict, ctx: list[dict] | str, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -402,7 +399,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
pass
return True
@deprecated("not used anymore")
def higher_is_better(self):
......@@ -411,7 +408,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
pass
return True
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
......@@ -485,7 +482,7 @@ class Task(abc.ABC):
example = self.doc_to_text(doc)
return description + labeled_examples + example
def apply_filters(self) -> Optional[List[Instance]]:
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
......@@ -530,13 +527,13 @@ class Task(abc.ABC):
self._config.metric_list = [MetricConfig(name=metric_name)]
self._config.process_results = lambda *args: {"bypass": 0}
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
def set_fewshot_seed(self, seed: int | None = None) -> None:
self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd
@property
def eval_docs(self) -> Union[datasets.Dataset, Iterable[dict]]:
def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
if self.has_test_docs():
return self.test_docs()
elif self.has_validation_docs():
......@@ -550,13 +547,13 @@ class Task(abc.ABC):
self,
*,
rank: int = 0,
limit: Union[int, None] = None,
limit: int | None = None,
world_size: int = 1,
samples: Optional[List[int]] = None,
) -> Iterator[Tuple[int, Any]]:
samples: list[int] | None = None,
) -> Iterator[tuple[int, Any]]:
if samples:
n = len(self.eval_docs)
assert all([e < n for e in samples]), (
assert all(e < n for e in samples), (
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
)
eval_logger.info(
......@@ -589,7 +586,7 @@ class ConfigurableTask(Task):
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
config: dict | None = None,
) -> None:
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -607,9 +604,8 @@ class ConfigurableTask(Task):
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if isinstance(self.config.metadata, dict):
if "version" in self.config.metadata:
self.VERSION = self.config.metadata["version"]
if isinstance(self.config.metadata, dict) and "version" in self.config.metadata:
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None:
if self.config.output_type not in ALL_OUTPUT_TYPES:
......@@ -695,18 +691,13 @@ class ConfigurableTask(Task):
else:
test_target = str(test_target)
if test_choice is not None:
check_choices = test_choice
else:
check_choices = [test_target]
check_choices = test_choice if test_choice is not None else [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False
choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = (
True
if self.config.target_delimiter.rstrip()
self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
else False
)
if delimiter_has_whitespace and choice_has_whitespace:
......@@ -718,9 +709,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None:
def download(self, dataset_kwargs: dict[str, Any] | None = None, **kwargs) -> None:
self.config.dataset_kwargs, self.config.metadata = (
self.config.dataset_kwargs or {},
self.config.metadata or {},
......@@ -739,24 +728,15 @@ class ConfigurableTask(Task):
)
def has_training_docs(self) -> bool:
if self.config.training_split is not None:
return True
else:
return False
return self.config.training_split is not None
def has_validation_docs(self) -> bool:
if self.config.validation_split is not None:
return True
else:
return False
return self.config.validation_split is not None
def has_test_docs(self) -> bool:
if self.config.test_split is not None:
return True
else:
return False
return self.config.test_split is not None
def training_docs(self) -> Optional[datasets.Dataset]:
def training_docs(self) -> datasets.Dataset | None:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -764,7 +744,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.training_split]
def validation_docs(self) -> Optional[datasets.Dataset]:
def validation_docs(self) -> datasets.Dataset | None:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -772,7 +752,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.validation_split]
def test_docs(self) -> Optional[datasets.Dataset]:
def test_docs(self) -> datasets.Dataset | None:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
......@@ -785,22 +765,25 @@ class ConfigurableTask(Task):
return docs
# Fallback to parent implementation
if _num_fewshot := self.config.num_fewshot:
if isinstance(_num_fewshot, int) and _num_fewshot > 0:
eval_logger.warning(
f"[Task: {self.config.task}] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
if (
(_num_fewshot := self.config.num_fewshot)
and isinstance(_num_fewshot, int)
and _num_fewshot > 0
):
eval_logger.warning(
f"[Task: {self.config.task}] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
return super().fewshot_docs()
@staticmethod
def append_target_question(
labeled_examples: List[Dict[str, str]],
labeled_examples: list[dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None,
gen_prefix: str | None = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
......@@ -824,12 +807,12 @@ class ConfigurableTask(Task):
self,
doc: dict,
num_fewshot: int,
system_instruction: Optional[str] = None,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
gen_prefix: Optional[str] = None,
) -> Union[str, List[str], None]:
chat_template: Callable | None = None,
gen_prefix: str | None = None,
) -> str | list[str] | None:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -850,10 +833,7 @@ class ConfigurableTask(Task):
:returns: str
The fewshot context.
"""
if apply_chat_template:
labeled_examples = []
else:
labeled_examples = ""
labeled_examples = [] if apply_chat_template else ""
# get task description
if description := self.config.description:
......@@ -923,7 +903,7 @@ class ConfigurableTask(Task):
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=False if gen_prefix else True,
add_generation_prompt=not gen_prefix,
)
)
return labeled_examples_list
......@@ -947,7 +927,7 @@ class ConfigurableTask(Task):
# return lm.apply_chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=False if gen_prefix else True,
add_generation_prompt=not gen_prefix,
)
else:
prefix = (
......@@ -968,7 +948,7 @@ class ConfigurableTask(Task):
else:
return labeled_examples + str(example) + prefix
def apply_filters(self) -> Optional[List[Instance]]:
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
......@@ -1008,9 +988,7 @@ class ConfigurableTask(Task):
"""
return doc
def doc_to_text(
self, doc: dict, doc_to_text: Union[int, str, Callable, None] = None
):
def doc_to_text(self, doc: dict, doc_to_text: int | str | Callable | None = None):
# if self.prompt is not None:
# doc_to_text = self.prompt
if doc_to_text is not None:
......@@ -1046,9 +1024,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(
self, doc: dict, doc_to_target=None
) -> Union[int, str, list[int]]:
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None:
# doc_to_target = self.prompt
if doc_to_target is not None:
......@@ -1097,8 +1073,8 @@ class ConfigurableTask(Task):
def doc_to_choice(
self,
doc: dict,
doc_to_choice: Union[str, list, dict, Callable[..., list[str]], None] = None,
) -> List[str]:
doc_to_choice: str | list | dict | Callable[..., list[str]] | None = None,
) -> list[str]:
# if self.prompt is not None:
# doc_to_choice = self.prompt
if doc_to_choice is not None:
......@@ -1125,7 +1101,7 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_image(self, doc: dict, doc_to_image=None) -> Union[int, str, list, None]:
def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
......@@ -1148,7 +1124,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list, None]:
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None:
doc_to_audio = doc_to_audio
elif self.config.doc_to_audio is not None:
......@@ -1171,7 +1147,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc: dict) -> Optional[str]:
def doc_to_prefix(self, doc: dict) -> str | None:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
......@@ -1181,7 +1157,7 @@ class ConfigurableTask(Task):
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
) -> list[Instance] | Instance:
apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None)
......@@ -1317,7 +1293,7 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc.
# retrieve choices in list[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
......@@ -1364,7 +1340,7 @@ class ConfigurableTask(Task):
if self.multiple_target:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
exact_match = int(any(is_greedy[i] if i != -100 else 0 for i in gold))
else:
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
......@@ -1406,7 +1382,7 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
for metric in self._metric_fn_list.keys():
for metric in self._metric_fn_list:
try:
result_score = self._metric_fn_list[metric](
references=[gold] if not isinstance(gold, list) else gold,
......@@ -1440,7 +1416,7 @@ class ConfigurableTask(Task):
return getattr(self._config, key, None)
@property
def task_name(self) -> Optional[str]:
def task_name(self) -> str | None:
return getattr(self.config, "task", None)
def __repr__(self):
......@@ -1458,7 +1434,7 @@ class MultipleChoiceTask(Task):
def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> list[Instance]:
# TODO: add mutual info here?
return [
Instance(
......@@ -1471,7 +1447,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
def process_results(self, doc: dict, results: Iterable[tuple[float, bool]]) -> dict:
results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
......@@ -1505,7 +1481,7 @@ class PerplexityTask(Task):
def has_training_docs(self) -> bool:
return False
def fewshot_examples(self, k: int, rnd) -> List:
def fewshot_examples(self, k: int, rnd) -> list:
if k != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
......@@ -1536,7 +1512,7 @@ class PerplexityTask(Task):
def doc_to_target(self, doc):
return doc
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
def construct_requests(self, doc: dict, ctx: str | None, **kwargs):
if bool(ctx):
raise ValueError
......@@ -1548,7 +1524,7 @@ class PerplexityTask(Task):
**kwargs,
)
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
def process_results(self, doc: dict, results: tuple[float]) -> dict:
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment