Commit 73f3029c authored by lintangsutawika's avatar lintangsutawika
Browse files

precommit format

parent f701ba7d
...@@ -50,7 +50,7 @@ dataset_kwargs: null # any extra keyword arguments that should be passed to the ...@@ -50,7 +50,7 @@ dataset_kwargs: null # any extra keyword arguments that should be passed to the
``` ```
dataset_path: json dataset_path: json
dataset_name: null dataset_name: null
dataset_kwargs: dataset_kwargs:
data_files: /path/to/my/json data_files: /path/to/my/json
``` ```
------------------------------- -------------------------------
......
# from .evaluator import evaluate, simple_evaluate # from .evaluator import evaluate, simple_evaluate
# from .logger import eval_logger, SPACING # from .logger import eval_logger, SPACING
\ No newline at end of file
...@@ -9,10 +9,6 @@ import numpy as np ...@@ -9,10 +9,6 @@ import numpy as np
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import logging
SPACING = " " * 47
logging.basicConfig( logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S", datefmt="%Y-%m-%d:%H:%M:%S",
...@@ -28,13 +24,14 @@ def _handle_non_serializable(o): ...@@ -28,13 +24,14 @@ def _handle_non_serializable(o):
else: else:
return str(o) return str(o)
def parse_eval_args() -> argparse.Namespace: def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", default="hf", help="Name of model e.g. `hf`") parser.add_argument("--model", default="hf", help="Name of model e.g. `hf`")
parser.add_argument( parser.add_argument(
"--tasks", "--tasks",
default=None, default=None,
help="To get full list of tasks, use the command lm-eval --tasks list" help="To get full list of tasks, use the command lm-eval --tasks list",
) )
parser.add_argument( parser.add_argument(
"--model_args", "--model_args",
...@@ -145,7 +142,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -145,7 +142,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.tasks is None: if args.tasks is None:
task_names = ALL_TASKS task_names = ALL_TASKS
elif args.tasks == "list": elif args.tasks == "list":
eval_logger.info("Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS)))) eval_logger.info(
"Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS)))
)
sys.exit() sys.exit()
else: else:
if os.path.isdir(args.tasks): if os.path.isdir(args.tasks):
...@@ -169,10 +168,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -169,10 +168,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
missing = ", ".join(task_missing) missing = ", ".join(task_missing)
eval_logger.error( eval_logger.error(
f"Tasks were not found: {missing}\n" f"Tasks were not found: {missing}\n"
f"{SPACING}Try `lm-eval -h` for list of available tasks", f"{' ' * 47}Try `lm-eval --tasks list` for list of available tasks",
) )
raise ValueError( raise ValueError(
f"Tasks {missing} were not found. Try `lm-eval -h` for list of available tasks." f"Tasks {missing} were not found. Try `lm-eval --tasks list` for list of available tasks."
) )
if args.output_path: if args.output_path:
......
...@@ -10,6 +10,7 @@ import evaluate ...@@ -10,6 +10,7 @@ import evaluate
from lm_eval.api.registry import register_metric, register_aggregation from lm_eval.api.registry import register_metric, register_aggregation
import logging import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
# Register Aggregations First # Register Aggregations First
......
...@@ -12,6 +12,7 @@ from tqdm import tqdm ...@@ -12,6 +12,7 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
import logging import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
T = TypeVar("T", bound="LM") T = TypeVar("T", bound="LM")
......
...@@ -3,6 +3,7 @@ import evaluate ...@@ -3,6 +3,7 @@ import evaluate
from lm_eval.api.model import LM from lm_eval.api.model import LM
import logging import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
......
...@@ -4,6 +4,7 @@ from dataclasses import dataclass, field, asdict ...@@ -4,6 +4,7 @@ from dataclasses import dataclass, field, asdict
import re import re
import ast import ast
import yaml import yaml
import logging
import evaluate import evaluate
import random import random
import itertools import itertools
...@@ -46,9 +47,10 @@ ALL_OUTPUT_TYPES = [ ...@@ -46,9 +47,10 @@ ALL_OUTPUT_TYPES = [
"generate_until", "generate_until",
] ]
import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
......
...@@ -17,6 +17,7 @@ import logging ...@@ -17,6 +17,7 @@ import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
def register_configurable_task(config: Dict[str, str]) -> int: def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
...@@ -141,7 +142,7 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None: ...@@ -141,7 +142,7 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
if type(config["task"]) == list: if type(config["task"]) == list:
register_configurable_group(config, yaml_path) register_configurable_group(config, yaml_path)
# Log this silently and show it only when # Log this silently and show it only when
# the user defines the appropriate verbosity. # the user defines the appropriate verbosity.
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
eval_logger.debug( eval_logger.debug(
......
...@@ -20,6 +20,7 @@ from jinja2 import BaseLoader, Environment, StrictUndefined ...@@ -20,6 +20,7 @@ from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
import logging import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment