Commit 6fc2ac49 authored by Baber's avatar Baber
Browse files

fix circular

parent a9c16905
......@@ -24,11 +24,9 @@ import datasets
import numpy as np
from tqdm import tqdm
import lm_eval.tasks
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -1125,7 +1123,7 @@ class ConfigurableTask(Task):
# get task description
if description := self.config.description:
description = lm_eval.tasks.apply_template(self.config.description, doc)
description = utils.apply_template(self.config.description, doc)
# create system prompt based on the provided system instruction and description
if system_instruction is not None and description:
......@@ -1260,7 +1258,7 @@ class ConfigurableTask(Task):
return doc_to_decontamination_query(doc)
else:
return ast.literal_eval(
lm_eval.tasks.apply_template(
utils.apply_template(
self.config.doc_to_decontamination_query, doc
)
)
......@@ -1293,7 +1291,7 @@ class ConfigurableTask(Task):
# else:
return doc[doc_to_text]
else:
text_string = lm_eval.tasks.apply_template(doc_to_text, doc)
text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
......@@ -1329,7 +1327,7 @@ class ConfigurableTask(Task):
# else:
return doc[doc_to_target]
else:
target_string = lm_eval.tasks.apply_template(doc_to_target, doc)
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string)
elif (
......@@ -1372,9 +1370,7 @@ class ConfigurableTask(Task):
if doc_to_choice in self.features:
return doc[doc_to_choice]
else:
return ast.literal_eval(
lm_eval.tasks.apply_template(doc_to_choice, doc)
)
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list):
return doc_to_choice
elif isinstance(doc_to_choice, dict):
......@@ -1403,7 +1399,7 @@ class ConfigurableTask(Task):
if doc_to_image in self.features:
return doc[doc_to_image]
else:
return ast.literal_eval(lm_eval.tasks.apply_template(doc_to_image, doc))
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
elif callable(doc_to_image):
return doc_to_image(doc)
else:
......@@ -1426,7 +1422,7 @@ class ConfigurableTask(Task):
if doc_to_audio in self.features:
return doc[doc_to_audio]
else:
return ast.literal_eval(lm_eval.tasks.apply_template(doc_to_audio, doc))
return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
elif callable(doc_to_audio):
return doc_to_audio(doc)
else:
......@@ -1437,7 +1433,7 @@ class ConfigurableTask(Task):
if gen_prefix in self.features:
return doc[gen_prefix]
else:
return lm_eval.tasks.apply_template(gen_prefix, doc)
return utils.apply_template(gen_prefix, doc)
return None
def construct_requests(
......@@ -1802,6 +1798,8 @@ class MultipleChoiceTask(Task):
}
def aggregation(self) -> dict:
from lm_eval.api.metrics import mean
return {
"acc": mean,
"acc_norm": mean,
......@@ -1868,6 +1866,8 @@ class PerplexityTask(Task):
}
def aggregation(self) -> dict:
from lm_eval.api.metrics import bits_per_byte, weighted_perplexity
return {
"word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity,
......
......@@ -4,6 +4,7 @@ import os
from typing import Dict
import lm_eval.tasks
import lm_eval.utils
from lm_eval import utils
......@@ -123,7 +124,7 @@ class PromptString:
if "doc_to_choice" in self.prompt_string:
raise NotImplementedError("Not yet implemented to accept doc_to_choice")
text_string = lm_eval.tasks.apply_template(doc_to_text, doc)
target_string = lm_eval.tasks.apply_template(doc_to_target, doc)
text_string = lm_eval.utils.apply_template(doc_to_text, doc)
target_string = lm_eval.utils.apply_template(doc_to_target, doc)
return [text_string, target_string]
......@@ -3,7 +3,6 @@ import functools
import importlib.util
import inspect
import logging
import re
import sys
from functools import partial
from glob import iglob
......@@ -11,7 +10,6 @@ from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Union
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
from yaml import YAMLError
from lm_eval import utils
......@@ -177,28 +175,6 @@ def load_yaml_config(
return final_cfg
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
@functools.lru_cache(maxsize=256)
def _compile_tpl(src: str):
return apply_template._env.from_string(src)
def apply_template(template: str, doc: dict) -> str:
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(),
undefined=StrictUndefined,
keep_trailing_newline=True,
)
apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc)
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
# '**/*.yaml' is handled internally by os.scandir.
for path in iglob("**/*.yaml", root_dir=root, recursive=True):
......
......@@ -13,6 +13,7 @@ from itertools import islice
from typing import Any, Callable, Generator, List, Optional, Tuple
import numpy as np
from jinja2 import BaseLoader, Environment, StrictUndefined
SPACING = " " * 47
......@@ -511,3 +512,25 @@ def hash_dict_images(data_dict):
if importlib.util.find_spec("PIL")
else data_dict
)
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
@functools.lru_cache(maxsize=256)
def _compile_tpl(src: str):
return apply_template._env.from_string(src)
def apply_template(template: str, doc: dict) -> str:
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(),
undefined=StrictUndefined,
keep_trailing_newline=True,
)
apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc)
......@@ -11,8 +11,8 @@ from lm_eval.api.metrics import (
stderr_for_metric,
)
from lm_eval.models.utils import Collator
from lm_eval.tasks import apply_template
from lm_eval.utils import (
apply_template,
get_rolling_token_windows,
make_disjoint_window,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment