Commit 7cef4d38 authored by Baber's avatar Baber
Browse files

move test one doc to method

parent ec767666
......@@ -656,57 +656,6 @@ class ConfigurableTask(Task):
)
self.task_docs = self.eval_docs
# Test One Doc
self.features: list[str] = list(self.task_docs.features.keys())
self.multiple_input = self.config.multiple_input
self.multiple_target = 0
test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list")
else:
num_choice = len(test_choice)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target]
else:
test_target = str(test_target)
check_choices = test_choice if test_choice is not None else [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = (
self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(
self, dataset_kwargs:dict[str, Any] | None = None, **kwargs
......@@ -1470,6 +1419,56 @@ class ConfigurableTask(Task):
def task_name(self) -> str | None:
return getattr(self.config, "task", None)
def runtime_checks(self, test_doc):
# Test One Doc
self.features: list[str] = list(self.task_docs.features.keys())
self.multiple_target = 0
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list")
# else:
# num_choice = len(test_choice)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target]
else:
test_target = str(test_target)
check_choices = test_choice if test_choice is not None else [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = (
self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def __repr__(self):
return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
......@@ -1491,7 +1490,7 @@ class MultipleChoiceTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
arguments=(ctx, f" {choice}"),
idx=i,
**kwargs,
)
......
......@@ -171,7 +171,7 @@ class TaskConfig:
doc_to_decontamination_query: str | None = None
gen_prefix: str | None = None
multiple_input: bool = False
metadata: dict | None = field(
metadata: dict = field(
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
......
from __future__ import annotations
import collections
import fnmatch
import hashlib
......@@ -12,11 +14,11 @@ from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Optional
from typing import Any, Callable
import numpy as np
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
from jinja2 import BaseLoader, Environment, StrictUndefined, Template
SPACING = " " * 47
......@@ -146,7 +148,7 @@ def sanitize_list(sub):
return str(sub)
def simple_parse_args_string(args_string: Optional[str]) -> dict:
def simple_parse_args_string(args_string: str | None) -> dict:
"""
Parses something like
args1=val1,arg2=val2
......@@ -181,7 +183,7 @@ def group(arr, fn):
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
def pattern_match(patterns: list[str], source_list: list[str]) -> list[str]:
if isinstance(patterns, str):
patterns = [patterns]
......@@ -198,7 +200,7 @@ def softmax(x) -> np.ndarray:
return e_x / e_x.sum()
def general_detokenize(string) -> str:
def general_detokenize(string: str) -> str:
string = string.replace(" n't", "n't")
string = string.replace(" )", ")")
string = string.replace("( ", "(")
......@@ -226,7 +228,7 @@ def sanitize_model_name(model_name: str) -> str:
"""
Given the model name, returns a sanitized version of it.
"""
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
return re.sub(r"[\"<>:/|\\?*\[\]]+", "__", model_name)
def sanitize_task_name(task_name: str) -> str:
......@@ -489,7 +491,9 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path):
return function
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
def load_yaml_config(
yaml_path: str | None = None, yaml_config=None, yaml_dir=None, mode="full"
):
if mode == "simple":
constructor_fn = ignore_constructor
elif mode == "full":
......@@ -551,7 +555,7 @@ env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128)
def _compile(raw: str):
def _compile(raw: str) -> Template:
return env.from_string(raw)
......@@ -560,7 +564,13 @@ def apply_template(template: str, doc: dict) -> str:
return rtemplate.render(**doc)
def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
def create_iterator(
raw_iterator: collections.Iterator,
*,
rank: int = 0,
world_size: int = 1,
limit: int | None = None,
) -> islice:
"""
Method for creating a (potentially) sliced and limited
iterator from a raw document iterator. Used for splitting data
......
......@@ -116,7 +116,7 @@ plugins.md034.enabled = false # no-bare-urls
[tool.ruff]
target-version = "py39"
lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM"]
lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM", "RUF034", "W605", "FURB"]
lint.fixable = ["I001", "F401", "UP"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117", "E741"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment