Commit 6fc2ac49 authored by Baber's avatar Baber
Browse files

fix circular

parent a9c16905
...@@ -24,11 +24,9 @@ import datasets ...@@ -24,11 +24,9 @@ import datasets
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import lm_eval.tasks
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -1125,7 +1123,7 @@ class ConfigurableTask(Task): ...@@ -1125,7 +1123,7 @@ class ConfigurableTask(Task):
# get task description # get task description
if description := self.config.description: if description := self.config.description:
description = lm_eval.tasks.apply_template(self.config.description, doc) description = utils.apply_template(self.config.description, doc)
# create system prompt based on the provided system instruction and description # create system prompt based on the provided system instruction and description
if system_instruction is not None and description: if system_instruction is not None and description:
...@@ -1260,7 +1258,7 @@ class ConfigurableTask(Task): ...@@ -1260,7 +1258,7 @@ class ConfigurableTask(Task):
return doc_to_decontamination_query(doc) return doc_to_decontamination_query(doc)
else: else:
return ast.literal_eval( return ast.literal_eval(
lm_eval.tasks.apply_template( utils.apply_template(
self.config.doc_to_decontamination_query, doc self.config.doc_to_decontamination_query, doc
) )
) )
...@@ -1293,7 +1291,7 @@ class ConfigurableTask(Task): ...@@ -1293,7 +1291,7 @@ class ConfigurableTask(Task):
# else: # else:
return doc[doc_to_text] return doc[doc_to_text]
else: else:
text_string = lm_eval.tasks.apply_template(doc_to_text, doc) text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None: if text_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(text_string) return ast.literal_eval(text_string)
else: else:
...@@ -1329,7 +1327,7 @@ class ConfigurableTask(Task): ...@@ -1329,7 +1327,7 @@ class ConfigurableTask(Task):
# else: # else:
return doc[doc_to_target] return doc[doc_to_target]
else: else:
target_string = lm_eval.tasks.apply_template(doc_to_target, doc) target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None: if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string) return ast.literal_eval(target_string)
elif ( elif (
...@@ -1372,9 +1370,7 @@ class ConfigurableTask(Task): ...@@ -1372,9 +1370,7 @@ class ConfigurableTask(Task):
if doc_to_choice in self.features: if doc_to_choice in self.features:
return doc[doc_to_choice] return doc[doc_to_choice]
else: else:
return ast.literal_eval( return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
lm_eval.tasks.apply_template(doc_to_choice, doc)
)
elif isinstance(doc_to_choice, list): elif isinstance(doc_to_choice, list):
return doc_to_choice return doc_to_choice
elif isinstance(doc_to_choice, dict): elif isinstance(doc_to_choice, dict):
...@@ -1403,7 +1399,7 @@ class ConfigurableTask(Task): ...@@ -1403,7 +1399,7 @@ class ConfigurableTask(Task):
if doc_to_image in self.features: if doc_to_image in self.features:
return doc[doc_to_image] return doc[doc_to_image]
else: else:
return ast.literal_eval(lm_eval.tasks.apply_template(doc_to_image, doc)) return ast.literal_eval(utils.apply_template(doc_to_image, doc))
elif callable(doc_to_image): elif callable(doc_to_image):
return doc_to_image(doc) return doc_to_image(doc)
else: else:
...@@ -1426,7 +1422,7 @@ class ConfigurableTask(Task): ...@@ -1426,7 +1422,7 @@ class ConfigurableTask(Task):
if doc_to_audio in self.features: if doc_to_audio in self.features:
return doc[doc_to_audio] return doc[doc_to_audio]
else: else:
return ast.literal_eval(lm_eval.tasks.apply_template(doc_to_audio, doc)) return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
elif callable(doc_to_audio): elif callable(doc_to_audio):
return doc_to_audio(doc) return doc_to_audio(doc)
else: else:
...@@ -1437,7 +1433,7 @@ class ConfigurableTask(Task): ...@@ -1437,7 +1433,7 @@ class ConfigurableTask(Task):
if gen_prefix in self.features: if gen_prefix in self.features:
return doc[gen_prefix] return doc[gen_prefix]
else: else:
return lm_eval.tasks.apply_template(gen_prefix, doc) return utils.apply_template(gen_prefix, doc)
return None return None
def construct_requests( def construct_requests(
...@@ -1802,6 +1798,8 @@ class MultipleChoiceTask(Task): ...@@ -1802,6 +1798,8 @@ class MultipleChoiceTask(Task):
} }
def aggregation(self) -> dict: def aggregation(self) -> dict:
from lm_eval.api.metrics import mean
return { return {
"acc": mean, "acc": mean,
"acc_norm": mean, "acc_norm": mean,
...@@ -1868,6 +1866,8 @@ class PerplexityTask(Task): ...@@ -1868,6 +1866,8 @@ class PerplexityTask(Task):
} }
def aggregation(self) -> dict: def aggregation(self) -> dict:
from lm_eval.api.metrics import bits_per_byte, weighted_perplexity
return { return {
"word_perplexity": weighted_perplexity, "word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity, "byte_perplexity": weighted_perplexity,
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
from typing import Dict from typing import Dict
import lm_eval.tasks import lm_eval.tasks
import lm_eval.utils
from lm_eval import utils from lm_eval import utils
...@@ -123,7 +124,7 @@ class PromptString: ...@@ -123,7 +124,7 @@ class PromptString:
if "doc_to_choice" in self.prompt_string: if "doc_to_choice" in self.prompt_string:
raise NotImplementedError("Not yet implemented to accept doc_to_choice") raise NotImplementedError("Not yet implemented to accept doc_to_choice")
text_string = lm_eval.tasks.apply_template(doc_to_text, doc) text_string = lm_eval.utils.apply_template(doc_to_text, doc)
target_string = lm_eval.tasks.apply_template(doc_to_target, doc) target_string = lm_eval.utils.apply_template(doc_to_target, doc)
return [text_string, target_string] return [text_string, target_string]
...@@ -3,7 +3,6 @@ import functools ...@@ -3,7 +3,6 @@ import functools
import importlib.util import importlib.util
import inspect import inspect
import logging import logging
import re
import sys import sys
from functools import partial from functools import partial
from glob import iglob from glob import iglob
...@@ -11,7 +10,6 @@ from pathlib import Path ...@@ -11,7 +10,6 @@ from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Union from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Union
import yaml import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
from yaml import YAMLError from yaml import YAMLError
from lm_eval import utils from lm_eval import utils
...@@ -177,28 +175,6 @@ def load_yaml_config( ...@@ -177,28 +175,6 @@ def load_yaml_config(
return final_cfg return final_cfg
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
@functools.lru_cache(maxsize=256)
def _compile_tpl(src: str):
return apply_template._env.from_string(src)
def apply_template(template: str, doc: dict) -> str:
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(),
undefined=StrictUndefined,
keep_trailing_newline=True,
)
apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc)
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]: def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
# '**/*.yaml' is handled internally by os.scandir. # '**/*.yaml' is handled internally by os.scandir.
for path in iglob("**/*.yaml", root_dir=root, recursive=True): for path in iglob("**/*.yaml", root_dir=root, recursive=True):
......
...@@ -13,6 +13,7 @@ from itertools import islice ...@@ -13,6 +13,7 @@ from itertools import islice
from typing import Any, Callable, Generator, List, Optional, Tuple from typing import Any, Callable, Generator, List, Optional, Tuple
import numpy as np import numpy as np
from jinja2 import BaseLoader, Environment, StrictUndefined
SPACING = " " * 47 SPACING = " " * 47
...@@ -511,3 +512,25 @@ def hash_dict_images(data_dict): ...@@ -511,3 +512,25 @@ def hash_dict_images(data_dict):
if importlib.util.find_spec("PIL") if importlib.util.find_spec("PIL")
else data_dict else data_dict
) )
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
@functools.lru_cache(maxsize=256)
def _compile_tpl(src: str):
return apply_template._env.from_string(src)
def apply_template(template: str, doc: dict) -> str:
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(),
undefined=StrictUndefined,
keep_trailing_newline=True,
)
apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc)
...@@ -11,8 +11,8 @@ from lm_eval.api.metrics import ( ...@@ -11,8 +11,8 @@ from lm_eval.api.metrics import (
stderr_for_metric, stderr_for_metric,
) )
from lm_eval.models.utils import Collator from lm_eval.models.utils import Collator
from lm_eval.tasks import apply_template
from lm_eval.utils import ( from lm_eval.utils import (
apply_template,
get_rolling_token_windows, get_rolling_token_windows,
make_disjoint_window, make_disjoint_window,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment