"include/vscode:/vscode.git/clone" did not exist on "19dd98c89eb40c1d250cbef0f301b7a9c9670845"
Commit f264f2e2 authored by Baber's avatar Baber
Browse files

type hints

parent 230352ce
...@@ -33,7 +33,7 @@ repos: ...@@ -33,7 +33,7 @@ repos:
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff-check - id: ruff-check
args: [ --fix ] args: [ --fix, --unsafe-fixes]
# Run the formatter. # Run the formatter.
- id: ruff-format - id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
......
...@@ -11,6 +11,7 @@ from typing import ( ...@@ -11,6 +11,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Literal, Literal,
overload,
) )
import datasets import datasets
...@@ -192,7 +193,7 @@ class Task(abc.ABC): ...@@ -192,7 +193,7 @@ class Task(abc.ABC):
elif self.has_validation_docs(): elif self.has_validation_docs():
return self.validation_docs() return self.validation_docs()
else: else:
if self.config.get("num_fewshot", 0) > 0: if self.config.num_fewshot and self.config.num_fewshot > 0:
eval_logger.warning( eval_logger.warning(
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False" f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended." ", using test_docs as fewshot_docs but this is not recommended."
...@@ -331,7 +332,7 @@ class Task(abc.ABC): ...@@ -331,7 +332,7 @@ class Task(abc.ABC):
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats), metadata=(self.config.task, doc_id, self.config.repeats),
apply_chat_template=apply_chat_template, apply_chat_template=apply_chat_template,
chat_template=chat_template, chat_template=chat_template,
) )
...@@ -990,9 +991,21 @@ class ConfigurableTask(Task): ...@@ -990,9 +991,21 @@ class ConfigurableTask(Task):
""" """
return doc return doc
@overload
def doc_to_text(self, doc: dict, doc_to_text: None = None) -> str | int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: int) -> int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: str) -> str: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: Callable[..., str]) -> str: ...
def doc_to_text( def doc_to_text(
self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None
) -> str: ) -> str | int:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_text = self.prompt # doc_to_text = self.prompt
doc_to_text = doc_to_text or self.config.doc_to_text doc_to_text = doc_to_text or self.config.doc_to_text
...@@ -1025,6 +1038,25 @@ class ConfigurableTask(Task): ...@@ -1025,6 +1038,25 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
@overload
def doc_to_target(
self, doc: dict, doc_to_target: None = None
) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: int) -> int: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: str) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: list) -> list[int]: ...
@overload
def doc_to_target(
self, doc: dict, doc_to_target: Callable[..., int | str | list[int]]
) -> int | str | list[int]: ...
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]: def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_target = self.prompt # doc_to_target = self.prompt
...@@ -1071,6 +1103,23 @@ class ConfigurableTask(Task): ...@@ -1071,6 +1103,23 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: None = None) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: str) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: list) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: dict) -> list[str]: ...
@overload
def doc_to_choice(
self, doc: dict, doc_to_choice: Callable[..., list[str]]
) -> list[str]: ...
def doc_to_choice( def doc_to_choice(
self, self,
doc: dict, doc: dict,
...@@ -1102,6 +1151,18 @@ class ConfigurableTask(Task): ...@@ -1102,6 +1151,18 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
@overload
def doc_to_image(self, doc: dict, doc_to_image: None = None) -> None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: list) -> list: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: str) -> int | str | None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: Callable[..., Any]) -> Any: ...
def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None: def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None: if doc_to_image is not None:
doc_to_image = doc_to_image doc_to_image = doc_to_image
...@@ -1125,6 +1186,18 @@ class ConfigurableTask(Task): ...@@ -1125,6 +1186,18 @@ class ConfigurableTask(Task):
else: else:
return None return None
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: None = None) -> None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: list) -> list: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: str) -> int | str | None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: Callable[..., Any]) -> Any: ...
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None: def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None: if doc_to_audio is not None:
doc_to_audio = doc_to_audio doc_to_audio = doc_to_audio
...@@ -1369,15 +1442,15 @@ class ConfigurableTask(Task): ...@@ -1369,15 +1442,15 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
result = results[0] result = results[0]
for metric in self._metric_fn_list: for metric in self.config._metric_list:
try: try:
result_score = self._metric_fn_list[metric]( result_score = metric.fn(
references=[gold] if not isinstance(gold, list) else gold, references=[gold] if not isinstance(gold, list) else gold,
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric], **metric.kwargs,
) )
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result]) result_score = metric.fn([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function # This allows for multiple metrics to be returned from the same function
......
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import cached_property from functools import cached_property
from typing import Any from typing import Any
...@@ -11,8 +11,8 @@ class MetricConfig: ...@@ -11,8 +11,8 @@ class MetricConfig:
"""Encapsulates information about a single metric.""" """Encapsulates information about a single metric."""
name: str name: str
fn: Callable | None = None fn: Callable
kwargs: Mapping[str, Any] | None = None kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None aggregation_fn: Callable | None = None
higher_is_better: bool = True higher_is_better: bool = True
hf_evaluate: bool = False hf_evaluate: bool = False
......
...@@ -44,7 +44,7 @@ class FewshotConfig: ...@@ -44,7 +44,7 @@ class FewshotConfig:
num_fewshot: Callable[[], int] num_fewshot: Callable[[], int]
split: str | None = None split: str | None = None
sampler: str | Callable = "default" sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None samples: Callable[[], Iterable[dict]] | Iterable[dict] | None = None
process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = ( process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = (
None None
) )
...@@ -71,7 +71,7 @@ class FewshotConfig: ...@@ -71,7 +71,7 @@ class FewshotConfig:
def _get_raw_docs( def _get_raw_docs(
self, dataset self, dataset
) -> list[dict] | Callable[[], Iterable[dict]] | None: ) -> list[dict] | Callable[[], Iterable[dict[str, Any]]] | None:
"""Get raw documents from configured source.""" """Get raw documents from configured source."""
if self.split is not None: if self.split is not None:
return dataset[self.split] return dataset[self.split]
...@@ -425,12 +425,6 @@ class TaskConfig: ...@@ -425,12 +425,6 @@ class TaskConfig:
# Create and return TaskConfig instance # Create and return TaskConfig instance
return cls(**config_dict) return cls(**config_dict)
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict: def to_dict(self, keep_callable: bool = False) -> dict:
def _ser(x): def _ser(x):
if isinstance(x, dict): if isinstance(x, dict):
......
import collections import collections
import fnmatch import fnmatch
import functools
import hashlib import hashlib
import importlib.util import importlib.util
import inspect import inspect
...@@ -8,10 +7,12 @@ import json ...@@ -8,10 +7,12 @@ import json
import logging import logging
import os import os
import re import re
from collections.abc import Generator
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple from typing import Any, Callable, Optional
import numpy as np import numpy as np
import yaml import yaml
...@@ -91,7 +92,7 @@ def escaped_split(text, sep_char, maxsplit=-1): ...@@ -91,7 +92,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
return text return text
maxsplit = max(0, maxsplit) maxsplit = max(0, maxsplit)
return re.split(r"(?<!\\)" + sep_char, text, maxsplit) return re.split(r"(?<!\\)" + sep_char, text, maxsplit=maxsplit)
def handle_arg_string(arg): def handle_arg_string(arg):
...@@ -108,7 +109,7 @@ def handle_arg_string(arg): ...@@ -108,7 +109,7 @@ def handle_arg_string(arg):
def handle_non_serializable(o): def handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32): if isinstance(o, np.integer):
return int(o) return int(o)
elif isinstance(o, set): elif isinstance(o, set):
return list(o) return list(o)
...@@ -218,21 +219,21 @@ def sanitize_task_name(task_name: str) -> str: ...@@ -218,21 +219,21 @@ def sanitize_task_name(task_name: str) -> str:
return re.sub(r"\W", "_", task_name) return re.sub(r"\W", "_", task_name)
def get_latest_filename(filenames: List[str]) -> str: def get_latest_filename(filenames: list[str]) -> str:
""" """
Given a list of filenames, returns the filename with the latest datetime. Given a list of filenames, returns the filename with the latest datetime.
""" """
return max(filenames, key=lambda f: get_file_datetime(f)) return max(filenames, key=lambda f: get_file_datetime(f))
def get_results_filenames(filenames: List[str]) -> List[str]: def get_results_filenames(filenames: list[str]) -> list[str]:
""" """
Extracts filenames that correspond to aggregated results. Extracts filenames that correspond to aggregated results.
""" """
return [f for f in filenames if "/results_" in f and ".json" in f] return [f for f in filenames if "/results_" in f and ".json" in f]
def get_sample_results_filenames(filenames: List[str]) -> List[str]: def get_sample_results_filenames(filenames: list[str]) -> list[str]:
""" """
Extracts filenames that correspond to sample results. Extracts filenames that correspond to sample results.
""" """
...@@ -240,8 +241,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]: ...@@ -240,8 +241,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def get_rolling_token_windows( def get_rolling_token_windows(
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int token_list: list[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[Tuple[List[int], List[int]], None, None]: ) -> Generator[tuple[list[int], list[int]], None, None]:
""" """
- context_len allows for a rolling window context, allowing each prediction window to potentially - context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context condition on some context
...@@ -283,8 +284,8 @@ def get_rolling_token_windows( ...@@ -283,8 +284,8 @@ def get_rolling_token_windows(
def make_disjoint_window( def make_disjoint_window(
pair: Tuple[List[int], List[int]], pair: tuple[list[int], list[int]],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair a, b = pair
return a[: len(a) - (len(b) - 1)], b return a[: len(a) - (len(b) - 1)], b
...@@ -303,7 +304,7 @@ class EnhancedJSONEncoder(json.JSONEncoder): ...@@ -303,7 +304,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
class Reorderer: class Reorderer:
def __init__(self, arr: List[Any], fn: Callable) -> None: def __init__(self, arr: list[Any], fn: Callable) -> None:
"""Reorder an array according to some function """Reorder an array according to some function
Args: Args:
...@@ -406,11 +407,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False) ...@@ -406,11 +407,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
# TODO: fix # TODO: fix
hib = "↑" hib = "↑"
v = "%.4f" % v if isinstance(v, float) else v v = f"{v:.4f}" if isinstance(v, float) else v
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] se = dic[m + "_stderr" + "," + f]
se = " N/A" if se == "N/A" else "%.4f" % se se = " N/A" if se == "N/A" else f"{se:.4f}"
values.append([k, version, f, n, m, hib, v, "±", se]) values.append([k, version, f, n, m, hib, v, "±", se])
else: else:
values.append([k, version, f, n, m, hib, v, "", ""]) values.append([k, version, f, n, m, hib, v, "", ""])
...@@ -431,7 +432,8 @@ def positional_deprecated(fn): ...@@ -431,7 +432,8 @@ def positional_deprecated(fn):
wrapped function, `fn`. wrapped function, `fn`.
""" """
@functools.wraps(fn) wraps(fn)
def _wrapper(*args, **kwargs): def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0: if len(args) != 1 if inspect.ismethod(fn) else 0:
print( print(
...@@ -477,7 +479,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full ...@@ -477,7 +479,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if yaml_path is None: if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.") raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later # Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path)) constructor_fn = partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader # Add the import_function constructor to the YAML loader
...@@ -526,13 +528,18 @@ def regex_replace(string, pattern, repl, count: int = 0): ...@@ -526,13 +528,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
env = Environment( env = Environment(
loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
) )
env.filters["regex_replace"] = regex_replace env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128)
def _compile(raw: str):
return env.from_string(raw)
def apply_template(template: str, doc: dict) -> str: def apply_template(template: str, doc: dict) -> str:
rtemplate = env.from_string(template) rtemplate = _compile(template)
return rtemplate.render(**doc) return rtemplate.render(**doc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment