Unverified Commit c5ed8cdc authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #501 from EleutherAI/update-config

Update config
parents f6b76f5d c17e3659
import re
def doc_to_text(x):
def _mark_span(text, span_str, span_idx, mark):
pattern_tmpl = r"^((?:\S+\s){N})(W)"
pattern = re.sub("N", str(span_idx), pattern_tmpl)
pattern = re.sub("W", span_str, pattern)
return re.sub(pattern, r"\1{0} \2 {0}".format(mark), text)
text = x["text"]
text = _mark_span(text, x["span1_text"], x["span1_index"], "*")
# Compensate for 2 added "words" added in previous step.
span2_index = x["span2_index"] + 2 * (x["span1_index"] < x["span2_index"])
text = _mark_span(text, x["span2_text"], span2_index, "#")
return text
group:
- super-glue-t5-prompt
task: t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue
dataset_name: wsc
training_split: train
validation_split: validation
doc_to_text: !function "preprocess_wsc.doc_to_text"
doc_to_target: "{% set answer_choices = ['False', 'True'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
from . import arc
from . import gsm8k
from . import lambada
from . import pile
from . import wikitext
# TODO: define via __all__
\ No newline at end of file
......@@ -10,8 +10,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import re
from lm_eval.api.task import PerplexityTask, register_task
from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@misc{merity2016pointer,
......@@ -58,6 +60,7 @@ def wikitext_detokenizer(string):
return string
@register_task("wikitext")
class WikiText(PerplexityTask):
VERSION = "2.0"
......
import os
import pathlib
import re
import collections
import functools
import inspect
import sys
import yaml
import inspect
import pathlib
import functools
import subprocess
import collections
import importlib.util
from typing import List
from omegaconf import OmegaConf
......@@ -146,7 +150,6 @@ class Reorderer:
return res
def make_table(result_dict):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter
......@@ -253,6 +256,61 @@ def get_git_commit_hash():
return git_hash
def import_function(loader, node):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
module_name, function_name = function_name.split(".")
module_path = os.path.join(yaml_path, "{}.py".format(module_name))
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
# Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path):
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
env = Environment(loader=BaseLoader, undefined=StrictUndefined)
......
import argparse
import os
import json
import logging
import fnmatch
import yaml
import os
import argparse
from lm_eval import evaluator, tasks
from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY
from lm_eval import evaluator, utils
from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger
logging.getLogger("openai").setLevel(logging.WARNING)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
ALL_TASKS = sorted(list(TASK_REGISTRY))
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class MultiChoice:
def __init__(self, choices):
self.choices = choices
print(f"{ALL_TASKS} is this")
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
eval_logger.warning("{} is not in task list.".format(value))
# eval_logger.info(f"{choices} is this")
return True
......@@ -47,7 +44,6 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args()
......@@ -65,30 +61,29 @@ def main():
args = parse_args()
if args.limit:
print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks is None:
if args.config:
task_names = []
for config_files in args.config.split(","):
with open(config_files, "r") as f:
config = yaml.load(f, yaml.Loader)
if args.num_fewshot != 0:
config["num_fewshot"] = args.num_fewshot
if args.batch_size != None:
config["batch_size"] = args.batch_size
if args.tasks is not None:
if os.path.isdir(args.tasks):
import glob
task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file)
task_names.append(config)
else:
task_names = ALL_TASKS
else:
task_names = pattern_match(args.tasks.split(","), ALL_TASKS)
print(f"Selected Tasks: {task_names}")
tasks_list = args.tasks.split(",")
task_names = pattern_match(tasks_list, ALL_TASKS)
for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
task_names.append(config)
eval_logger.info(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(
model=args.model,
......
......@@ -42,6 +42,6 @@ setuptools.setup(
extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment