"vscode:/vscode.git/clone" did not exist on "89120f1fbe0900b263b063942470934038e46faa"
Unverified Commit d924ca33 authored by ben's avatar ben Committed by GitHub
Browse files

Merge pull request #2 from EleutherAI/multigpu-feature-minor-edits

Multigpu feature minor edits
parents 650d3c76 c77fa461
from . import arc
from . import gsm8k
from . import lambada
from . import pile
from . import wikitext
# TODO: define via __all__
\ No newline at end of file
...@@ -10,8 +10,10 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -10,8 +10,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import re import re
from lm_eval.api.task import PerplexityTask, register_task
from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group
_CITATION = """ _CITATION = """
@misc{merity2016pointer, @misc{merity2016pointer,
...@@ -58,6 +60,7 @@ def wikitext_detokenizer(string): ...@@ -58,6 +60,7 @@ def wikitext_detokenizer(string):
return string return string
@register_task("wikitext") @register_task("wikitext")
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = "2.0" VERSION = "2.0"
......
import os import os
import pathlib
import re import re
import collections
import functools
import inspect
import sys import sys
import yaml
import inspect
import pathlib
import functools
import subprocess
import collections
import importlib.util
from typing import List from typing import List
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
class ExitCodeError(Exception): class ExitCodeError(Exception):
pass pass
...@@ -146,7 +151,6 @@ class Reorderer: ...@@ -146,7 +151,6 @@ class Reorderer:
return res return res
def make_table(result_dict): def make_table(result_dict):
"""Generate table of results.""" """Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import MarkdownTableWriter, LatexTableWriter
...@@ -253,6 +257,61 @@ def get_git_commit_hash(): ...@@ -253,6 +257,61 @@ def get_git_commit_hash():
return git_hash return git_hash
def import_function(loader, node):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
module_name, function_name = function_name.split(".")
module_path = os.path.join(yaml_path, "{}.py".format(module_name))
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
# Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path):
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
env = Environment(loader=BaseLoader, undefined=StrictUndefined) env = Environment(loader=BaseLoader, undefined=StrictUndefined)
...@@ -261,10 +320,10 @@ def apply_template(template, doc): ...@@ -261,10 +320,10 @@ def apply_template(template, doc):
return rtemplate.render(**doc) return rtemplate.render(**doc)
def create_iterator(raw_iterator, rank, world_size, limit = None): def create_iterator(raw_iterator, rank, world_size, limit=None):
""" """
Method for creating a (potentially) sliced and limited Method for creating a (potentially) sliced and limited
iterator from a raw document iterator. Used for splitting data iterator from a raw document iterator. Used for splitting data
among ranks in multigpu setting or only pulling a sample of documents among ranks in multigpu setting or only pulling a sample of documents
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
import argparse import os
import json import json
import logging
import fnmatch import fnmatch
import yaml import argparse
import os
from lm_eval import evaluator, tasks from lm_eval import evaluator, utils
from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger
logging.getLogger("openai").setLevel(logging.WARNING) os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
ALL_TASKS = sorted(list(TASK_REGISTRY))
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
self.choices = choices self.choices = choices
print(f"{ALL_TASKS} is this")
# Simple wildcard support (linux filename patterns) # Simple wildcard support (linux filename patterns)
def __contains__(self, values): def __contains__(self, values):
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
return False eval_logger.warning("{} is not in task list.".format(value))
# eval_logger.info(f"{choices} is this")
return True return True
...@@ -47,7 +44,6 @@ def parse_args(): ...@@ -47,7 +44,6 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None) parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args() return parser.parse_args()
...@@ -65,30 +61,29 @@ def main(): ...@@ -65,30 +61,29 @@ def main():
args = parse_args() args = parse_args()
if args.limit: if args.limit:
print( eval_logger.warning(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." " --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.tasks is None: if args.tasks is not None:
if args.config: if os.path.isdir(args.tasks):
task_names = [] import glob
for config_files in args.config.split(","):
with open(config_files, "r") as f:
config = yaml.load(f, yaml.Loader)
if args.num_fewshot != 0:
config["num_fewshot"] = args.num_fewshot
if args.batch_size != None:
config["batch_size"] = args.batch_size
task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file)
task_names.append(config) task_names.append(config)
else: else:
task_names = ALL_TASKS tasks_list = args.tasks.split(",")
else: task_names = pattern_match(tasks_list, ALL_TASKS)
task_names = pattern_match(args.tasks.split(","), ALL_TASKS) for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task):
print(f"Selected Tasks: {task_names}") config = utils.load_yaml_config(task)
task_names.append(config)
eval_logger.info(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=args.model, model=args.model,
......
...@@ -43,6 +43,6 @@ setuptools.setup( ...@@ -43,6 +43,6 @@ setuptools.setup(
extras_require={ extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"], "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"], "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"] "sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment