Commit f264f2e2 authored by Baber's avatar Baber
Browse files

type hints

parent 230352ce
......@@ -33,7 +33,7 @@ repos:
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
args: [ --fix, --unsafe-fixes]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
......
......@@ -11,6 +11,7 @@ from typing import (
TYPE_CHECKING,
Any,
Literal,
overload,
)
import datasets
......@@ -192,7 +193,7 @@ class Task(abc.ABC):
elif self.has_validation_docs():
return self.validation_docs()
else:
if self.config.get("num_fewshot", 0) > 0:
if self.config.num_fewshot and self.config.num_fewshot > 0:
eval_logger.warning(
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
......@@ -331,7 +332,7 @@ class Task(abc.ABC):
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats),
metadata=(self.config.task, doc_id, self.config.repeats),
apply_chat_template=apply_chat_template,
chat_template=chat_template,
)
......@@ -990,9 +991,21 @@ class ConfigurableTask(Task):
"""
return doc
@overload
def doc_to_text(self, doc: dict, doc_to_text: None = None) -> str | int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: int) -> int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: str) -> str: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: Callable[..., str]) -> str: ...
def doc_to_text(
self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None
) -> str:
) -> str | int:
# if self.prompt is not None:
# doc_to_text = self.prompt
doc_to_text = doc_to_text or self.config.doc_to_text
......@@ -1025,6 +1038,25 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
@overload
def doc_to_target(
self, doc: dict, doc_to_target: None = None
) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: int) -> int: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: str) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: list) -> list[int]: ...
@overload
def doc_to_target(
self, doc: dict, doc_to_target: Callable[..., int | str | list[int]]
) -> int | str | list[int]: ...
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None:
# doc_to_target = self.prompt
......@@ -1071,6 +1103,23 @@ class ConfigurableTask(Task):
else:
raise TypeError
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: None = None) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: str) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: list) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: dict) -> list[str]: ...
@overload
def doc_to_choice(
self, doc: dict, doc_to_choice: Callable[..., list[str]]
) -> list[str]: ...
def doc_to_choice(
self,
doc: dict,
......@@ -1102,6 +1151,18 @@ class ConfigurableTask(Task):
else:
raise TypeError
@overload
def doc_to_image(self, doc: dict, doc_to_image: None = None) -> None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: list) -> list: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: str) -> int | str | None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: Callable[..., Any]) -> Any: ...
def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None:
doc_to_image = doc_to_image
......@@ -1125,6 +1186,18 @@ class ConfigurableTask(Task):
else:
return None
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: None = None) -> None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: list) -> list: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: str) -> int | str | None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: Callable[..., Any]) -> Any: ...
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None:
doc_to_audio = doc_to_audio
......@@ -1369,15 +1442,15 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
for metric in self._metric_fn_list:
for metric in self.config._metric_list:
try:
result_score = self._metric_fn_list[metric](
result_score = metric.fn(
references=[gold] if not isinstance(gold, list) else gold,
predictions=[result],
**self._metric_fn_kwargs[metric],
**metric.kwargs,
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
result_score = metric.fn([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
......
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
......@@ -11,8 +11,8 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Callable | None = None
kwargs: Mapping[str, Any] | None = None
fn: Callable
kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
......
......@@ -44,7 +44,7 @@ class FewshotConfig:
num_fewshot: Callable[[], int]
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None
samples: Callable[[], Iterable[dict]] | Iterable[dict] | None = None
process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = (
None
)
......@@ -71,7 +71,7 @@ class FewshotConfig:
def _get_raw_docs(
self, dataset
) -> list[dict] | Callable[[], Iterable[dict]] | None:
) -> list[dict] | Callable[[], Iterable[dict[str, Any]]] | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
......@@ -425,12 +425,6 @@ class TaskConfig:
# Create and return TaskConfig instance
return cls(**config_dict)
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
def _ser(x):
if isinstance(x, dict):
......
import collections
import fnmatch
import functools
import hashlib
import importlib.util
import inspect
......@@ -8,10 +7,12 @@ import json
import logging
import os
import re
from collections.abc import Generator
from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple
from typing import Any, Callable, Optional
import numpy as np
import yaml
......@@ -91,7 +92,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
return text
maxsplit = max(0, maxsplit)
return re.split(r"(?<!\\)" + sep_char, text, maxsplit)
return re.split(r"(?<!\\)" + sep_char, text, maxsplit=maxsplit)
def handle_arg_string(arg):
......@@ -108,7 +109,7 @@ def handle_arg_string(arg):
def handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
if isinstance(o, np.integer):
return int(o)
elif isinstance(o, set):
return list(o)
......@@ -218,21 +219,21 @@ def sanitize_task_name(task_name: str) -> str:
return re.sub(r"\W", "_", task_name)
def get_latest_filename(filenames: List[str]) -> str:
def get_latest_filename(filenames: list[str]) -> str:
"""
Given a list of filenames, returns the filename with the latest datetime.
"""
return max(filenames, key=lambda f: get_file_datetime(f))
def get_results_filenames(filenames: List[str]) -> List[str]:
def get_results_filenames(filenames: list[str]) -> list[str]:
"""
Extracts filenames that correspond to aggregated results.
"""
return [f for f in filenames if "/results_" in f and ".json" in f]
def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def get_sample_results_filenames(filenames: list[str]) -> list[str]:
"""
Extracts filenames that correspond to sample results.
"""
......@@ -240,8 +241,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def get_rolling_token_windows(
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[Tuple[List[int], List[int]], None, None]:
token_list: list[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[tuple[list[int], list[int]], None, None]:
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
......@@ -283,8 +284,8 @@ def get_rolling_token_windows(
def make_disjoint_window(
pair: Tuple[List[int], List[int]],
) -> Tuple[List[int], List[int]]:
pair: tuple[list[int], list[int]],
) -> tuple[list[int], list[int]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair
return a[: len(a) - (len(b) - 1)], b
......@@ -303,7 +304,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
class Reorderer:
def __init__(self, arr: List[Any], fn: Callable) -> None:
def __init__(self, arr: list[Any], fn: Callable) -> None:
"""Reorder an array according to some function
Args:
......@@ -406,11 +407,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
# TODO: fix
hib = "↑"
v = "%.4f" % v if isinstance(v, float) else v
v = f"{v:.4f}" if isinstance(v, float) else v
if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
se = " N/A" if se == "N/A" else "%.4f" % se
se = " N/A" if se == "N/A" else f"{se:.4f}"
values.append([k, version, f, n, m, hib, v, "±", se])
else:
values.append([k, version, f, n, m, hib, v, "", ""])
......@@ -431,7 +432,8 @@ def positional_deprecated(fn):
wrapped function, `fn`.
"""
@functools.wraps(fn)
wraps(fn)
def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0:
print(
......@@ -477,7 +479,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
constructor_fn = partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader
......@@ -526,13 +528,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
env = Environment(
loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True
loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
)
env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128)
def _compile(raw: str):
return env.from_string(raw)
def apply_template(template: str, doc: dict) -> str:
rtemplate = env.from_string(template)
rtemplate = _compile(template)
return rtemplate.render(**doc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment