"src/vscode:/vscode.git/clone" did not exist on "836974229bdf0e2d329bdfdb0f9c4920eae224b6"
Commit 1768fd3b authored by Baber's avatar Baber
Browse files

ruff rules; types

parent f650197a
......@@ -32,10 +32,9 @@ repos:
rev: v0.12.2
hooks:
# Run the linter.
- id: ruff
args:
- --fix
# Run the formatter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
......
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Optional, Union
@dataclass
class AggMetricConfig(dict):
metric: Optional[str] = None
aggregation: Optional[str] = "mean"
weight_by_size: Optional[str] = False
weight_by_size: bool = False
# list of filter names which should be incorporated into the aggregated metric.
filter_list: Optional[Union[str, list]] = "none"
......@@ -27,7 +27,7 @@ class GroupConfig(dict):
group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None
aggregate_metric_list: Optional[
Union[List[AggMetricConfig], AggMetricConfig, dict]
Union[list[AggMetricConfig], AggMetricConfig, dict]
] = None
version: Optional[str] = None
metadata: Optional[dict] = (
......
from __future__ import annotations
import logging
import math
import os
import random
import re
import string
from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, TypeVar
from collections.abc import Callable, Iterable, Sequence
from typing import Generic, TypeVar
import numpy as np
......@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
@register_aggregation("mean")
def mean(arr: list[float]) -> float:
def mean(arr: Sequence[float]) -> float:
return sum(arr) / len(arr)
......@@ -70,7 +72,7 @@ def f1_score(items):
@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
def matthews_corrcoef(items: Iterable[tuple[int, int] | tuple[str, str]]) -> float:
from sklearn.metrics import matthews_corrcoef
unzipped_list = list(zip(*items))
......@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
@register_aggregation("bleu")
def bleu(items):
def bleu(items: Iterable[tuple[str, str]]):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
......@@ -117,7 +119,7 @@ def chrf(items):
@register_aggregation("ter")
def ter(items):
def ter(items: Iterable[tuple[str, str]]):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
......@@ -135,7 +137,9 @@ def ter(items):
@register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function
def brier_score(
items: Iterable[tuple[str, float]],
): # This is a passthrough function
gold, predictions = list(zip(*items))
bs, num_class = np.array(predictions).shape
......@@ -203,8 +207,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and
# limitations under the License.
def exact_match_hf_evaluate(
predictions,
references,
predictions: Iterable[str],
references: Iterable[str],
regexes_to_ignore=None,
ignore_case=False,
ignore_punctuation=False,
......@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def word_perplexity_fn(items): # This is a passthrough function
def word_perplexity_fn(items: T) -> T: # This is a passthrough function
return items
......@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items): # This is a passthrough function
def byte_perplexity_fn(items: T) -> T: # This is a passthrough function
return items
......@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
)
def bits_per_byte_fn(items): # This is a passthrough function
def bits_per_byte_fn(items: T) -> T: # This is a passthrough function
return items
......@@ -295,7 +299,7 @@ def pop_stddev(arr):
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr: Sequence[T]) -> float:
def sample_stddev(arr: Sequence[float]) -> float:
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
......@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
def weighted_mean(items: List[tuple[float, float]]) -> float:
def weighted_mean(items: list[tuple[float, float]]) -> float:
a, b = zip(*items)
return sum(a) / sum(b)
......@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# Sacrebleu expects (list[str], list[list[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
# We expect refs to be list[str] or list[list[str]], the outer list corresponding to preds
# Must become list[list[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs):
refs = list(refs)
if not is_non_str_iterable(refs[0]):
......@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
# We expect preds to be list[str] or list[list[str]]. Must become list[str]
if not is_non_str_iterable(preds):
preds = list(preds)
if is_non_str_iterable(preds[0]):
......@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
# stderr stuff
class _bootstrap_internal:
class _bootstrap_internal(Generic[T]):
"""
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
......@@ -539,7 +543,7 @@ def bootstrap_stderr(
def stderr_for_metric(
metric: Callable[[Sequence[T]], float], bootstrap_iters: int
) -> Optional[Callable[[Sequence[T]], float]]:
) -> Callable[[Sequence[T]], float] | None:
"""
Return a function that estimates the standard error of `metric(xs)`.
......@@ -569,10 +573,10 @@ def stderr_for_metric(
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None)
return stderr.get(metric)
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
def pooled_sample_stderr(stderrs: list[float], sizes: list[int]):
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask.
#
......@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
return np.sqrt(pooled_sample_var / sum(sizes))
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
def combined_sample_stderr(stderrs: list[float], sizes: list[int], metrics=None):
assert metrics is not None, (
"Need to pass a list of each subtask's metric for this stderr aggregation"
)
......@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
return np.sqrt(variance)
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
def aggregate_subtask_metrics(
metrics: list[float], sizes: list[float], weight_by_size: bool = True
):
# A helper function that is used to aggregate
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
......@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
assert len(metrics) == len(sizes)
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
return sum(metric * size for metric, size in zip(metrics, sizes)) / sum(sizes)
......@@ -1053,7 +1053,9 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: dict, doc_to_target=None) -> Union[int, str, list]:
def doc_to_target(
self, doc: dict, doc_to_target=None
) -> Union[int, str, list[int]]:
# if self.prompt is not None:
# doc_to_target = self.prompt
if doc_to_target is not None:
......@@ -1100,7 +1102,9 @@ class ConfigurableTask(Task):
raise TypeError
def doc_to_choice(
self, doc: dict, doc_to_choice: Union[str, list, dict, None] = None
self,
doc: dict,
doc_to_choice: Union[str, list, dict, Callable[..., list[str]], None] = None,
) -> List[str]:
# if self.prompt is not None:
# doc_to_choice = self.prompt
......@@ -1123,8 +1127,8 @@ class ConfigurableTask(Task):
return list(doc_to_choice.values())
elif callable(doc_to_choice):
return doc_to_choice(doc)
elif hasattr(doc_to_choice, "get_answer_choices_list"):
return doc_to_choice.get_answer_choices_list(doc)
# elif hasattr(doc_to_choice, "get_answer_choices_list"):
# return doc_to_choice.get_answer_choices_list(doc)
else:
raise TypeError
......@@ -1333,6 +1337,8 @@ class ConfigurableTask(Task):
raise ValueError
# and this stores our "regular" conditional loglikelihoods
lls = lls[: len(choices)]
else:
lls_unconditional = None
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
......@@ -1390,6 +1396,9 @@ class ConfigurableTask(Task):
}
if "acc_mutual_info" in use_metric:
assert lls_unconditional is not None, (
"lls_unconditional should not be None if acc_mutual_info is in use_metric"
)
lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
......
......@@ -3,8 +3,8 @@ from typing import Any, Callable, Union
def serialize_callable(
value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
value: Union[Callable[..., Any], str], keep_callable=False
) -> Union[Callable[..., Any], str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
......
from __future__ import annotations
import itertools
import json
import logging
......@@ -5,7 +7,7 @@ import os
import random
import time
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union
import numpy as np
import torch
......@@ -49,7 +51,7 @@ eval_logger = logging.getLogger(__name__)
@positional_deprecated
def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
model_args: Optional[Union[str, dict[str, Any]]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = None,
......@@ -420,7 +422,7 @@ def simple_evaluate(
def evaluate(
lm: "LM",
task_dict,
limit: Optional[int] = None,
limit: int | float | None = None,
samples: Optional[dict] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
......
......@@ -107,16 +107,19 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled = false # no-bare-urls
[tool.ruff.lint]
extend-select = ["I", "W605"]
[tool.ruff]
target-version = "py39"
lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117"]
[tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2
known-first-party = ["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403"]
"utils.py" = ["F401"]
[dependency-groups]
dev = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment