Commit 7cef4d38 authored by Baber's avatar Baber
Browse files

move test one doc to method

parent ec767666
...@@ -656,57 +656,6 @@ class ConfigurableTask(Task): ...@@ -656,57 +656,6 @@ class ConfigurableTask(Task):
) )
self.task_docs = self.eval_docs self.task_docs = self.eval_docs
# Test One Doc
self.features: list[str] = list(self.task_docs.features.keys())
self.multiple_input = self.config.multiple_input
self.multiple_target = 0
test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list")
else:
num_choice = len(test_choice)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target]
else:
test_target = str(test_target)
check_choices = test_choice if test_choice is not None else [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = (
self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download( def download(
self, dataset_kwargs:dict[str, Any] | None = None, **kwargs self, dataset_kwargs:dict[str, Any] | None = None, **kwargs
...@@ -1470,6 +1419,56 @@ class ConfigurableTask(Task): ...@@ -1470,6 +1419,56 @@ class ConfigurableTask(Task):
def task_name(self) -> str | None: def task_name(self) -> str | None:
return getattr(self.config, "task", None) return getattr(self.config, "task", None)
def runtime_checks(self, test_doc):
# Test One Doc
self.features: list[str] = list(self.task_docs.features.keys())
self.multiple_target = 0
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list")
# else:
# num_choice = len(test_choice)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target]
else:
test_target = str(test_target)
check_choices = test_choice if test_choice is not None else [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = (
self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def __repr__(self): def __repr__(self):
return ( return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
...@@ -1491,7 +1490,7 @@ class MultipleChoiceTask(Task): ...@@ -1491,7 +1490,7 @@ class MultipleChoiceTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, f" {choice}"),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
...@@ -171,7 +171,7 @@ class TaskConfig: ...@@ -171,7 +171,7 @@ class TaskConfig:
doc_to_decontamination_query: str | None = None doc_to_decontamination_query: str | None = None
gen_prefix: str | None = None gen_prefix: str | None = None
multiple_input: bool = False multiple_input: bool = False
metadata: dict | None = field( metadata: dict = field(
default_factory=dict default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks ) # by default, not used in the code. allows for users to pass arbitrary info to tasks
......
from __future__ import annotations
import collections import collections
import fnmatch import fnmatch
import hashlib import hashlib
...@@ -12,11 +14,11 @@ from dataclasses import asdict, is_dataclass ...@@ -12,11 +14,11 @@ from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional from typing import Any, Callable
import numpy as np import numpy as np
import yaml import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined, Template
SPACING = " " * 47 SPACING = " " * 47
...@@ -146,7 +148,7 @@ def sanitize_list(sub): ...@@ -146,7 +148,7 @@ def sanitize_list(sub):
return str(sub) return str(sub)
def simple_parse_args_string(args_string: Optional[str]) -> dict: def simple_parse_args_string(args_string: str | None) -> dict:
""" """
Parses something like Parses something like
args1=val1,arg2=val2 args1=val1,arg2=val2
...@@ -181,7 +183,7 @@ def group(arr, fn): ...@@ -181,7 +183,7 @@ def group(arr, fn):
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns: list[str], source_list: list[str]) -> list[str]:
if isinstance(patterns, str): if isinstance(patterns, str):
patterns = [patterns] patterns = [patterns]
...@@ -198,7 +200,7 @@ def softmax(x) -> np.ndarray: ...@@ -198,7 +200,7 @@ def softmax(x) -> np.ndarray:
return e_x / e_x.sum() return e_x / e_x.sum()
def general_detokenize(string) -> str: def general_detokenize(string: str) -> str:
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
string = string.replace("( ", "(") string = string.replace("( ", "(")
...@@ -226,7 +228,7 @@ def sanitize_model_name(model_name: str) -> str: ...@@ -226,7 +228,7 @@ def sanitize_model_name(model_name: str) -> str:
""" """
Given the model name, returns a sanitized version of it. Given the model name, returns a sanitized version of it.
""" """
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name) return re.sub(r"[\"<>:/|\\?*\[\]]+", "__", model_name)
def sanitize_task_name(task_name: str) -> str: def sanitize_task_name(task_name: str) -> str:
...@@ -489,7 +491,9 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path): ...@@ -489,7 +491,9 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path):
return function return function
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): def load_yaml_config(
yaml_path: str | None = None, yaml_config=None, yaml_dir=None, mode="full"
):
if mode == "simple": if mode == "simple":
constructor_fn = ignore_constructor constructor_fn = ignore_constructor
elif mode == "full": elif mode == "full":
...@@ -551,7 +555,7 @@ env.filters["regex_replace"] = regex_replace ...@@ -551,7 +555,7 @@ env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def _compile(raw: str): def _compile(raw: str) -> Template:
return env.from_string(raw) return env.from_string(raw)
...@@ -560,7 +564,13 @@ def apply_template(template: str, doc: dict) -> str: ...@@ -560,7 +564,13 @@ def apply_template(template: str, doc: dict) -> str:
return rtemplate.render(**doc) return rtemplate.render(**doc)
def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): def create_iterator(
raw_iterator: collections.Iterator,
*,
rank: int = 0,
world_size: int = 1,
limit: int | None = None,
) -> islice:
""" """
Method for creating a (potentially) sliced and limited Method for creating a (potentially) sliced and limited
iterator from a raw document iterator. Used for splitting data iterator from a raw document iterator. Used for splitting data
......
...@@ -116,7 +116,7 @@ plugins.md034.enabled = false # no-bare-urls ...@@ -116,7 +116,7 @@ plugins.md034.enabled = false # no-bare-urls
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py39"
lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM"] lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM", "RUF034", "W605", "FURB"]
lint.fixable = ["I001", "F401", "UP"] lint.fixable = ["I001", "F401", "UP"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117", "E741"] lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117", "E741"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment