Commit e3fee7ea authored by Baber's avatar Baber
Browse files

nit

parent 3e3a0d8f
from __future__ import annotations
import abc import abc
import ast import ast
import logging import logging
...@@ -8,12 +10,7 @@ from copy import deepcopy ...@@ -8,12 +10,7 @@ from copy import deepcopy
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Dict,
List,
Literal, Literal,
Optional,
Tuple,
Union,
) )
import datasets import datasets
...@@ -54,23 +51,23 @@ class Task(abc.ABC): ...@@ -54,23 +51,23 @@ class Task(abc.ABC):
{"question": ..., question, answer) {"question": ..., question, answer)
""" """
VERSION: Optional[Union[int, str]] = None VERSION: int | str | None = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script. # or a path to a custom `datasets` loading script.
DATASET_PATH: Optional[str] = None DATASET_PATH: str | None = None
# The name of a subset within `DATASET_PATH`. # The name of a subset within `DATASET_PATH`.
DATASET_NAME: Optional[str] = None DATASET_NAME: str | None = None
OUTPUT_TYPE: Optional[OutputType] = None OUTPUT_TYPE: OutputType | None = None
def __init__( def __init__(
self, self,
data_dir: Optional[str] = None, data_dir: str | None = None,
cache_dir: Optional[str] = None, cache_dir: str | None = None,
download_mode: Optional[datasets.DownloadMode] = None, download_mode: datasets.DownloadMode | None = None,
config: Optional[Mapping] = None, # Union[dict, TaskConfig] config: Mapping | None = None, # Union[dict, TaskConfig]
) -> None: ) -> None:
""" """
:param data_dir: str :param data_dir: str
...@@ -94,21 +91,21 @@ class Task(abc.ABC): ...@@ -94,21 +91,21 @@ class Task(abc.ABC):
Fresh download and fresh dataset. Fresh download and fresh dataset.
""" """
self.download(data_dir, cache_dir, download_mode) self.download(data_dir, cache_dir, download_mode)
self._training_docs: Optional[list] = None self._training_docs: list | None = None
self._fewshot_docs: Optional[list] = None self._fewshot_docs: list | None = None
self._instances: Optional[List[Instance]] = None self._instances: list[Instance] | None = None
self._config: TaskConfig = TaskConfig.from_yaml({**config}) self._config: TaskConfig = TaskConfig.from_yaml({**config})
self._filters = [build_filter_ensemble("none", [("take_first", None)])] self._filters = [build_filter_ensemble("none", [("take_first", None)])]
self.fewshot_rnd: Optional[random.Random] = ( self.fewshot_rnd: random.Random | None = (
None # purposely induce errors in case of improper usage None # purposely induce errors in case of improper usage
) )
def download( def download(
self, self,
data_dir: Optional[str] = None, data_dir: str | None = None,
cache_dir: Optional[str] = None, cache_dir: str | None = None,
download_mode=None, download_mode=None,
) -> None: ) -> None:
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
...@@ -235,7 +232,7 @@ class Task(abc.ABC): ...@@ -235,7 +232,7 @@ class Task(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def doc_to_target(self, doc: dict) -> Union[str, int]: def doc_to_target(self, doc: dict) -> str | int:
pass pass
# not an abstractmethod because not every language-only task has to implement this # not an abstractmethod because not every language-only task has to implement this
...@@ -251,16 +248,16 @@ class Task(abc.ABC): ...@@ -251,16 +248,16 @@ class Task(abc.ABC):
def build_all_requests( def build_all_requests(
self, self,
*, *,
limit: Union[int, None] = None, limit: int | None = None,
samples: Optional[List[int]] = None, samples: list[int] | None = None,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
cache_requests: bool = False, cache_requests: bool = False,
rewrite_requests_cache: bool = False, rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None, system_instruction: str | None = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Callable | None = None,
tokenizer_name: str = "", tokenizer_name: str = "",
) -> None: ) -> None:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
...@@ -362,7 +359,7 @@ class Task(abc.ABC): ...@@ -362,7 +359,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances) save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod @abc.abstractmethod
def construct_requests(self, doc: dict, ctx: Union[list[dict], str], **kwargs): def construct_requests(self, doc: dict, ctx: list[dict] | str, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -402,7 +399,7 @@ class Task(abc.ABC): ...@@ -402,7 +399,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores functions that aggregate a list of metric scores
""" """
pass return True
@deprecated("not used anymore") @deprecated("not used anymore")
def higher_is_better(self): def higher_is_better(self):
...@@ -411,7 +408,7 @@ class Task(abc.ABC): ...@@ -411,7 +408,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
pass return True
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
...@@ -485,7 +482,7 @@ class Task(abc.ABC): ...@@ -485,7 +482,7 @@ class Task(abc.ABC):
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
return description + labeled_examples + example return description + labeled_examples + example
def apply_filters(self) -> Optional[List[Instance]]: def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
...@@ -530,13 +527,13 @@ class Task(abc.ABC): ...@@ -530,13 +527,13 @@ class Task(abc.ABC):
self._config.metric_list = [MetricConfig(name=metric_name)] self._config.metric_list = [MetricConfig(name=metric_name)]
self._config.process_results = lambda *args: {"bypass": 0} self._config.process_results = lambda *args: {"bypass": 0}
def set_fewshot_seed(self, seed: Optional[int] = None) -> None: def set_fewshot_seed(self, seed: int | None = None) -> None:
self.fewshot_rnd = random.Random(seed) self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"): if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd self.sampler.rnd = self.fewshot_rnd
@property @property
def eval_docs(self) -> Union[datasets.Dataset, Iterable[dict]]: def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
if self.has_test_docs(): if self.has_test_docs():
return self.test_docs() return self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs():
...@@ -550,13 +547,13 @@ class Task(abc.ABC): ...@@ -550,13 +547,13 @@ class Task(abc.ABC):
self, self,
*, *,
rank: int = 0, rank: int = 0,
limit: Union[int, None] = None, limit: int | None = None,
world_size: int = 1, world_size: int = 1,
samples: Optional[List[int]] = None, samples: list[int] | None = None,
) -> Iterator[Tuple[int, Any]]: ) -> Iterator[tuple[int, Any]]:
if samples: if samples:
n = len(self.eval_docs) n = len(self.eval_docs)
assert all([e < n for e in samples]), ( assert all(e < n for e in samples), (
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}." f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
) )
eval_logger.info( eval_logger.info(
...@@ -589,7 +586,7 @@ class ConfigurableTask(Task): ...@@ -589,7 +586,7 @@ class ConfigurableTask(Task):
data_dir=None, data_dir=None,
cache_dir=None, cache_dir=None,
download_mode=None, download_mode=None,
config: Optional[dict] = None, config: dict | None = None,
) -> None: ) -> None:
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -607,9 +604,8 @@ class ConfigurableTask(Task): ...@@ -607,9 +604,8 @@ class ConfigurableTask(Task):
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
) )
if isinstance(self.config.metadata, dict): if isinstance(self.config.metadata, dict) and "version" in self.config.metadata:
if "version" in self.config.metadata: self.VERSION = self.config.metadata["version"]
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None: if self.config.output_type is not None:
if self.config.output_type not in ALL_OUTPUT_TYPES: if self.config.output_type not in ALL_OUTPUT_TYPES:
...@@ -695,18 +691,13 @@ class ConfigurableTask(Task): ...@@ -695,18 +691,13 @@ class ConfigurableTask(Task):
else: else:
test_target = str(test_target) test_target = str(test_target)
if test_choice is not None: check_choices = test_choice if test_choice is not None else [test_target]
check_choices = test_choice
else:
check_choices = [test_target]
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
for choice in check_choices: for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = ( delimiter_has_whitespace = (
True self.config.target_delimiter.rstrip()
if self.config.target_delimiter.rstrip()
!= self.config.target_delimiter != self.config.target_delimiter
else False
) )
if delimiter_has_whitespace and choice_has_whitespace: if delimiter_has_whitespace and choice_has_whitespace:
...@@ -718,9 +709,7 @@ class ConfigurableTask(Task): ...@@ -718,9 +709,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
def download( def download(self, dataset_kwargs: dict[str, Any] | None = None, **kwargs) -> None:
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None:
self.config.dataset_kwargs, self.config.metadata = ( self.config.dataset_kwargs, self.config.metadata = (
self.config.dataset_kwargs or {}, self.config.dataset_kwargs or {},
self.config.metadata or {}, self.config.metadata or {},
...@@ -739,24 +728,15 @@ class ConfigurableTask(Task): ...@@ -739,24 +728,15 @@ class ConfigurableTask(Task):
) )
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
if self.config.training_split is not None: return self.config.training_split is not None
return True
else:
return False
def has_validation_docs(self) -> bool: def has_validation_docs(self) -> bool:
if self.config.validation_split is not None: return self.config.validation_split is not None
return True
else:
return False
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
if self.config.test_split is not None: return self.config.test_split is not None
return True
else:
return False
def training_docs(self) -> Optional[datasets.Dataset]: def training_docs(self) -> datasets.Dataset | None:
if self.has_training_docs(): if self.has_training_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -764,7 +744,7 @@ class ConfigurableTask(Task): ...@@ -764,7 +744,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.training_split] return self.dataset[self.config.training_split]
def validation_docs(self) -> Optional[datasets.Dataset]: def validation_docs(self) -> datasets.Dataset | None:
if self.has_validation_docs(): if self.has_validation_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -772,7 +752,7 @@ class ConfigurableTask(Task): ...@@ -772,7 +752,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.validation_split] return self.dataset[self.config.validation_split]
def test_docs(self) -> Optional[datasets.Dataset]: def test_docs(self) -> datasets.Dataset | None:
if self.has_test_docs(): if self.has_test_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split]) return self.config.process_docs(self.dataset[self.config.test_split])
...@@ -785,22 +765,25 @@ class ConfigurableTask(Task): ...@@ -785,22 +765,25 @@ class ConfigurableTask(Task):
return docs return docs
# Fallback to parent implementation # Fallback to parent implementation
if _num_fewshot := self.config.num_fewshot: if (
if isinstance(_num_fewshot, int) and _num_fewshot > 0: (_num_fewshot := self.config.num_fewshot)
eval_logger.warning( and isinstance(_num_fewshot, int)
f"[Task: {self.config.task}] " and _num_fewshot > 0
"num_fewshot > 0 but no fewshot source configured. " ):
"Using preconfigured rule." eval_logger.warning(
) f"[Task: {self.config.task}] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
return super().fewshot_docs() return super().fewshot_docs()
@staticmethod @staticmethod
def append_target_question( def append_target_question(
labeled_examples: List[Dict[str, str]], labeled_examples: list[dict[str, str]],
question: str, question: str,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None, gen_prefix: str | None = None,
) -> None: ) -> None:
"""Adds a target question to the labeled examples list. """Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
...@@ -824,12 +807,12 @@ class ConfigurableTask(Task): ...@@ -824,12 +807,12 @@ class ConfigurableTask(Task):
self, self,
doc: dict, doc: dict,
num_fewshot: int, num_fewshot: int,
system_instruction: Optional[str] = None, system_instruction: str | None = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Callable | None = None,
gen_prefix: Optional[str] = None, gen_prefix: str | None = None,
) -> Union[str, List[str], None]: ) -> str | list[str] | None:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -850,10 +833,7 @@ class ConfigurableTask(Task): ...@@ -850,10 +833,7 @@ class ConfigurableTask(Task):
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
if apply_chat_template: labeled_examples = [] if apply_chat_template else ""
labeled_examples = []
else:
labeled_examples = ""
# get task description # get task description
if description := self.config.description: if description := self.config.description:
...@@ -923,7 +903,7 @@ class ConfigurableTask(Task): ...@@ -923,7 +903,7 @@ class ConfigurableTask(Task):
labeled_examples_list.append( labeled_examples_list.append(
chat_template( chat_template(
chat, chat,
add_generation_prompt=False if gen_prefix else True, add_generation_prompt=not gen_prefix,
) )
) )
return labeled_examples_list return labeled_examples_list
...@@ -947,7 +927,7 @@ class ConfigurableTask(Task): ...@@ -947,7 +927,7 @@ class ConfigurableTask(Task):
# return lm.apply_chat_template(labeled_examples) # return lm.apply_chat_template(labeled_examples)
return chat_template( return chat_template(
labeled_examples, labeled_examples,
add_generation_prompt=False if gen_prefix else True, add_generation_prompt=not gen_prefix,
) )
else: else:
prefix = ( prefix = (
...@@ -968,7 +948,7 @@ class ConfigurableTask(Task): ...@@ -968,7 +948,7 @@ class ConfigurableTask(Task):
else: else:
return labeled_examples + str(example) + prefix return labeled_examples + str(example) + prefix
def apply_filters(self) -> Optional[List[Instance]]: def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
...@@ -1008,9 +988,7 @@ class ConfigurableTask(Task): ...@@ -1008,9 +988,7 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text( def doc_to_text(self, doc: dict, doc_to_text: int | str | Callable | None = None):
self, doc: dict, doc_to_text: Union[int, str, Callable, None] = None
):
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_text = self.prompt # doc_to_text = self.prompt
if doc_to_text is not None: if doc_to_text is not None:
...@@ -1046,9 +1024,7 @@ class ConfigurableTask(Task): ...@@ -1046,9 +1024,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target( def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
self, doc: dict, doc_to_target=None
) -> Union[int, str, list[int]]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_target = self.prompt # doc_to_target = self.prompt
if doc_to_target is not None: if doc_to_target is not None:
...@@ -1097,8 +1073,8 @@ class ConfigurableTask(Task): ...@@ -1097,8 +1073,8 @@ class ConfigurableTask(Task):
def doc_to_choice( def doc_to_choice(
self, self,
doc: dict, doc: dict,
doc_to_choice: Union[str, list, dict, Callable[..., list[str]], None] = None, doc_to_choice: str | list | dict | Callable[..., list[str]] | None = None,
) -> List[str]: ) -> list[str]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_choice = self.prompt # doc_to_choice = self.prompt
if doc_to_choice is not None: if doc_to_choice is not None:
...@@ -1125,7 +1101,7 @@ class ConfigurableTask(Task): ...@@ -1125,7 +1101,7 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_image(self, doc: dict, doc_to_image=None) -> Union[int, str, list, None]: def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None: if doc_to_image is not None:
doc_to_image = doc_to_image doc_to_image = doc_to_image
elif self.config.doc_to_image is not None: elif self.config.doc_to_image is not None:
...@@ -1148,7 +1124,7 @@ class ConfigurableTask(Task): ...@@ -1148,7 +1124,7 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list, None]: def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None: if doc_to_audio is not None:
doc_to_audio = doc_to_audio doc_to_audio = doc_to_audio
elif self.config.doc_to_audio is not None: elif self.config.doc_to_audio is not None:
...@@ -1171,7 +1147,7 @@ class ConfigurableTask(Task): ...@@ -1171,7 +1147,7 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc: dict) -> Optional[str]: def doc_to_prefix(self, doc: dict) -> str | None:
if (gen_prefix := self.config.gen_prefix) is not None: if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features: if gen_prefix in self.features:
return doc[gen_prefix] return doc[gen_prefix]
...@@ -1181,7 +1157,7 @@ class ConfigurableTask(Task): ...@@ -1181,7 +1157,7 @@ class ConfigurableTask(Task):
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> list[Instance] | Instance:
apply_chat_template = kwargs.pop("apply_chat_template", False) apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None) chat_template: Callable | None = kwargs.pop("chat_template", None)
...@@ -1317,7 +1293,7 @@ class ConfigurableTask(Task): ...@@ -1317,7 +1293,7 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in list[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
...@@ -1364,7 +1340,7 @@ class ConfigurableTask(Task): ...@@ -1364,7 +1340,7 @@ class ConfigurableTask(Task):
if self.multiple_target: if self.multiple_target:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold])) exact_match = int(any(is_greedy[i] if i != -100 else 0 for i in gold))
else: else:
acc = 1.0 if pred == gold else 0.0 acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0 acc_norm = 1.0 if pred_norm == gold else 0.0
...@@ -1406,7 +1382,7 @@ class ConfigurableTask(Task): ...@@ -1406,7 +1382,7 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number. # it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
gold = choices[gold] gold = choices[gold]
for metric in self._metric_fn_list.keys(): for metric in self._metric_fn_list:
try: try:
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
references=[gold] if not isinstance(gold, list) else gold, references=[gold] if not isinstance(gold, list) else gold,
...@@ -1440,7 +1416,7 @@ class ConfigurableTask(Task): ...@@ -1440,7 +1416,7 @@ class ConfigurableTask(Task):
return getattr(self._config, key, None) return getattr(self._config, key, None)
@property @property
def task_name(self) -> Optional[str]: def task_name(self) -> str | None:
return getattr(self.config, "task", None) return getattr(self.config, "task", None)
def __repr__(self): def __repr__(self):
...@@ -1458,7 +1434,7 @@ class MultipleChoiceTask(Task): ...@@ -1458,7 +1434,7 @@ class MultipleChoiceTask(Task):
def doc_to_target(self, doc: dict) -> str: def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]] return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]: def construct_requests(self, doc: dict, ctx: str, **kwargs) -> list[Instance]:
# TODO: add mutual info here? # TODO: add mutual info here?
return [ return [
Instance( Instance(
...@@ -1471,7 +1447,7 @@ class MultipleChoiceTask(Task): ...@@ -1471,7 +1447,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"]) for i, choice in enumerate(doc["choices"])
] ]
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict: def process_results(self, doc: dict, results: Iterable[tuple[float, bool]]) -> dict:
results = [ results = [
res[0] for res in results res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...@@ -1505,7 +1481,7 @@ class PerplexityTask(Task): ...@@ -1505,7 +1481,7 @@ class PerplexityTask(Task):
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
return False return False
def fewshot_examples(self, k: int, rnd) -> List: def fewshot_examples(self, k: int, rnd) -> list:
if k != 0: if k != 0:
raise ValueError( raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks." "The number of fewshot examples must be 0 for perplexity tasks."
...@@ -1536,7 +1512,7 @@ class PerplexityTask(Task): ...@@ -1536,7 +1512,7 @@ class PerplexityTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc return doc
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs): def construct_requests(self, doc: dict, ctx: str | None, **kwargs):
if bool(ctx): if bool(ctx):
raise ValueError raise ValueError
...@@ -1548,7 +1524,7 @@ class PerplexityTask(Task): ...@@ -1548,7 +1524,7 @@ class PerplexityTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc: dict, results: Tuple[float]) -> dict: def process_results(self, doc: dict, results: tuple[float]) -> dict:
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc)) bytes_ = self.count_bytes(self.doc_to_target(doc))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment