"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "efa5dcceb2b6a466fd44d5b58acffb90727b79e2"
Unverified Commit c5ed8cdc authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #501 from EleutherAI/update-config

Update config
parents f6b76f5d c17e3659
import re
def doc_to_text(x):
def _mark_span(text, span_str, span_idx, mark):
pattern_tmpl = r"^((?:\S+\s){N})(W)"
pattern = re.sub("N", str(span_idx), pattern_tmpl)
pattern = re.sub("W", span_str, pattern)
return re.sub(pattern, r"\1{0} \2 {0}".format(mark), text)
text = x["text"]
text = _mark_span(text, x["span1_text"], x["span1_index"], "*")
# Compensate for 2 added "words" added in previous step.
span2_index = x["span2_index"] + 2 * (x["span1_index"] < x["span2_index"])
text = _mark_span(text, x["span2_text"], span2_index, "#")
return text
group:
- super-glue-t5-prompt
task: t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue
dataset_name: wsc
training_split: train
validation_split: validation
doc_to_text: !function "preprocess_wsc.doc_to_text"
doc_to_target: "{% set answer_choices = ['False', 'True'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
from . import arc
from . import gsm8k
from . import lambada
from . import pile
from . import wikitext
# TODO: define via __all__
\ No newline at end of file
...@@ -10,8 +10,10 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -10,8 +10,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import re import re
from lm_eval.api.task import PerplexityTask, register_task
from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group
_CITATION = """ _CITATION = """
@misc{merity2016pointer, @misc{merity2016pointer,
...@@ -58,6 +60,7 @@ def wikitext_detokenizer(string): ...@@ -58,6 +60,7 @@ def wikitext_detokenizer(string):
return string return string
@register_task("wikitext") @register_task("wikitext")
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = "2.0" VERSION = "2.0"
......
import os import os
import pathlib
import re import re
import collections
import functools
import inspect
import sys import sys
import yaml
import inspect
import pathlib
import functools
import subprocess
import collections
import importlib.util
from typing import List from typing import List
from omegaconf import OmegaConf from omegaconf import OmegaConf
...@@ -146,7 +150,6 @@ class Reorderer: ...@@ -146,7 +150,6 @@ class Reorderer:
return res return res
def make_table(result_dict): def make_table(result_dict):
"""Generate table of results.""" """Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import MarkdownTableWriter, LatexTableWriter
...@@ -253,6 +256,61 @@ def get_git_commit_hash(): ...@@ -253,6 +256,61 @@ def get_git_commit_hash():
return git_hash return git_hash
def import_function(loader, node):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
module_name, function_name = function_name.split(".")
module_path = os.path.join(yaml_path, "{}.py".format(module_name))
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
# Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path):
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
env = Environment(loader=BaseLoader, undefined=StrictUndefined) env = Environment(loader=BaseLoader, undefined=StrictUndefined)
......
import argparse import os
import json import json
import logging
import fnmatch import fnmatch
import yaml import argparse
import os
from lm_eval import evaluator, tasks from lm_eval import evaluator, utils
from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger
logging.getLogger("openai").setLevel(logging.WARNING) os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
ALL_TASKS = sorted(list(TASK_REGISTRY))
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
self.choices = choices self.choices = choices
print(f"{ALL_TASKS} is this")
# Simple wildcard support (linux filename patterns) # Simple wildcard support (linux filename patterns)
def __contains__(self, values): def __contains__(self, values):
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
return False eval_logger.warning("{} is not in task list.".format(value))
# eval_logger.info(f"{choices} is this")
return True return True
...@@ -47,7 +44,6 @@ def parse_args(): ...@@ -47,7 +44,6 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None) parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args() return parser.parse_args()
...@@ -65,30 +61,29 @@ def main(): ...@@ -65,30 +61,29 @@ def main():
args = parse_args() args = parse_args()
if args.limit: if args.limit:
print( eval_logger.warning(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." " --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.tasks is None: if args.tasks is not None:
if args.config: if os.path.isdir(args.tasks):
task_names = [] import glob
for config_files in args.config.split(","):
with open(config_files, "r") as f:
config = yaml.load(f, yaml.Loader)
if args.num_fewshot != 0:
config["num_fewshot"] = args.num_fewshot
if args.batch_size != None:
config["batch_size"] = args.batch_size
task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file)
task_names.append(config) task_names.append(config)
else: else:
task_names = ALL_TASKS tasks_list = args.tasks.split(",")
else: task_names = pattern_match(tasks_list, ALL_TASKS)
task_names = pattern_match(args.tasks.split(","), ALL_TASKS) for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task):
print(f"Selected Tasks: {task_names}") config = utils.load_yaml_config(task)
task_names.append(config)
eval_logger.info(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=args.model, model=args.model,
......
...@@ -42,6 +42,6 @@ setuptools.setup( ...@@ -42,6 +42,6 @@ setuptools.setup(
extras_require={ extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"], "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"], "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"] "sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment