Commit 1768fd3b authored by Baber's avatar Baber
Browse files

ruff rules; types

parent f650197a
...@@ -32,10 +32,9 @@ repos: ...@@ -32,10 +32,9 @@ repos:
rev: v0.12.2 rev: v0.12.2
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff-check
args: args: [ --fix ]
- --fix # Run the formatter.
# Run the formatter.
- id: ruff-format - id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.1 rev: v2.4.1
......
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, Optional, Union
@dataclass @dataclass
class AggMetricConfig(dict): class AggMetricConfig(dict):
metric: Optional[str] = None metric: Optional[str] = None
aggregation: Optional[str] = "mean" aggregation: Optional[str] = "mean"
weight_by_size: Optional[str] = False weight_by_size: bool = False
# list of filter names which should be incorporated into the aggregated metric. # list of filter names which should be incorporated into the aggregated metric.
filter_list: Optional[Union[str, list]] = "none" filter_list: Optional[Union[str, list]] = "none"
...@@ -27,7 +27,7 @@ class GroupConfig(dict): ...@@ -27,7 +27,7 @@ class GroupConfig(dict):
group_alias: Optional[str] = None group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None task: Optional[Union[str, list]] = None
aggregate_metric_list: Optional[ aggregate_metric_list: Optional[
Union[List[AggMetricConfig], AggMetricConfig, dict] Union[list[AggMetricConfig], AggMetricConfig, dict]
] = None ] = None
version: Optional[str] = None version: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = (
......
from __future__ import annotations
import logging import logging
import math import math
import os import os
import random import random
import re import re
import string import string
from collections.abc import Iterable from collections.abc import Callable, Iterable, Sequence
from typing import Callable, List, Optional, Sequence, TypeVar from typing import Generic, TypeVar
import numpy as np import numpy as np
...@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float: ...@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr: list[float]) -> float: def mean(arr: Sequence[float]) -> float:
return sum(arr) / len(arr) return sum(arr) / len(arr)
...@@ -70,7 +72,7 @@ def f1_score(items): ...@@ -70,7 +72,7 @@ def f1_score(items):
@register_aggregation("matthews_corrcoef") @register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items): def matthews_corrcoef(items: Iterable[tuple[int, int] | tuple[str, str]]) -> float:
from sklearn.metrics import matthews_corrcoef from sklearn.metrics import matthews_corrcoef
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
...@@ -80,7 +82,7 @@ def matthews_corrcoef(items): ...@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
@register_aggregation("bleu") @register_aggregation("bleu")
def bleu(items): def bleu(items: Iterable[tuple[str, str]]):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where n-grams in the candidate translation to n-grams in the reference text, where
...@@ -117,7 +119,7 @@ def chrf(items): ...@@ -117,7 +119,7 @@ def chrf(items):
@register_aggregation("ter") @register_aggregation("ter")
def ter(items): def ter(items: Iterable[tuple[str, str]]):
"""Translation Error Rate is an error metric for machine translation that """Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one measures the number of edits required to change a system output into one
of the references of the references
...@@ -135,7 +137,9 @@ def ter(items): ...@@ -135,7 +137,9 @@ def ter(items):
@register_aggregation("brier_score") @register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function def brier_score(
items: Iterable[tuple[str, float]],
): # This is a passthrough function
gold, predictions = list(zip(*items)) gold, predictions = list(zip(*items))
bs, num_class = np.array(predictions).shape bs, num_class = np.array(predictions).shape
...@@ -203,8 +207,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -203,8 +207,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
def exact_match_hf_evaluate( def exact_match_hf_evaluate(
predictions, predictions: Iterable[str],
references, references: Iterable[str],
regexes_to_ignore=None, regexes_to_ignore=None,
ignore_case=False, ignore_case=False,
ignore_punctuation=False, ignore_punctuation=False,
...@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function ...@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="weighted_perplexity", aggregation="weighted_perplexity",
) )
def word_perplexity_fn(items): # This is a passthrough function def word_perplexity_fn(items: T) -> T: # This is a passthrough function
return items return items
...@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function ...@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="weighted_perplexity", aggregation="weighted_perplexity",
) )
def byte_perplexity_fn(items): # This is a passthrough function def byte_perplexity_fn(items: T) -> T: # This is a passthrough function
return items return items
...@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function ...@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="bits_per_byte", aggregation="bits_per_byte",
) )
def bits_per_byte_fn(items): # This is a passthrough function def bits_per_byte_fn(items: T) -> T: # This is a passthrough function
return items return items
...@@ -295,7 +299,7 @@ def pop_stddev(arr): ...@@ -295,7 +299,7 @@ def pop_stddev(arr):
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr: Sequence[T]) -> float: def sample_stddev(arr: Sequence[float]) -> float:
mu = mean(arr) mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
...@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
def weighted_mean(items: List[tuple[float, float]]) -> float: def weighted_mean(items: list[tuple[float, float]]) -> float:
a, b = zip(*items) a, b = zip(*items)
return sum(a) / sum(b) return sum(a) / sum(b)
...@@ -427,15 +431,15 @@ def is_non_str_iterable(obj): ...@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
def _sacreformat(refs, preds): def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular""" """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str]) # Sacrebleu expects (list[str], list[list[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred. # Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred # So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect # This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds # We expect refs to be list[str] or list[list[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds # Must become list[list[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs): if not is_non_str_iterable(refs):
refs = list(refs) refs = list(refs)
if not is_non_str_iterable(refs[0]): if not is_non_str_iterable(refs[0]):
...@@ -443,7 +447,7 @@ def _sacreformat(refs, preds): ...@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
refs = list(zip(*refs)) refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds # Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str] # We expect preds to be list[str] or list[list[str]]. Must become list[str]
if not is_non_str_iterable(preds): if not is_non_str_iterable(preds):
preds = list(preds) preds = list(preds)
if is_non_str_iterable(preds[0]): if is_non_str_iterable(preds[0]):
...@@ -456,7 +460,7 @@ def _sacreformat(refs, preds): ...@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal(Generic[T]):
""" """
Pool worker: `(i, xs)` → `n` bootstrap replicates Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`. of `f(xs)`using a RNG seeded with `i`.
...@@ -539,7 +543,7 @@ def bootstrap_stderr( ...@@ -539,7 +543,7 @@ def bootstrap_stderr(
def stderr_for_metric( def stderr_for_metric(
metric: Callable[[Sequence[T]], float], bootstrap_iters: int metric: Callable[[Sequence[T]], float], bootstrap_iters: int
) -> Optional[Callable[[Sequence[T]], float]]: ) -> Callable[[Sequence[T]], float] | None:
""" """
Return a function that estimates the standard error of `metric(xs)`. Return a function that estimates the standard error of `metric(xs)`.
...@@ -569,10 +573,10 @@ def stderr_for_metric( ...@@ -569,10 +573,10 @@ def stderr_for_metric(
stderr = {mean: mean_stderr, acc_all: acc_all_stderr} stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None) return stderr.get(metric)
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): def pooled_sample_stderr(stderrs: list[float], sizes: list[int]):
# Used to aggregate bootstrapped stderrs across subtasks in a group, # Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask. # when we are weighting by the size of each subtask.
# #
...@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): ...@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
return np.sqrt(pooled_sample_var / sum(sizes)) return np.sqrt(pooled_sample_var / sum(sizes))
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None): def combined_sample_stderr(stderrs: list[float], sizes: list[int], metrics=None):
assert metrics is not None, ( assert metrics is not None, (
"Need to pass a list of each subtask's metric for this stderr aggregation" "Need to pass a list of each subtask's metric for this stderr aggregation"
) )
...@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None) ...@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
return np.sqrt(variance) return np.sqrt(variance)
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True): def aggregate_subtask_metrics(
metrics: list[float], sizes: list[float], weight_by_size: bool = True
):
# A helper function that is used to aggregate # A helper function that is used to aggregate
# subtask scores cross-task. # subtask scores cross-task.
# TODO: does not hold for non-mean aggregations # TODO: does not hold for non-mean aggregations
...@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True): ...@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
assert len(metrics) == len(sizes) assert len(metrics) == len(sizes)
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes) return sum(metric * size for metric, size in zip(metrics, sizes)) / sum(sizes)
...@@ -1053,7 +1053,9 @@ class ConfigurableTask(Task): ...@@ -1053,7 +1053,9 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: dict, doc_to_target=None) -> Union[int, str, list]: def doc_to_target(
self, doc: dict, doc_to_target=None
) -> Union[int, str, list[int]]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_target = self.prompt # doc_to_target = self.prompt
if doc_to_target is not None: if doc_to_target is not None:
...@@ -1100,7 +1102,9 @@ class ConfigurableTask(Task): ...@@ -1100,7 +1102,9 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_choice( def doc_to_choice(
self, doc: dict, doc_to_choice: Union[str, list, dict, None] = None self,
doc: dict,
doc_to_choice: Union[str, list, dict, Callable[..., list[str]], None] = None,
) -> List[str]: ) -> List[str]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_choice = self.prompt # doc_to_choice = self.prompt
...@@ -1123,8 +1127,8 @@ class ConfigurableTask(Task): ...@@ -1123,8 +1127,8 @@ class ConfigurableTask(Task):
return list(doc_to_choice.values()) return list(doc_to_choice.values())
elif callable(doc_to_choice): elif callable(doc_to_choice):
return doc_to_choice(doc) return doc_to_choice(doc)
elif hasattr(doc_to_choice, "get_answer_choices_list"): # elif hasattr(doc_to_choice, "get_answer_choices_list"):
return doc_to_choice.get_answer_choices_list(doc) # return doc_to_choice.get_answer_choices_list(doc)
else: else:
raise TypeError raise TypeError
...@@ -1333,6 +1337,8 @@ class ConfigurableTask(Task): ...@@ -1333,6 +1337,8 @@ class ConfigurableTask(Task):
raise ValueError raise ValueError
# and this stores our "regular" conditional loglikelihoods # and this stores our "regular" conditional loglikelihoods
lls = lls[: len(choices)] lls = lls[: len(choices)]
else:
lls_unconditional = None
pred = np.argmax(lls) pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
...@@ -1390,6 +1396,9 @@ class ConfigurableTask(Task): ...@@ -1390,6 +1396,9 @@ class ConfigurableTask(Task):
} }
if "acc_mutual_info" in use_metric: if "acc_mutual_info" in use_metric:
assert lls_unconditional is not None, (
"lls_unconditional should not be None if acc_mutual_info is in use_metric"
)
lls_mutual_info = [ lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
] ]
......
...@@ -3,8 +3,8 @@ from typing import Any, Callable, Union ...@@ -3,8 +3,8 @@ from typing import Any, Callable, Union
def serialize_callable( def serialize_callable(
value: Union[Callable, str], keep_callable=False value: Union[Callable[..., Any], str], keep_callable=False
) -> Union[Callable, str]: ) -> Union[Callable[..., Any], str]:
"""Serializes a given function or string. """Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned. If 'keep_callable' is True, the original callable is returned.
......
from __future__ import annotations
import itertools import itertools
import json import json
import logging import logging
...@@ -5,7 +7,7 @@ import os ...@@ -5,7 +7,7 @@ import os
import random import random
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, Any, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -49,7 +51,7 @@ eval_logger = logging.getLogger(__name__) ...@@ -49,7 +51,7 @@ eval_logger = logging.getLogger(__name__)
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
model, model,
model_args: Optional[Union[str, dict]] = None, model_args: Optional[Union[str, dict[str, Any]]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None, tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None, num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = None, batch_size: Optional[Union[int, str]] = None,
...@@ -420,7 +422,7 @@ def simple_evaluate( ...@@ -420,7 +422,7 @@ def simple_evaluate(
def evaluate( def evaluate(
lm: "LM", lm: "LM",
task_dict, task_dict,
limit: Optional[int] = None, limit: int | float | None = None,
samples: Optional[dict] = None, samples: Optional[dict] = None,
cache_requests: bool = False, cache_requests: bool = False,
rewrite_requests_cache: bool = False, rewrite_requests_cache: bool = False,
......
...@@ -107,16 +107,19 @@ plugins.md028.enabled = false # no-blanks-blockquote ...@@ -107,16 +107,19 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values = true # ol-prefix plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled = false # no-bare-urls plugins.md034.enabled = false # no-bare-urls
[tool.ruff.lint]
extend-select = ["I", "W605"] [tool.ruff]
target-version = "py39"
lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2 lines-after-imports = 2
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores] [tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403"] "__init__.py" = ["F401","F402","F403"]
"utils.py" = ["F401"]
[dependency-groups] [dependency-groups]
dev = [ dev = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment