"vscode:/vscode.git/clone" did not exist on "119589fcb3b9568d3536d9011ca607286bf66e81"
Unverified Commit 65b8761d authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Switch Linting to `ruff` (#1166)

* add ruff and isort. remove black and flake8

* remove unnecessary dependencies

* remove dependency from table

* change order

* ran ruff

* check 3.9

* exclude evaluator

* update CI workflow

* use ruff config in pyproject.toml

* test

* add isort rules to ruff

* sort imports

* import `make_table`

* try stages for no-commit-to-branch

* turn on mypy for pre-commit

* test

* test

* test

* change no-commit-to-branch to default

* nits

* fixed dependency
parent 21d4ae98
...@@ -56,7 +56,7 @@ jobs: ...@@ -56,7 +56,7 @@ jobs:
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[dev]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
...@@ -17,29 +17,22 @@ jobs: ...@@ -17,29 +17,22 @@ jobs:
linter: linter:
name: Linters name: Linters
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 20 timeout-minutes: 5
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python 3.8 - name: Set up Python 3.8
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: 3.8 python-version: 3.8
cache: pip cache: pip
cache-dependency-path: setup.py cache-dependency-path: pyproject.toml
- name: Install dependencies
run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu ; export SKIP=no-commit-to-branch # env var deactivates --no-commit-to-branch
- name: Pre-Commit - name: Pre-Commit
env:
SKIP: "no-commit-to-branch,mypy"
uses: pre-commit/action@v3.0.0 uses: pre-commit/action@v3.0.0
- name: Lint with pylint
run: python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# # mypy turned off for now # # mypy turned off for now
# - name: Lint with mypy # - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable # run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
...@@ -53,17 +46,17 @@ jobs: ...@@ -53,17 +46,17 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: pip cache: pip
cache-dependency-path: setup.py cache-dependency-path: pyproject.toml
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[dev,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
...@@ -27,14 +27,16 @@ repos: ...@@ -27,14 +27,16 @@ repos:
args: [--remove] args: [--remove]
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 3.7.9 # Ruff version.
rev: v0.1.8
hooks: hooks:
- id: flake8 # Run the linter.
- repo: https://github.com/psf/black - id: ruff
rev: 22.3.0 args:
hooks: - --fix
- id: black # Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.1.0 rev: v2.1.0
hooks: hooks:
......
...@@ -49,11 +49,10 @@ pip install -e . ...@@ -49,11 +49,10 @@ pip install -e .
We also provide a number of optional dependencies for extended functionality. Extras can be installed via `pip install -e ".[NAME]"` We also provide a number of optional dependencies for extended functionality. Extras can be installed via `pip install -e ".[NAME]"`
| Name | Use | | Name | Use |
| ------------- | ------------------------------------- | |---------------|---------------------------------------|
| anthropic | For using Anthropic's models | | anthropic | For using Anthropic's models |
| dev | You probably don't want to use this |
| gptq | For loading models with GPTQ | | gptq | For loading models with GPTQ |
| testing | You probably don't want to use this | | dev | You probably don't want to use this |
| multilingual | For multilingual tokenizers | | multilingual | For multilingual tokenizers |
| openai | For using OpenAI's models | | openai | For using OpenAI's models |
| promptsource | For using PromtSource prompts | | promptsource | For using PromtSource prompts |
......
import argparse
import json
import logging
import os import os
import re import re
import sys import sys
import json
import logging
import argparse
import numpy as np
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import numpy as np
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.tasks import initialize_tasks, include_path
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
from lm_eval.tasks import include_path, initialize_tasks
from lm_eval.utils import make_table
def _handle_non_serializable(o): def _handle_non_serializable(o):
...@@ -170,7 +171,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -170,7 +171,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
task_names = ALL_TASKS task_names = ALL_TASKS
elif args.tasks == "list": elif args.tasks == "list":
eval_logger.info( eval_logger.info(
"Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS))) "Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS)))
) )
sys.exit() sys.exit()
else: else:
...@@ -271,9 +272,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -271,9 +272,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
) )
print(evaluator.make_table(results)) print(make_table(results))
if "groups" in results: if "groups" in results:
print(evaluator.make_table(results, "groups")) print(make_table(results, "groups"))
if __name__ == "__main__": if __name__ == "__main__":
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
from lm_eval.api.instance import Instance
from datasets import Dataset from datasets import Dataset
from lm_eval.api.instance import Instance
class Filter: class Filter:
""" """
...@@ -42,7 +43,6 @@ class FilterEnsemble: ...@@ -42,7 +43,6 @@ class FilterEnsemble:
filters: List[Filter] filters: List[Filter]
def apply(self, instances: List[Instance], docs: List[Dataset]) -> None: def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:
resps = [ resps = [
inst.resps for inst in instances inst.resps for inst in instances
] # operate just on the model responses ] # operate just on the model responses
......
import logging
import math import math
import random
from collections.abc import Iterable from collections.abc import Iterable
import evaluate
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
import random
import evaluate
from lm_eval.api.registry import register_metric, register_aggregation from lm_eval.api.registry import register_aggregation, register_metric
import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
# Register Aggregations First # Register Aggregations First
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr): def mean(arr):
......
import abc import abc
import hashlib
import json
import logging
import os import os
from typing import List, Optional, Tuple, Type, TypeVar
import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
import json
import hashlib
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
......
import os import logging
import evaluate import evaluate
from lm_eval.api.model import LM from lm_eval.api.model import LM
import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
...@@ -91,7 +92,6 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -91,7 +92,6 @@ DEFAULT_METRIC_REGISTRY = {
def register_metric(**args): def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics? # TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn): def decorate(fn):
assert "metric" in args assert "metric" in args
name = args["metric"] name = args["metric"]
...@@ -100,7 +100,6 @@ def register_metric(**args): ...@@ -100,7 +100,6 @@ def register_metric(**args):
("higher_is_better", HIGHER_IS_BETTER_REGISTRY), ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY), ("aggregation", METRIC_AGGREGATION_REGISTRY),
]: ]:
if key in args: if key in args:
value = args[key] value = args[key]
assert ( assert (
...@@ -120,7 +119,6 @@ def register_metric(**args): ...@@ -120,7 +119,6 @@ def register_metric(**args):
def get_metric(name, hf_evaluate_metric=False): def get_metric(name, hf_evaluate_metric=False):
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
...@@ -151,7 +149,6 @@ def register_aggregation(name): ...@@ -151,7 +149,6 @@ def register_aggregation(name):
def get_aggregation(name): def get_aggregation(name):
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
...@@ -161,7 +158,6 @@ def get_aggregation(name): ...@@ -161,7 +158,6 @@ def get_aggregation(name):
def get_metric_aggregation(name): def get_metric_aggregation(name):
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
......
...@@ -40,18 +40,18 @@ class ContextSampler: ...@@ -40,18 +40,18 @@ class ContextSampler:
self.doc_to_text(doc) self.doc_to_text(doc)
if ( if (
self.config.doc_to_choice is None self.config.doc_to_choice is None
or type(self.doc_to_text(doc)) is str or isinstance(self.doc_to_text(doc), str)
) )
else self.doc_to_choice(doc)[self.doc_to_text(doc)] else self.doc_to_choice(doc)[self.doc_to_text(doc)]
) )
+ self.target_delimiter + self.target_delimiter
+ ( + (
str(self.doc_to_target(doc)[0]) str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list if isinstance(self.doc_to_target(doc), list)
else self.doc_to_target(doc) else self.doc_to_target(doc)
if ( if (
self.config.doc_to_choice is None self.config.doc_to_choice is None
or type(self.doc_to_target(doc)) is str or isinstance(self.doc_to_target(doc), str)
) )
else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
) )
...@@ -77,8 +77,8 @@ class FirstNSampler(ContextSampler): ...@@ -77,8 +77,8 @@ class FirstNSampler(ContextSampler):
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
""" """
assert n <= len( assert (
self.docs n <= len(self.docs)
), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
return self.docs[:n] return self.docs[:n]
......
import abc import abc
from dataclasses import dataclass, field, asdict
import os
import re
import ast import ast
import yaml
import logging import logging
import evaluate import os
import random import random
import itertools import re
import functools from collections.abc import Callable
from tqdm import tqdm from dataclasses import asdict, dataclass
from typing import Any, List, Literal, Tuple, Union
import datasets import datasets
import numpy as np import numpy as np
from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
bits_per_byte,
mean, mean,
weighted_perplexity, weighted_perplexity,
bits_per_byte,
metric_max_over_ground_truths,
) )
from lm_eval.api.registry import ( from lm_eval.api.registry import (
get_metric, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation, get_aggregation,
get_metric,
get_metric_aggregation, get_metric_aggregation,
is_higher_better, is_higher_better,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
) )
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
"loglikelihood", "loglikelihood",
...@@ -349,9 +339,7 @@ class Task(abc.ABC): ...@@ -349,9 +339,7 @@ class Task(abc.ABC):
elif self.has_validation_docs(): elif self.has_validation_docs():
docs = self.validation_docs() docs = self.validation_docs()
else: else:
assert ( assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(f"Building contexts for task on rank {rank}...") eval_logger.info(f"Building contexts for task on rank {rank}...")
...@@ -603,9 +591,9 @@ class ConfigurableTask(Task): ...@@ -603,9 +591,9 @@ class ConfigurableTask(Task):
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
if type(agg_name) == str: if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name) self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = metric_config[ self._aggregation_list[metric_name] = metric_config[
"aggregation" "aggregation"
] ]
...@@ -672,9 +660,7 @@ class ConfigurableTask(Task): ...@@ -672,9 +660,7 @@ class ConfigurableTask(Task):
elif self.has_validation_docs(): elif self.has_validation_docs():
self.task_docs = self.validation_docs() self.task_docs = self.validation_docs()
else: else:
assert ( assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
# Test One Doc # Test One Doc
self.features = list(self.task_docs.features.keys()) self.features = list(self.task_docs.features.keys())
...@@ -686,20 +672,20 @@ class ConfigurableTask(Task): ...@@ -686,20 +672,20 @@ class ConfigurableTask(Task):
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc) test_choice = self.doc_to_choice(test_doc)
if type(test_choice) is not list: if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list") eval_logger.error("doc_to_choice must return list")
else: else:
num_choice = len(test_choice) num_choice = len(test_choice)
if type(test_text) is int: if isinstance(test_text, int):
self.multiple_input = num_choice self.multiple_input = num_choice
else: else:
test_choice = None test_choice = None
if type(test_target) is list: if isinstance(test_target, list):
self.multiple_target = len(test_target) self.multiple_target = len(test_target)
else: else:
if (type(test_target) is int) and (test_choice is not None): if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target] test_target = test_choice[test_target]
else: else:
test_target = str(test_target) test_target = str(test_target)
...@@ -808,11 +794,11 @@ class ConfigurableTask(Task): ...@@ -808,11 +794,11 @@ class ConfigurableTask(Task):
) )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if type(example) == str: if isinstance(example, str):
return labeled_examples + example return labeled_examples + example
elif type(example) == list: elif isinstance(example, list):
return [labeled_examples + ex for ex in example] return [labeled_examples + ex for ex in example]
elif type(example) == int: elif isinstance(example, int):
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
return labeled_examples + choices[example] return labeled_examples + choices[example]
...@@ -864,9 +850,9 @@ class ConfigurableTask(Task): ...@@ -864,9 +850,9 @@ class ConfigurableTask(Task):
else: else:
doc_to_text = self.config.doc_to_text doc_to_text = self.config.doc_to_text
if type(doc_to_text) == int: if isinstance(doc_to_text, int):
return doc_to_text return doc_to_text
elif type(doc_to_text) == str: elif isinstance(doc_to_text, str):
if doc_to_text in self.features: if doc_to_text in self.features:
# if self.config.doc_to_choice is not None: # if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]] # return self.doc_to_choice(doc)[doc[doc_to_text]]
...@@ -898,9 +884,9 @@ class ConfigurableTask(Task): ...@@ -898,9 +884,9 @@ class ConfigurableTask(Task):
else: else:
doc_to_target = self.config.doc_to_target doc_to_target = self.config.doc_to_target
if type(doc_to_target) == int: if isinstance(doc_to_target, int):
return doc_to_target return doc_to_target
elif type(doc_to_target) == str: elif isinstance(doc_to_target, str):
if doc_to_target in self.features: if doc_to_target in self.features:
# if self.config.doc_to_choice is not None: # if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]] # return self.doc_to_choice(doc)[doc[doc_to_target]]
...@@ -921,7 +907,7 @@ class ConfigurableTask(Task): ...@@ -921,7 +907,7 @@ class ConfigurableTask(Task):
return target_string return target_string
else: else:
return target_string return target_string
elif type(doc_to_target) == list: elif isinstance(doc_to_target, list):
return doc_to_target return doc_to_target
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) return doc_to_target(doc)
...@@ -944,14 +930,14 @@ class ConfigurableTask(Task): ...@@ -944,14 +930,14 @@ class ConfigurableTask(Task):
else: else:
doc_to_choice = self.config.doc_to_choice doc_to_choice = self.config.doc_to_choice
if type(doc_to_choice) == str: if isinstance(doc_to_choice, str):
if doc_to_choice in self.features: if doc_to_choice in self.features:
return doc[doc_to_choice] return doc[doc_to_choice]
else: else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif type(doc_to_choice) == list: elif isinstance(doc_to_choice, list):
return doc_to_choice return doc_to_choice
elif type(doc_to_choice) == dict: elif isinstance(doc_to_choice, dict):
return list(doc_to_choice.values()) return list(doc_to_choice.values())
elif callable(doc_to_choice): elif callable(doc_to_choice):
return doc_to_choice(doc) return doc_to_choice(doc)
...@@ -1078,14 +1064,14 @@ class ConfigurableTask(Task): ...@@ -1078,14 +1064,14 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
gold_index_error = False gold_index_error = False
if type(gold) is list: if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold] gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold: if -100 in gold:
gold_index_error = True gold_index_error = True
else: else:
if type(gold) is int: if isinstance(gold, int):
gold = gold if gold < len(choices) else -100 gold = gold if gold < len(choices) else -100
elif type(gold) is str: elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100 gold = choices.index(gold) if gold in choices else -100
if gold == -100: if gold == -100:
...@@ -1175,9 +1161,7 @@ class ConfigurableTask(Task): ...@@ -1175,9 +1161,7 @@ class ConfigurableTask(Task):
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric], **self._metric_fn_kwargs[metric],
) )
except ( except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
TypeError
): # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result]) result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
......
import datetime
import io
import json
import mmap
import os import os
from pathlib import Path
from typing import Any from typing import Any
import zstandard
import json
import jsonlines import jsonlines
import io
import datetime
import mmap
import tqdm import tqdm
from pathlib import Path import zstandard
def json_serial(obj: Any) -> str: def json_serial(obj: Any) -> str:
......
import time import collections
import random
import pickle
import json
import glob import glob
import json
import os import os
import collections import pickle
import random
import time
from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader from .archiver import ZStdTextReader
from .janitor import Janitor, word_ngrams
# Was used for testing the evaluator decoupled from the full logic below # Was used for testing the evaluator decoupled from the full logic below
...@@ -109,7 +109,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d ...@@ -109,7 +109,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
print(f"Merging lookups took {elapsed:0.5f} seconds.") print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"{ngrams_n_size} grams files found in {ngrams_path}:") print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst")) files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
print(files) print(files)
for file in files: for file in files:
...@@ -135,11 +135,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d ...@@ -135,11 +135,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
matching_unique += 1 matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]: for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)] task_doc_set = duplicates[(task_name, task_set)]
for ( for doc_id in doc_ids: # Record contamination across all relevant task/set combos
doc_id
) in (
doc_ids
): # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id) task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again del merged_lookup[ngram] # No point matching again
else: else:
......
import pickle
import re import re
import string import string
import pickle
import traceback import traceback
from pprint import pprint from typing import Iterator, List, Sequence, Tuple, TypeVar
from typing import Iterator, Sequence, TypeVar, List, Tuple
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
......
import random import random
import itertools import itertools
import json
import collections import collections
import sys
import torch import torch
...@@ -17,8 +15,6 @@ import lm_eval.api.registry ...@@ -17,8 +15,6 @@ import lm_eval.api.registry
from lm_eval.utils import ( from lm_eval.utils import (
positional_deprecated, positional_deprecated,
run_task_tests, run_task_tests,
make_table,
create_iterator,
get_git_commit_hash, get_git_commit_hash,
simple_parse_args_string, simple_parse_args_string,
eval_logger, eval_logger,
...@@ -91,7 +87,7 @@ def simple_evaluate( ...@@ -91,7 +87,7 @@ def simple_evaluate(
if gen_kwargs is not None: if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs) gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning( eval_logger.warning(
f"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks." "generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
) )
if gen_kwargs == "": if gen_kwargs == "":
gen_kwargs = None gen_kwargs = None
...@@ -118,7 +114,9 @@ def simple_evaluate( ...@@ -118,7 +114,9 @@ def simple_evaluate(
use_cache use_cache
# each rank receives a different cache db. # each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once # necessary to avoid multiple writes to cache at once
+ "_rank" + str(lm.rank) + ".db", + "_rank"
+ str(lm.rank)
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
...@@ -513,9 +511,7 @@ def evaluate( ...@@ -513,9 +511,7 @@ def evaluate(
) + total_size * current_size / ( ) + total_size * current_size / (
(total_size + current_size) (total_size + current_size)
* (total_size + current_size - 1) * (total_size + current_size - 1)
) * ( ) * (results[group][metric] - metric_score) ** 2
results[group][metric] - metric_score
) ** 2
else: else:
results[group][metric] = metric_score results[group][metric] = metric_score
results[group][stderr] = var_score results[group][stderr] = var_score
......
...@@ -32,7 +32,7 @@ def build_filter_ensemble(filter_name, components): ...@@ -32,7 +32,7 @@ def build_filter_ensemble(filter_name, components):
Create a filtering pipeline. Create a filtering pipeline.
""" """
filters = [] filters = []
for (function, kwargs) in components: for function, kwargs in components:
if kwargs is None: if kwargs is None:
f = get_filter(function)() f = get_filter(function)()
else: else:
......
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
import time import time
from typing import Any, List, Tuple
from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from typing import List, Any, Tuple from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
......
import random import random
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
......
import requests
import logging import logging
import time import time
from tqdm import tqdm
import requests
from requests.exceptions import RequestException from requests.exceptions import RequestException
from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
import copy
import os import os
from packaging import version from collections import defaultdict
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import transformers import transformers
from accelerate import Accelerator, DistributedType, find_executable_batch_size
from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from tqdm import tqdm
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
) )
from peft import __version__ as PEFT_VERSION, PeftModel
import copy
from collections import defaultdict
from tqdm import tqdm
from pathlib import Path
import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import stop_sequences_criteria
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union, Tuple, Literal
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -107,9 +106,7 @@ class HFLM(LM): ...@@ -107,9 +106,7 @@ class HFLM(LM):
eval_logger.warning( eval_logger.warning(
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
) )
assert ( assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
not parallelize
), "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
self._model = pretrained self._model = pretrained
self._device = self._model.device self._device = self._model.device
...@@ -279,10 +276,13 @@ class HFLM(LM): ...@@ -279,10 +276,13 @@ class HFLM(LM):
"with 'accelerate launch *script*'. " "with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices." f"Current run will proceed with {accelerator.num_processes} devices."
) )
assert accelerator.distributed_type in [ assert (
DistributedType.FSDP, accelerator.distributed_type
DistributedType.MULTI_GPU, in [
], "Unsupported distributed type provided. Only DDP and FSDP are supported." DistributedType.FSDP,
DistributedType.MULTI_GPU,
]
), "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP: if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model) self._model = accelerator.prepare(self.model)
else: else:
...@@ -417,7 +417,6 @@ class HFLM(LM): ...@@ -417,7 +417,6 @@ class HFLM(LM):
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> None: ) -> None:
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
...@@ -751,8 +750,9 @@ class HFLM(LM): ...@@ -751,8 +750,9 @@ class HFLM(LM):
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( context_enc, continuation_enc = (
continuation [self.eot_token_id],
self.tok_encode(continuation),
) )
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
...@@ -995,9 +995,7 @@ class HFLM(LM): ...@@ -995,9 +995,7 @@ class HFLM(LM):
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor( cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device cont_toks, dtype=torch.long, device=self.device
).unsqueeze( ).unsqueeze(0) # [1, seq]
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices # Obtain log-probs at the corresponding continuation token indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment