Commit 74df9bea authored by zhaoying1's avatar zhaoying1
Browse files

added deepseekv2

parents
Pipeline #1652 failed with stages
in 0 seconds
import argparse
import json
import logging
import os
import re
import sys
from functools import partial
from pathlib import Path
from typing import Union
import numpy as np
from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.logging_utils import WandbLogger
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table, simple_parse_args_string
DEFAULT_RESULTS_FILE = "results.json"
def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
def parse_value(item):
item = item.strip().lower()
if item == "none":
return None
try:
return int(item)
except ValueError:
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
items = [parse_value(v) for v in value.split(split_char)]
num_items = len(items)
if num_items == 1:
# Makes downstream handling the same for single and multiple values
items = items * max_len
elif num_items != max_len:
raise argparse.ArgumentTypeError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
return items
def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
parser.add_argument(
"--tasks",
"-t",
default=None,
metavar="task1,task2",
help="To get full list of tasks, use the command lm-eval --tasks list",
)
parser.add_argument(
"--model_args",
"-a",
default="",
help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
)
parser.add_argument(
"--num_fewshot",
"-f",
type=int,
default=None,
metavar="N",
help="Number of examples in few-shot context",
)
parser.add_argument(
"--batch_size",
"-b",
type=str,
default=1,
metavar="auto|auto:N|N",
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=None,
metavar="N",
help="Maximal batch size to try with --batch_size auto.",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use (e.g. cuda, cuda:0, cpu).",
)
parser.add_argument(
"--output_path",
"-o",
default=None,
type=str,
metavar="DIR|DIR/file.json",
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
)
parser.add_argument(
"--limit",
"-L",
type=float,
default=None,
metavar="N|0<N<1",
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument(
"--use_cache",
"-c",
type=str,
default=None,
metavar="DIR",
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
)
parser.add_argument(
"--cache_requests",
type=str,
default=None,
choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
)
parser.add_argument(
"--check_integrity",
action="store_true",
help="Whether to run the relevant part of the test suite for the tasks.",
)
parser.add_argument(
"--write_out",
"-w",
action="store_true",
default=False,
help="Prints the prompt for the first few documents.",
)
parser.add_argument(
"--log_samples",
"-s",
action="store_true",
default=False,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
)
parser.add_argument(
"--show_config",
action="store_true",
default=False,
help="If True, shows the the full config of all tasks at the end of the evaluation.",
)
parser.add_argument(
"--include_path",
type=str,
default=None,
metavar="DIR",
help="Additional path to include if there are external tasks to include.",
)
parser.add_argument(
"--gen_kwargs",
default=None,
help=(
"String arguments for model generation on greedy_until tasks,"
" e.g. `temperature=0,top_k=0,top_p=0`."
),
)
parser.add_argument(
"--verbosity",
"-v",
type=str.upper,
default="INFO",
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
)
parser.add_argument(
"--wandb_args",
default="",
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
)
parser.add_argument(
"--predict_only",
"-x",
action="store_true",
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
parser.add_argument(
"--seed",
type=partial(_int_or_none_list_arg_type, 3),
default="0,1234,1234", # for backward compatibility
help=(
"Set seed for python's random, numpy and torch.\n"
"Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
"or a single integer to set the same seed for all three.\n"
"The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
"E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
"E.g, `--seed 42` sets all three seeds to 42."
),
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
return parser.parse_args()
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if not args:
# we allow for args to be passed externally, else we parse them ourselves
args = parse_eval_args()
if args.wandb_args:
wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
eval_logger = utils.eval_logger
eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
eval_logger.info(f"Verbosity set to {args.verbosity}")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
initialize_tasks(args.verbosity)
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
if args.limit:
eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
include_path(args.include_path)
if args.tasks is None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
elif args.tasks == "list":
eval_logger.info(
"Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
)
sys.exit()
else:
if os.path.isdir(args.tasks):
import glob
task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file)
task_names.append(config)
else:
task_list = args.tasks.split(",")
task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
task_names.append(config)
task_missing = [
task for task in task_list if task not in task_names and "*" not in task
] # we don't want errors if a wildcard ("*") task name was used
if task_missing:
missing = ", ".join(task_missing)
eval_logger.error(
f"Tasks were not found: {missing}\n"
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
)
raise ValueError(
f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
)
if args.output_path:
path = Path(args.output_path)
# check if file or 'dir/results.json' exists
if path.is_file():
raise FileExistsError(f"File already exists at {path}")
output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
if output_path_file.is_file():
eval_logger.warning(
f"File {output_path_file} already exists. Results will be overwritten."
)
# if path json then get parent dir
elif path.suffix in (".json", ".jsonl"):
output_path_file = path
path.parent.mkdir(parents=True, exist_ok=True)
path = path.parent
else:
path.mkdir(parents=True, exist_ok=True)
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code:
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
args.model_args = (
args.model_args
+ f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
)
eval_logger.info(f"Selected Tasks: {task_names}")
eval_logger.info("Loading selected tasks...")
request_caching_args = request_caching_arg_to_dict(
cache_requests=args.cache_requests
)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
use_cache=args.use_cache,
limit=args.limit,
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs,
task_manager=task_manager,
predict_only=args.predict_only,
random_seed=args.seed[0],
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
**request_caching_args,
)
if results is not None:
if args.log_samples:
samples = results.pop("samples")
dumped = json.dumps(
results, indent=2, default=_handle_non_serializable, ensure_ascii=False
)
if args.show_config:
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# Add W&B logging
if args.wandb_args:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if args.log_samples:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
if args.output_path:
output_path_file.open("w", encoding="utf-8").write(dumped)
if args.log_samples:
for task_name, config in results["configs"].items():
output_name = "{}_{}".format(
re.sub("/|=", "__", args.model_args), task_name
)
filename = path.joinpath(f"{output_name}.jsonl")
samples_dumped = json.dumps(
samples[task_name],
indent=2,
default=_handle_non_serializable,
ensure_ascii=False,
)
filename.write_text(samples_dumped, encoding="utf-8")
print(
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
if args.wandb_args:
# Tear down wandb run once all the logging is done.
wandb_logger.run.finish()
if __name__ == "__main__":
cli_evaluate()
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance
class Filter(ABC):
"""
Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`)
across all instances of a task, and perform operations.
In a single run, one can configure any number of separate filters or lists of filters.
"""
def __init__(self, **kwargs) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
[<filtered resps for instance 0>, <filtered resps for instance 1>]
"""
return resps
@dataclass
class FilterEnsemble:
"""
FilterEnsemble creates a pipeline applying multiple filters.
Its intended usage is to stack multiple post-processing steps in order.
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
pipeline separately.
"""
name: str
filters: List[Callable[[], Filter]]
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
for f in self.filters:
# apply filters in sequence
resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
@dataclass
class Instance:
request_type: OutputType
doc: dict
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: Optional[str] = None
doc_id: Optional[int] = None
repeats: Optional[int] = None
def __post_init__(self) -> None:
# unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self):
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
return (
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
)
import logging
import math
import random
from collections.abc import Iterable
from typing import List
import evaluate as hf_evaluate
import numpy as np
import sacrebleu
import sklearn.metrics
from lm_eval.api.registry import register_aggregation, register_metric
eval_logger = logging.getLogger("lm-eval")
# Register Aggregations First
@register_aggregation("bypass")
def bypass_agg(arr):
return 999
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_aggregation("f1")
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
# print(preds)
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_aggregation("bleu")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
@register_aggregation("chrf")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_aggregation("ter")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
@register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function
gold, predictions = list(zip(*items))
gold = list(gold)
gold_one_hot = np.eye(np.max(gold) + 1)[gold]
predictions = list(zip(*items))[1]
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
@register_metric(
metric="brier_score",
higher_is_better=False,
output_type=["multiple_choice"],
aggregation="brier_score",
)
def brier_score_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_norm",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_norm_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_mutual_info",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def acc_mutual_info_fn(items): # This is a passthrough function
return items
#exact_match = hf_evaluate.load("exact_match")
exact_match = hf_evaluate.load("/workspace/evaluate/metrics/exact_match/exact_match.py", module_type="metric")
@register_metric(
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs)
@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
aggregation="perplexity",
)
def perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="word_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def word_perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="byte_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
)
def bits_per_byte_fn(items): # This is a passthrough function
return items
def pop_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
@register_metric(
metric="bypass",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice", "generate_until"],
aggregation="bypass",
)
def bypass(items):
return None
@register_metric(
metric="mcc",
higher_is_better=True,
output_type="multiple_choice",
aggregation="matthews_corrcoef",
)
def mcc_fn(items): # This is a passthrough function
return items
@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
aggregation="f1",
)
def f1_fn(items): # This is a passthrough function
return items
@register_metric(
metric="bleu",
higher_is_better=True,
output_type="generate_until",
aggregation="bleu",
)
def bleu_fn(items): # This is a passthrough function
return items
@register_metric(
metric="chrf",
higher_is_better=True,
output_type="generate_until",
aggregation="chrf",
)
def chrf_fn(items): # This is a passthrough function
return items
@register_metric(
metric="ter",
higher_is_better=True,
output_type="generate_until",
aggregation="ter",
)
def ter_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_all",
higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
paragraph_id = doc["idx"]["paragraph"]
question_id = doc["idx"]["question"]
if (paragraph_id, question_id) not in question_scoring_dict:
question_scoring_dict[(paragraph_id, question_id)] = []
gold_label = doc["label"] == 1
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc
def acc_all_stderr(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict:
question_scoring_dict[question_id] = []
gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
return acc
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Compute max metric between prediction and each ground truth."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs):
refs = list(refs)
if not is_non_str_iterable(refs[0]):
refs = [[ref] for ref in refs]
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not is_non_str_iterable(preds):
preds = list(preds)
if is_non_str_iterable(preds[0]):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds]
return refs, preds
# stderr stuff
class _bootstrap_internal:
def __init__(self, f, n) -> None:
self.f = f
self.n = n
def __call__(self, v):
i, xs = v
rnd = random.Random()
rnd.seed(i)
res = []
for _ in range(self.n):
res.append(self.f(rnd.choices(xs, k=len(xs))))
return res
def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev.
# Unfortunately, I haven't been able to figure out what the right correction is
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
res = []
chunk_size = min(1000, iters)
from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement
res.extend(bootstrap)
pool.close()
return sample_stddev(res)
def stderr_for_metric(metric, bootstrap_iters):
bootstrappable = [
median,
matthews_corrcoef,
f1_score,
perplexity,
bleu,
chrf,
ter,
]
if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None)
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask.
#
assert len(stderrs) == len(sizes)
# formula source: https://en.wikipedia.org/wiki/Pooled_variance
# and: https://stats.stackexchange.com/a/4841331
# this empirically seems to match running `stderr_for_metric` on all instances
# from the subtasks concatenated with each other.
pooled_sample_var = (
sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
) / (sum(sizes) - len(sizes))
return np.sqrt(pooled_sample_var / sum(sizes))
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
assert (
metrics is not None
), "Need to pass a list of each subtask's metric for this stderr aggregation"
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
# This formula depends on sample means.
# removed because it seems to give erroneously huge stderrs for groupings of tasks
# and does not seem to match up with bootstrap-calculated stderrs for groups.
### don't use this unless a statistician has told you it's the right thing to do ###
# accumulators: we'll aggregate pairwise N - 1 times
variance = stderrs[0] ** 2
curr_size = sizes[0]
curr_score = metrics[0]
for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
curr_score = ((curr_score * curr_size) + (score * size)) / (
curr_size + size
) # NOTE: this assumes our aggregation fn is "mean"
variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
curr_size + size - 1
) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
curr_score - score
) ** 2
return np.sqrt(variance)
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
# A helper function that is used to aggregate
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
if not weight_by_size:
sizes = [1] * len(sizes)
assert len(metrics) == len(sizes)
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
import abc
import hashlib
import json
import logging
import os
from typing import List, Optional, Tuple, Type, TypeVar
from sqlitedict import SqliteDict
from tqdm import tqdm
from lm_eval import utils
eval_logger = logging.getLogger("lm-eval")
T = TypeVar("T", bound="LM")
class LM(abc.ABC):
def __init__(self) -> None:
"""Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output
(inputs/outputs should be tokenization-agnostic.)
"""
# set rank and world size to a single process, by default.
self._rank = 0
self._world_size = 1
self.cache_hook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
`continuation: str`
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
`logprob: float`
The log probability of `continuation`.
`isgreedy`:
Whether `continuation` would be generated by greedy sampling from `context`.
"""
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float]]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context,).
string: str
String for which we are computing overall loglikelihood
:return: list[tuple[float]]
A list of tuples (logprob,)
logprob: float
The log probability of `context` conditioned on the EOT token.
"""
pass
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests) -> List[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until).
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list[str]
A list of strings continuation
continuation: str
The generated continuation.
"""
pass
@classmethod
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
@classmethod
def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given arg_obj
Parameters:
- arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config
additional_config = {
k: v for k, v in additional_config.items() if v is not None
}
return cls(**arg_dict, **additional_config)
@property
def rank(self):
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return self._rank
@property
def world_size(self):
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return self._world_size
def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook
### SQLite-based caching of LM responses
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
def __init__(self, cachinglm) -> None:
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res) -> None:
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM:
def __init__(self, lm, cache_db) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
Underlying LM
:param cache_db: str
Path to cache db
"""
self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr):
lm_attr = getattr(self.lm, attr)
if not callable(lm_attr):
return lm_attr
def fn(requests):
res = []
remaining_reqs = []
warned = False
# figure out which ones are cached and which ones are new
eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
)
for req in tqdm(requests, desc="Checking cached requests"):
hsh = hash_args(attr, req.args)
if attr == "generate_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned:
eval_logger.warning(
f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
)
warned = True
res.append(None)
remaining_reqs.append(req)
elif hsh in self.dbdict:
ob = self.dbdict[hsh]
assert ob is not None
res.append(ob)
else:
res.append(None)
remaining_reqs.append(req)
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
while res[resptr] is not None:
resptr += 1
res[resptr] = r
# caching
hsh = hash_args(attr, req.args)
self.dbdict[hsh] = r
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
return CacheHook(self)
class TemplateLM(LM):
"""
A class acting as intermediary between the LM base class
and boilerplate often included in other LM subclasses.
"""
@property
@abc.abstractmethod
def eot_token_id(self):
pass
@abc.abstractmethod
def tok_encode(self, string: str, **kwargs):
pass
@abc.abstractmethod
def _loglikelihood_tokens(self, requests, **kwargs):
pass
def _encode_pair(self, context, continuation):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
pass
@abc.abstractmethod
def generate_until(self, requests) -> List[str]:
pass
import logging
from typing import Callable, Dict
import evaluate as hf_evaluate
from lm_eval.api.model import LM
eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
)
TASK_REGISTRY = {}
GROUP_REGISTRY = {}
ALL_TASKS = set()
func2task_index = {}
def register_task(name):
def decorate(fn):
assert (
name not in TASK_REGISTRY
), f"task named '{name}' conflicts with existing registered task!"
TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name
return fn
return decorate
def register_group(name):
def decorate(fn):
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn
return decorate
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
}
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert (
value not in registry
), f"{key} named '{value}' conflicts with existing registered {key}!"
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
return fn
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
eval_logger.warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
)
try:
metric_object = hf_evaluate.load(name)
return metric_object.compute
except Exception:
eval_logger.error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name: str):
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> bool:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
)
class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd
assert self.rnd, "must pass rnd to FewShotSampler!"
self.task = task
self.config = task._config
self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter
self.doc_to_text = self.task.doc_to_text
self.doc_to_target = self.task.doc_to_target
self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
self.fewshot_delimiter.join(
[
# TODO: is separating doc_to_text and doc_to_target by one space always desired?
(
self.doc_to_text(doc)
if (
self.config.doc_to_choice is None
or isinstance(self.doc_to_text(doc), str)
)
else self.doc_to_choice(doc)[self.doc_to_text(doc)]
)
+ self.target_delimiter
+ (
str(self.doc_to_target(doc)[0])
if isinstance(self.doc_to_target(doc), list)
else self.doc_to_target(doc)
if (
self.config.doc_to_choice is None
or isinstance(self.doc_to_target(doc), str)
)
else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)
for doc in selected_docs
]
)
+ self.fewshot_delimiter
)
return labeled_examples
def sample(self, n):
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
return self.rnd.sample(self.docs, n)
class FirstNSampler(ContextSampler):
def sample(self, n) -> None:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
"""
assert (
n <= len(self.docs)
), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
return self.docs[:n]
class BalancedSampler(ContextSampler):
def sample(self, n) -> None:
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
"""
pass
class ManualSampler(ContextSampler):
def sample(self, n) -> None:
""" """
pass
SAMPLER_REGISTRY = {
"default": ContextSampler,
"first_n": FirstNSampler,
}
def get_sampler(name):
try:
return SAMPLER_REGISTRY[name]
except KeyError:
raise ValueError(
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
)
import abc
import ast
import logging
import random
import re
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
)
import datasets
import numpy as np
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
"loglikelihood_rolling",
"generate_until",
]
eval_logger = logging.getLogger("lm-eval")
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
group: Optional[Union[str, list]] = None
group_alias: Optional[Union[str, list]] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[
str
] = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None
# runtime configuration options
num_fewshot: Optional[int] = None
# scoring options
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
metadata: Optional[
dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
raise ValueError(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
}
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict = asdict(self)
# remove values that are `None`
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
elif k == "metric_list":
for metric_dict in v:
for metric_key, metric_value in metric_dict.items():
if callable(metric_value):
metric_dict[metric_key] = self.serialize_function(
metric_value, keep_callable=keep_callable
)
cfg_dict[k] = v
elif callable(v):
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
return cfg_dict
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems,
answers, and evaluation methods. See BoolQ for a simple example implementation
A `doc` can be any python object which represents one instance of evaluation.
This is usually a dictionary e.g.
{"question": ..., "answer": ...} or
{"question": ..., question, answer)
"""
VERSION: Optional[Union[int, str]] = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: Optional[str] = None
# The name of a subset within `DATASET_PATH`.
DATASET_NAME: Optional[str] = None
OUTPUT_TYPE: Optional[OutputType] = None
def __init__(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode: Optional[datasets.DownloadMode] = None,
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
) -> None:
"""
:param data_dir: str
Stores the path to a local folder containing the `Task`'s data files.
Use this to specify the path to manually downloaded data (usually when
the dataset is not publicly accessible).
:param cache_dir: str
The directory to read/write the `Task` dataset. This follows the
HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process
to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode
How to treat pre-existing `Task` downloads and data.
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
Reuse download and reuse dataset.
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
Reuse download with fresh dataset.
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
self.download(data_dir, cache_dir, download_mode)
self._training_docs: Optional[list] = None
self._fewshot_docs: Optional[list] = None
self._instances: Optional[List[Instance]] = None
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
def download(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode=None,
) -> None:
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
:param data_dir: str
Stores the path to a local folder containing the `Task`'s data files.
Use this to specify the path to manually downloaded data (usually when
the dataset is not publicly accessible).
:param cache_dir: str
The directory to read/write the `Task` dataset. This follows the
HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode
How to treat pre-existing `Task` downloads and data.
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
Reuse download and reuse dataset.
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
Reuse download with fresh dataset.
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
@property
def config(self) -> TaskConfig:
"""Returns the TaskConfig associated with this class."""
return self._config
@abc.abstractmethod
def has_training_docs(self):
"""Whether the task has a training set"""
pass
@abc.abstractmethod
def has_validation_docs(self):
"""Whether the task has a validation set"""
pass
@abc.abstractmethod
def has_test_docs(self):
"""Whether the task has a test set"""
pass
def training_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def validation_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def test_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def fewshot_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if self.has_training_docs():
return self.training_docs()
elif self.has_validation_docs():
return self.validation_docs()
else:
eval_logger.warning(
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
)
return self.test_docs()
def _process_doc(self, doc: dict) -> dict:
"""
Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split.
E.g. `map(self._process_doc, self.dataset["validation"])`
:return: dict
The processed version of the specified `doc`.
"""
return doc
@property
def instances(self) -> List[Instance]:
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
"""
return self._instances
def fewshot_examples(self, k, rnd):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc):
raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query."
)
@abc.abstractmethod
def doc_to_text(self, doc):
pass
@abc.abstractmethod
def doc_to_target(self, doc):
pass
def build_all_requests(
self,
*,
limit=None,
rank=None,
world_size=None,
cache_requests=False,
rewrite_requests_cache=False,
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
# used with caching
og_limit = limit
cache_key = f"requests-{self._config.task}"
cached_instances = load_from_cache(file_name=cache_key)
if cache_requests and cached_instances and not rewrite_requests_cache:
cached_instances = cached_instances[:limit]
flattened_instances = [
instance
for instance_group in cached_instances
for instance in instance_group
]
self._instances = flattened_instances
return
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
instances = []
# process all documents when caching is specified for simplicity
if (
cache_requests
and (not cached_instances or rewrite_requests_cache)
and limit is not None
):
limit = None
doc_id_docs = list(
self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
)
num_docs = len(doc_id_docs)
for doc_id, doc in tqdm(
doc_id_docs,
total=num_docs,
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats),
)
if not isinstance(inst, list):
inst = [inst]
instances.append(inst)
# now flatten, this is to allow slicing to work with pickles
sliced_instances = instances[:og_limit]
flattened_instances = [
instance
for instance_group in sliced_instances
for instance in instance_group
]
self._instances = flattened_instances
if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
if cache_requests and (not cached_instances or rewrite_requests_cache):
save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
:param doc_idx: int
The index of a document within `self.test_docs()` or `self.validation_docs()`,
whichever is the main split used.
:param repeats: int
TODO: update this docstring
The number of times each instance in a dataset is inferred on. Defaults to 1,
can be increased for techniques like majority voting.
"""
pass
@abc.abstractmethod
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
pass
@abc.abstractmethod
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
pass
@abc.abstractmethod
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
pass
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
@classmethod
def count_bytes(cls, doc):
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc):
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
@utils.positional_deprecated
def fewshot_context(
self,
doc,
num_fewshot,
rnd=random.Random(1234),
description=None,
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd`"
)
description = description if description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
def apply_filters(self) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
def dump_config(self) -> dict:
"""Returns the config as a dictionary."""
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (num_fewshot)
return self.config.to_dict()
def set_config(self, key: str, value: Any, update: bool = False) -> None:
"""Set or update the configuration for a given key."""
if key is None:
raise ValueError("Key must be provided.")
if update:
current_value = getattr(self._config, key, {})
if not isinstance(current_value, dict):
raise TypeError(
f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
)
current_value.update(value)
else:
setattr(self._config, key, value)
def override_metric(self, metric_name: str) -> None:
"""
Override the default metrics used for evaluation with custom metrics.
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
if not isinstance(self, ConfigurableTask):
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name)
}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
@property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
if self.has_test_docs():
return self.test_docs()
elif self.has_validation_docs():
return self.validation_docs()
else:
raise ValueError(
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
)
def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
) -> Iterator[Tuple[int, Any]]:
limit = int(limit) if limit else None
doc_iterator = utils.create_iterator(
enumerate(self.eval_docs),
rank=int(rank),
limit=limit,
world_size=int(world_size),
)
return doc_iterator
class ConfigurableTask(Task):
VERSION = "Yaml"
OUTPUT_TYPE = None
CONFIG = None
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
) -> None: # TODO no super() call here
# Get pre-configured attributes
self._config = self.CONFIG
# Use new configurations if there was no preconfiguration
if self.config is None:
self._config = TaskConfig(**config)
# Overwrite configs
else:
if config is not None:
self._config.__dict__.update(config)
if self.config is None:
raise ValueError(
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if isinstance(self.config.metadata, dict):
if "version" in self.config.metadata:
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None:
if self.config.output_type not in ALL_OUTPUT_TYPES:
raise ValueError(
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
)
self.OUTPUT_TYPE = self.config.output_type
if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
hf_evaluate_metric = (
"hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True
)
if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = get_metric(
metric_name, hf_evaluate_metric
)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self.config.dataset_kwargs)
self._training_docs = None
self._fewshot_docs = None
if self.config.filter_list is not None:
self._filters = []
for filter_config in self.config.filter_list:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self.config.use_prompt is not None:
eval_logger.info(f"loading prompt {self.config.use_prompt}")
self.prompt = get_prompt(
self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
)
else:
self.prompt = None
if self.fewshot_docs() is not None:
self.sampler = samplers.get_sampler(
self.config.fewshot_config.get("sampler", "default")
if self.config.fewshot_config
else "default"
)(list(self.fewshot_docs()), self, rnd=random.Random(1234))
self.task_docs = self.eval_docs
# Test One Doc
self.features = list(self.task_docs.features.keys())
self.multiple_input = 0
self.multiple_target = 0
test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list")
else:
num_choice = len(test_choice)
if isinstance(test_text, int):
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target]
else:
test_target = str(test_target)
if test_choice is not None:
check_choices = test_choice
else:
check_choices = [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False
delimiter_has_whitespace = (
True
if self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
else False
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
print(self.DATASET_NAME)
dataset_kwargs = {}
if "cmmlu" in self.DATASET_PATH:
dataset_kwargs['data_files'] = {
'dev': self.DATASET_PATH+"/"+self.DATASET_NAME+"/1.0.1/efcc940752ea4a1ea94d2727f11f83858d64fc8e/cmmlu-dev.arrow",
'test': self.DATASET_PATH + "/" + self.DATASET_NAME + "/1.0.1/efcc940752ea4a1ea94d2727f11f83858d64fc8e/cmmlu-test.arrow",
}
elif "ceval" in self.DATASET_PATH:
dataset_kwargs['data_files'] = {
'val': self.DATASET_PATH+"/"+self.DATASET_NAME+"/1.0.0/3923b519fd180e689d0961bf3a032ece929742f3/ceval-exam-val.arrow",
'dev': self.DATASET_PATH + "/" + self.DATASET_NAME + "/1.0.0/3923b519fd180e689d0961bf3a032ece929742f3/ceval-exam-val.arrow"
}
elif "hellaswag" in self.DATASET_PATH:
dataset_kwargs['data_files'] = {
'train': self.DATASET_PATH +"/default/0.1.0/512a66dd8b1b1643ab4a48aa4f150d04c91680da6a4096498a5e5f799623d5ae/hellaswag-train.arrow",
'validation': self.DATASET_PATH + "/default/0.1.0/512a66dd8b1b1643ab4a48aa4f150d04c91680da6a4096498a5e5f799623d5ae/hellaswag-validation.arrow"
}
elif "mmlu" in self.DATASET_PATH:
dataset_kwargs['data_files'] = {
'dev': self.DATASET_PATH + "/" + self.DATASET_NAME + "/1.0.0/b7d5f7f21003c21be079f11495ee011332b980bd1cd7e70cc740e8c079e5bda2/mmlu_no_train-validation.arrow",
'test': self.DATASET_PATH + "/" + self.DATASET_NAME + "/1.0.0/b7d5f7f21003c21be079f11495ee011332b980bd1cd7e70cc740e8c079e5bda2/mmlu_no_train-test.arrow"
}
elif "mathqa" in self.DATASET_PATH:
dataset_kwargs['data_files'] = {
'train': self.DATASET_PATH +"/default/0.1.0/c4f1cc784c04c4957b50c97858f23893b633eea6/math_qa-train.arrow",
'validation': self.DATASET_PATH + "/default/0.1.0/c4f1cc784c04c4957b50c97858f23893b633eea6/math_qa-validation.arrow",
'test': self.DATASET_PATH + "/default/0.1.0/c4f1cc784c04c4957b50c97858f23893b633eea6/math_qa-test.arrow"
}
elif "gsm8k" in self.DATASET_PATH:
dataset_kwargs['data_files'] = {
'train': self.DATASET_PATH +"/main/0.0.0/e53f048856ff4f594e959d75785d2c2d37b678ee/gsm8k-train.arrow",
'test': self.DATASET_PATH +"/main/0.0.0/e53f048856ff4f594e959d75785d2c2d37b678ee/gsm8k-test.arrow"
}
else:
raise ValueError("Only Support cmmlu, ceval-valid, hellaswag, mmlu and mathqa datasets")
self.dataset = datasets.load_dataset(
"arrow",
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self) -> bool:
if self.config.training_split is not None:
return True
else:
return False
def has_validation_docs(self) -> bool:
if self.config.validation_split is not None:
return True
else:
return False
def has_test_docs(self) -> bool:
if self.config.test_split is not None:
return True
else:
return False
def training_docs(self) -> datasets.Dataset:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
self.dataset[self.config.training_split]
)
return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
self.dataset[self.config.validation_split]
)
return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
return self.dataset[self.config.test_split]
def fewshot_docs(self):
if self.config.fewshot_split is not None:
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.fewshot_split])
return self.dataset[self.config.fewshot_split]
else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning(
f"Task '{self.config.task}': "
"num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule."
)
return super().fewshot_docs()
@utils.positional_deprecated
def fewshot_context(self, doc: str, num_fewshot: int) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:returns: str
The fewshot context.
"""
if num_fewshot == 0:
# always prepend the (possibly empty) task description
labeled_examples = self.config.description
else:
labeled_examples = self.config.description + self.sampler.get_context(
doc, num_fewshot
)
example = self.doc_to_text(doc)
if self.multiple_input:
return labeled_examples
else:
if isinstance(example, str):
return labeled_examples + example
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
def apply_filters(self):
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
def should_decontaminate(self):
return self.config.should_decontaminate
def doc_to_decontamination_query(self, doc):
if self.config.should_decontaminate:
if self.config.doc_to_decontamination_query is None:
return self.doc_to_text(doc)
else:
doc_to_decontamination_query = self.config.doc_to_decontamination_query
if doc_to_decontamination_query in self.features:
return doc[doc_to_decontamination_query]
elif callable(doc_to_decontamination_query):
return doc_to_decontamination_query(doc)
else:
return ast.literal_eval(
utils.apply_template(
self.config.doc_to_decontamination_query, doc
)
)
def _process_doc(self, doc: dict) -> dict:
"""
Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split.
E.g. `map(self._process_doc, self.dataset["validation"])`
:return: dict
The processed version of the specified `doc`.
"""
return doc
def doc_to_text(self, doc):
if self.prompt is not None:
doc_to_text = self.prompt
else:
doc_to_text = self.config.doc_to_text
if isinstance(doc_to_text, int):
return doc_to_text
elif isinstance(doc_to_text, str):
if doc_to_text in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]]
# else:
return doc[doc_to_text]
else:
text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
return text_string
elif callable(doc_to_text):
return doc_to_text(doc)
# Used when applying a Promptsource template
elif hasattr(doc_to_text, "apply"):
applied_prompt = doc_to_text.apply(doc)
if len(applied_prompt) == 2:
return applied_prompt[0]
else:
eval_logger.warning("Applied prompt returns empty string")
return self.config.fewshot_delimiter
else:
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
else:
doc_to_target = self.config.doc_to_target
if isinstance(doc_to_target, int):
return doc_to_target
elif isinstance(doc_to_target, str):
if doc_to_target in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]]
# else:
return doc[doc_to_target]
else:
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string)
elif (
len(target_string) >= 2
and (target_string[0] == "[")
and (target_string[-1] == "]")
):
try:
return ast.literal_eval(target_string)
except (SyntaxError, ValueError):
return target_string
else:
return target_string
elif isinstance(doc_to_target, list):
return doc_to_target
elif callable(doc_to_target):
return doc_to_target(doc)
# Used when applying a Promptsource template
elif hasattr(doc_to_target, "apply"):
applied_prompt = doc_to_target.apply(doc)
if len(applied_prompt) == 2:
return applied_prompt[1]
else:
eval_logger.warning("Applied prompt returns empty string")
return self.config.fewshot_delimiter
else:
raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None:
doc_to_choice = self.prompt
elif self.config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config")
else:
doc_to_choice = self.config.doc_to_choice
if isinstance(doc_to_choice, str):
if doc_to_choice in self.features:
return doc[doc_to_choice]
else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list):
return doc_to_choice
elif isinstance(doc_to_choice, dict):
return list(doc_to_choice.values())
elif callable(doc_to_choice):
return doc_to_choice(doc)
elif hasattr(doc_to_choice, "get_answer_choices_list"):
return doc_to_choice.get_answer_choices_list(doc)
else:
raise TypeError
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter
if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx
cont = self.doc_to_target(doc)
arguments = [
(ctx + choice, f"{target_delimiter}{cont}") for choice in choices
]
else:
# Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
)
def process_results(self, doc, results):
if callable(self.config.process_results):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
return {
**(
{"word_perplexity": (loglikelihood, _words)}
if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
if len(lls_unconditional) != len(choices):
raise ValueError
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
if self.multiple_input:
gold = self.doc_to_text(doc)
else:
gold = self.doc_to_target(doc)
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
if gold_index_error:
eval_logger.warning(
f"Label index was not in within range of available choices,"
f"Sample:\n\n{doc}\n\n"
)
if self.multiple_target:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
else:
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0
prob_norm = utils.softmax(lls)
# TODO use keyword arguments to the metric?
# gold, pred, norm stuff, the original lls,
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
if "brier_score" in use_metric
else {}
),
}
if "acc_mutual_info" in use_metric:
lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
if self.config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
elif type(gold) != type(result):
# cast gold to the same type as result
gold = type(result)(gold)
for metric in self._metric_fn_list.keys():
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric == "exact_match":
result = [result for _ in range(len(gold))]
scores = self._metric_fn_list[metric](
references=gold,
predictions=result,
**self._metric_fn_kwargs[metric],
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold:
try:
result_score = self._metric_fn_list[metric](
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
try:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
result_dict[metric] = result_score
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
)
return result_dict
def aggregation(self) -> dict:
return self._aggregation_list
def higher_is_better(self) -> dict:
return self._higher_is_better
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
def __repr__(self):
return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
)
class MultipleChoiceTask(Task):
OUTPUT_TYPE = "loglikelihood"
def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
# TODO: add mutual info here?
return [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
gold = doc["gold"]
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
return {
"acc": acc,
"acc_norm": acc_norm,
}
def higher_is_better(self) -> dict:
return {
"acc": True,
"acc_norm": True,
}
def aggregation(self) -> dict:
return {
"acc": mean,
"acc_norm": mean,
}
class PerplexityTask(Task):
OUTPUT_TYPE = "loglikelihood_rolling"
def has_training_docs(self) -> bool:
return False
def fewshot_examples(self, k: int, rnd) -> List:
if k != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
)
return []
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
if num_fewshot != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
)
return ""
def higher_is_better(self) -> dict:
return {
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def doc_to_decontamination_query(self, doc):
return doc
def doc_to_text(self, doc) -> str:
return ""
def doc_to_target(self, doc):
return doc
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
if bool(ctx):
raise ValueError
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(self.doc_to_target(doc),),
idx=0,
**kwargs,
)
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
return {
"word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes_),
"bits_per_byte": (loglikelihood, bytes_),
}
def aggregation(self) -> dict:
return {
"word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity,
"bits_per_byte": bits_per_byte,
}
@classmethod
def count_bytes(cls, doc) -> int:
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc) -> int:
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
import hashlib
import os
import dill
from lm_eval.utils import eval_logger
MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
# This should be sufficient for uniqueness
HASH_INPUT = "EleutherAI-lm-evaluation-harness"
HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
def load_from_cache(file_name):
try:
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
with open(path, "rb") as file:
cached_task_dict = dill.loads(file.read())
return cached_task_dict
except Exception:
eval_logger.debug(f"{file_name} is not cached, generating...")
pass
def save_to_cache(file_name, obj):
if not os.path.exists(PATH):
os.mkdir(PATH)
file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
eval_logger.debug(f"Saving {file_path} to cache...")
with open(file_path, "wb") as file:
file.write(dill.dumps(obj))
# NOTE the "key" param is to allow for flexibility
def delete_cache(key: str = ""):
files = os.listdir(PATH)
for file in files:
if file.startswith(key) and file.endswith(FILE_SUFFIX):
file_path = f"{PATH}/{file}"
os.unlink(file_path)
import datetime
import io
import json
import mmap
import os
from pathlib import Path
from typing import Any
import jsonlines
import tqdm
import zstandard
def json_serial(obj: Any) -> str:
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)):
return obj.isoformat()
raise TypeError("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file.
class Archive:
def __init__(self, file_path: str, compression_level: int = 3) -> None:
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, "wb")
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta=None) -> None:
if meta is None:
meta = {}
self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
)
+ b"\n"
)
def commit(self) -> None:
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
def __init__(self) -> None:
pass
def read(
self,
file,
get_meta: bool = False,
autojoin_paragraphs: bool = True,
para_joiner: str = "\n\n",
):
with open(file, "rb") as fh:
self.fh = fh
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
rdr = jsonlines.Reader(reader)
for ob in rdr:
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
if isinstance(ob, str):
assert not get_meta
yield ob
continue
text = ob["text"]
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)
if get_meta:
yield text, (ob["meta"] if "meta" in ob else {})
else:
yield text
class TextArchive:
def __init__(self, file_path, mode: str = "rb+") -> None:
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
if not os.path.exists(file_path):
Path(file_path).touch()
self.fh = open(self.file_path, mode)
def add_data(self, data) -> None:
self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self) -> None:
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path) -> None:
self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency: int = 10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, "r", encoding="utf-8") as fh, tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
unit="byte",
unit_scale=1,
) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh:
while True:
line = fh.readline()
if line == -1 or line == "":
break
else:
yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file) -> None:
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
import collections
import glob
import json
import os
import pickle
import random
import time
from .archiver import ZStdTextReader
from .janitor import Janitor, word_ngrams
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
simulated_overlap = 0.1
contaminated = int(len(docs) * simulated_overlap)
return random.sample(range(len(docs)), contaminated)
# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
ngrams_n_size = info_dict["ngram_size"]
janitor = Janitor()
# Build lookup for each dataset first in case we use different task combinations later
print("Building Lookups...")
start = time.perf_counter()
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
lookups = {}
duplicates = {} # (task_name, task_set): set(doc_ids)}
sets_to_decontaminate = len(docs_by_task_set.keys())
for (task_name, task_set), docs in docs_by_task_set.items():
if not os.path.exists(f"data/{task_name}"):
os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(
open(overlaps_dump_path, "rb")
)
sets_to_decontaminate -= 1
continue
else:
duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = (
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
)
if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...")
lookups[(task_name, task_set)] = pickle.load(
open(task_set_lookup_path, "rb")
)
else:
print(f"{task_set_lookup_path} not available, building...")
lookup = collections.defaultdict(set)
for doc_id, document in enumerate(docs):
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
for ngram in ngrams:
lookup[ngram].add(doc_id)
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
lookups[(task_name, task_set)] = lookup
elapsed = time.perf_counter() - start
print(f"Building lookups took {elapsed:0.5f} seconds.")
matched_ngrams = []
if sets_to_decontaminate > 0:
print("Merging lookups...")
start = time.perf_counter()
merged_lookup = collections.defaultdict(list)
for (task_name, task_set), lookup in lookups.items():
for ngram, doc_ids in lookup.items():
merged_lookup[ngram].append((task_name, task_set, doc_ids))
elapsed = time.perf_counter() - start
print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
print(files)
for file in files:
start = time.perf_counter()
print(f"Scanning {file}")
reader = ZStdTextReader(file)
total_ngrams = 0
unique_ngrams = 0
matching_unique = 0
non_matching_unique = 0
current_ngram = ""
for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1)
if (
ngram != current_ngram
): # Only need to match the ngram once in training set
unique_ngrams += 1
current_ngram = ngram
if ngram in merged_lookup:
matched_ngrams.append(ngram) # For logging
matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)]
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again
else:
non_matching_unique += 1
print(f"Total Ngrams: {total_ngrams}")
print(f"Unique Ngrams: {unique_ngrams}")
print(f"Unique Matching: {matching_unique}")
print(f"Unique Non Matching: {non_matching_unique}")
print("Matched ngrams:")
for ngram in matched_ngrams:
print(ngram)
elapsed = time.perf_counter() - start
print(f"Read took {elapsed:0.5f} seconds.")
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")
print(duplicates)
# Dump overlaps separately
for (task_name, task_set), doc_ids in duplicates.items():
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
# Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
import pickle
import re
import string
import traceback
from typing import Iterator, List, Sequence, Tuple, TypeVar
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try:
import janitor_util
JANITOR_CPP = True
except Exception:
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
traceback.print_exc()
JANITOR_CPP = False
T = TypeVar("T")
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try:
next_item = next(sequence)
except StopIteration:
# no more data, terminate the generator
return
history.append(next_item)
n -= 1
for item in sequence:
history.append(item)
yield tuple(history)
del history[0]
def word_ngrams(s: str, n: int) -> Iterator[str]:
"""Splits a string into ngram words"""
tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n):
# current_word = ""
# history = []
# gap = False;
# start = 0
# end = 0
# for character in sequence:
# if character == " ":
# if not gap:
# gap = True
# history.append(current_word)
# end += len(current_word) - 1
# current_word = ""
# if len(history) == n:
# yield (tuple(history), start, end)
# del history[0]
# start = end + 1
# end = start
# else:
# gap = False
# current_word += character
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s)
# Generator of ngrams of (word, idx_pairs)
# (
# [(word, (start,end)), (word, (start, end))...],
# [(word, (start, end)), ...],
# ...
# )
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
# Generator of pairs of word and index ngrams
# (
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs = (
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return (
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
for ngram_seq, indices in ngram_indices_pairs
)
class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars?
def __init__(
self,
ngram_n: int = 13,
window_to_remove: int = 200,
too_dirty_cutoff: int = 10,
minimum_slice_length: int = 200,
delete_chars: str = string.punctuation,
) -> None:
self.ngram_n = ngram_n
self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff
self.minimum_slice_length = minimum_slice_length
self.delete_chars = delete_chars
self.dirt_ngrams = set()
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
# This is fast by python standards
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters
self.delete_chars, # These are deleted
)
##############
# I/O for saving contamination ngrams
##############
def save_contamination_ngrams(self, filename: str) -> None:
with open(filename, "wb") as fp:
pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename: str) -> None:
with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp)
##############
# Call these :)
##############
def register_contaminant(self, dirt_string: str) -> None:
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP:
return self.register_contaminant_cpp(dirt_string)
else:
print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string: str) -> List[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if JANITOR_CPP:
return self.clean_cpp(dirty_string)
else:
print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string)
def _split_chunks(
self, dirty_string: str, dirty_parts: Sequence[Tuple]
) -> List[str]:
clean_chunks = []
splice_idx = 0
end = -1
for i, (ngram, start, end) in enumerate(dirty_parts):
if i >= self.too_dirty_cutoff:
return []
start = max(0, start - self.window_to_remove)
end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx:start])
splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end + 1 :])
return clean_chunks
##############
# Fast C++
##############
def register_contaminant_cpp(self, dirt_string) -> None:
self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string: str) -> List[str]:
contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
return self._split_chunks(dirty_string, contamination_indices)
##############
# Slow python
##############
def normalize_string(self, s: str) -> str:
return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string: str) -> None:
self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string: str) -> List[str]:
contamination_indices = (
(None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
)
return self._split_chunks(dirty_string, contamination_indices)
##################################################################
# Tests
#################################################################
# def print_cpp():
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
# for i in range(1, 10, 2):
# pprint(janitor_util.clean_ngram(source, string.punctuation, i))
# for ngram, start, end in \
# janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
# def test_cpp():
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
# contaminant = "dirty boy. Clean he he"
# jan_python = Janitor()
# jan_cpp = Janitor()
# jan_python.register_contaminant_python(contaminant)
# jan_cpp.register_contaminant(contaminant)
# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
# assert jan_python.clean_python(source) == jan_cpp.clean(source), \
# (jan_python.clean_python(source), jan_cpp.clean(source))
# print("Passed test, python==cpp")
# def benchmark():
# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
# setup = \
# """
# with open("data/enwik8", "r") as f:
# data = f.read()
# jan = Janitor(too_dirty_cutoff=1000)
# jan.register_contaminant('''
# theories is that there is a connection between &quot;geekdom&quot; and autism.
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
# The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of
# the media's application of mental disease labels to what is actually variant normal behavior
# &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
# interests, even when they seem unusual to others, are not in themselves signs of autism or
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying
# mental disease labels to children who in the past would have simply been accepted as a little
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
# Due to the recent publicity surrounding autism and autis
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
# would last, took a cautious approach, preferring to save the revenue rather than investing it in
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
# [[United Arab Emirates]]. After the Emirates gained independence in 1971,
# ''')
# """
# n = 1
# print(f"Timing {n} run on 100 MB")
# print("Register contaminant")
# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
# print("Clean")
# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
# def test_janitor_general():
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
# contaminant = "dirty boy. Clean he he"
# jan = Janitor(ngram_n=3)
# jan.register_contaminant(contaminant)
# cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam
# filename = "data/saved_contam"
# jan.save_contamination_ngrams(filename)
# jan = Janitor(ngram_n=3)
# jan.load_contamination_ngrams(filename)
# cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam
# if __name__ == "__main__":
# test()
# # print_cpp()
# # test_cpp()
# # benchmark()
import itertools
import logging
import random
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch
import lm_eval.api.metrics
import lm_eval.api.registry
import lm_eval.models
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import (
consolidate_results,
get_sample_size,
get_task_list,
prepare_print_tasks,
print_writeout,
run_task_tests,
)
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
if TYPE_CHECKING:
from lm_eval.api.model import LM
from lm_eval.tasks import Task
@positional_deprecated
def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
use_cache: Optional[str] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
delete_requests_cache: bool = False,
limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
write_out: bool = False,
log_samples: bool = True,
gen_kwargs: Optional[str] = None,
task_manager: Optional[TaskManager] = None,
verbosity: str = "INFO",
predict_only: bool = False,
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str, dict]
String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, dict, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int or str, optional
Batch size for model
:param max_batch_size: int, optional
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param use_cache: str, optional
A path to a sqlite db file for caching model responses. `None` if not caching.
:param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
:param rewrite_requests_cache: bool, optional
Rewrites all of the request cache if set to `True`. `None` if not desired.
:param delete_requests_cache: bool, optional
Deletes all of the request cache if set to `True`. `None` if not desired.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_out: bool
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param gen_kwargs: str
String arguments for model generation
Ignored for all tasks with loglikelihood output_type
:param predict_only: bool
If true only model outputs will be generated and returned. Metrics will not be evaluated
:param random_seed: int
Random seed for python's random module. If set to None, the seed will not be set.
:param numpy_random_seed: int
Random seed for numpy. If set to None, the seed will not be set.
:param torch_random_seed: int
Random seed for torch. If set to None, the seed will not be set.
:return
Dictionary of results
"""
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
if delete_requests_cache:
eval_logger.info("Deleting requests cache...")
delete_cache()
seed_message = []
if random_seed is not None:
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
seed_message.append(f"Setting random seed to {random_seed}")
random.seed(random_seed)
if numpy_random_seed is not None:
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
np.random.seed(numpy_random_seed)
if torch_random_seed is not None:
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed)
if seed_message:
eval_logger.info(" | ".join(seed_message))
if tasks is None:
tasks = []
if len(tasks) == 0:
raise ValueError(
"No tasks specified, or no tasks found. Please verify the task names."
)
if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning(
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
)
if gen_kwargs == "":
gen_kwargs = None
if isinstance(model, str):
if model_args is None:
model_args = ""
if isinstance(model_args, dict):
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
else:
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
else:
if not isinstance(model, lm_eval.api.model.LM):
raise TypeError
lm = model
if use_cache is not None:
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank"
+ str(lm.rank)
+ ".db",
)
if task_manager is None:
task_manager = TaskManager(verbosity)
eval_logger.info(
"get_task_dict has been updated to accept an optional argument, `task_manager`"
"Read more here:https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
)
task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if isinstance(task_obj, tuple):
_, task_obj = task_obj
if task_obj is None:
continue
if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True
)
if predict_only:
log_samples = True
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
# we have to change the class properties post-hoc. This is pretty hacky.
task_obj.override_metric(metric_name="bypass")
if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else:
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
if check_integrity:
run_task_tests(task_list=tasks)
results = evaluate(
lm=lm,
task_dict=task_dict,
limit=limit,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters,
write_out=write_out,
log_samples=log_samples,
verbosity=verbosity,
)
if lm.rank == 0:
if isinstance(model, str):
model_name = model
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
model_name = model.config._name_or_path
else:
model_name = type(model).__name__
# add info about the model and few shot config
results["config"] = {
"model": model_name,
"model_args": model_args,
"batch_size": batch_size,
"batch_sizes": (
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
),
"device": device,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs,
}
#results["git_hash"] = get_git_commit_hash()
#add_env_info(results) # additional environment info to results
return results
else:
return None
@positional_deprecated
def evaluate(
lm: "LM",
task_dict,
limit: Optional[int] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
bootstrap_iters: Optional[int] = 100000,
write_out: bool = False,
log_samples: bool = True,
verbosity: str = "INFO",
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
Language Model
:param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param write_out: bool
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return
Dictionary of results
"""
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
# tracks all Instances/requests a model must generate output on.
requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict)
if not log_samples:
if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
for task_output in eval_tasks
):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
for task_output in eval_tasks:
task: Task = task_output.task
limit = get_sample_size(task, limit)
task.build_all_requests(
limit=limit,
rank=lm.rank,
world_size=lm.world_size,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
)
eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
)
if write_out:
print_writeout(task)
# aggregate Instances by LM method requested to get output.
for instance in task.instances:
reqtype = instance.request_type
requests[reqtype].append(instance)
if lm.world_size > 1:
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
gathered_item = (
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
)
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
else task.OUTPUT_TYPE
)
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank]
# todo: may not account for padding in cases like SquadV2 which has multiple req types
padding_requests[reqtype] += numpad
### Run LM on inputs, get all outputs ###
# execute each type of request
for reqtype, reqs in requests.items():
eval_logger.info(f"Running {reqtype} requests")
# create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = []
for req in reqs:
cloned_reqs.extend([req] * req.repeats)
if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
for _ in range(padding_requests[reqtype]):
cloned_reqs.extend([req] * req.repeats)
# run requests through model
resps = getattr(lm, reqtype)(cloned_reqs)
# put responses from model into a list of length K for each request.
for x, req in zip(resps, cloned_reqs):
req.resps.append(x)
if lm.world_size > 1:
lm.accelerator.wait_for_everyone()
RANK = lm.rank
WORLD_SIZE = lm.world_size
### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_output in eval_tasks:
task = task_output.task
task.apply_filters()
### Collect values of metrics on all datapoints ###
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id
instances_by_doc_id = defaultdict(list)
for instance in task.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx)
# iterate over different filters used
for filter_key in task.instances[0].filtered_resps.keys():
doc_iterator = task.doc_iterator(
rank=RANK, limit=limit, world_size=WORLD_SIZE
)
for doc_id, doc in doc_iterator:
requests = instances_by_doc_id[doc_id]
metrics = task.process_results(
doc, [req.filtered_resps[filter_key] for req in requests]
)
if log_samples:
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
}
example.update(metrics)
task_output.logged_samples.append(example)
for metric, value in metrics.items():
task_output.sample_metrics[(metric, filter_key)].append(value)
if WORLD_SIZE > 1:
# if multigpu, then gather data across all ranks to rank 0
# first gather logged samples across all ranks
for task_output in eval_tasks:
if log_samples:
# for task_name, task_samples in list(samples.items()):
full_samples = [None] * WORLD_SIZE if RANK == 0 else None
torch.distributed.gather_object(
obj=task_output.logged_samples,
object_gather_list=full_samples,
dst=0,
)
if RANK == 0:
task_output.logged_samples = list(
itertools.chain.from_iterable(full_samples)
)
# then collect metrics across all ranks
for metrics in task_output.sample_metrics:
metric_list = [None] * WORLD_SIZE if RANK == 0 else None
torch.distributed.gather_object(
obj=task_output.sample_metrics[metrics],
object_gather_list=metric_list,
dst=0,
)
if RANK == 0:
task_output.sample_metrics[metrics] = list(
itertools.chain.from_iterable(metric_list)
)
if RANK == 0:
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for task_output in eval_tasks:
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
results, samples, configs, versions, num_fewshot = consolidate_results(
eval_tasks
)
### Calculate group metrics ###
if bool(results):
for group, task_list in reversed(task_hierarchy.items()):
if len(task_list) == 0:
# task_hierarchy entries are either
# `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`.
# we only want to operate on groups here.
continue
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
# compute group's pooled metric and stderr
results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
else:
results[group][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group]["samples"] = sum(sizes)
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
if len(left_tasks_list) == 0:
break
_task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
for group_name, task_list in task_hierarchy.items():
if task_list:
num_fewshot[group_name] = num_fewshot[
task_list[0]
] # TODO: validate this
results_dict = {
"results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
}
if log_samples:
results_dict["samples"] = dict(samples)
return results_dict
else:
return None
def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = {
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
return request_caching_args
import collections
import math
import pathlib
import sys
from typing import Dict, List, Optional, Tuple, Union
from lm_eval.api import metrics
from lm_eval.utils import eval_logger, positional_deprecated
class TaskOutput:
"""
Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
Attributes:
task (object): The task object.
task_name (str): The name of the task.
task_config (dict): The configuration of the task.
version (str): The version of the task.
group_name (str): The name of the task group.
n_shot (int): The number of shots for the task.
task_alias (str): The alias of the task.
group_alias (str): The alias of the task group.
is_group (bool): Indicates if the task is a group.
logged_samples (list): The list of logged samples.
sample_len (int): The length of the samples.
sample_metrics (defaultdict): The dictionary of samples' metrics.
agg_metrics (defaultdict): The dictionary of aggregate metrics.
Methods:
from_taskdict(cls, task_name: str, task):
Creates a TaskOutput instance from a task dictionary.
calculate_aggregate_metric(bootstrap_iters=100000) -> None:
Calculates the aggregate metrics for the task.
"""
def __init__(
self,
task=None,
task_name=None,
task_config=None,
version=None,
group_name=None,
n_shot=None,
task_alias=None,
group_alias=None,
is_group=None,
):
self.task = task
self.task_config = task_config
self.task_name = task_name
self.group_name = group_name
self.version = version
self.n_shot = n_shot
self.task_alias = task_alias
self.group_alias = group_alias
self.is_group = is_group
self.logged_samples = []
self.sample_len = None
self.sample_metrics = collections.defaultdict(list)
self.agg_metrics = collections.defaultdict(list)
@classmethod
def from_taskdict(cls, task_name: str, task):
if isinstance(task, tuple):
group_name, task = task
else:
group_name = None
if not task:
# these gets filtered out in get_task_list
# once they are added to group hierarchy
is_group = True
return cls(
task=task, task_name=task_name, is_group=is_group, group_name=group_name
)
version = task.VERSION
task_config = dict(task.dump_config())
if (n_shot := task_config.get("num_fewshot")) == 0:
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
task_alias = task_config.get("alias")
group_alias = task_config.get("group_alias")
return cls(
task=task,
task_name=task_name,
task_config=task_config,
group_name=group_name,
version=version,
n_shot=n_shot,
task_alias=task_alias,
group_alias=group_alias,
)
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
for (metric, filter_key), items in self.sample_metrics.items():
agg_fn = self.task.aggregation()[metric]
metric_key = f"{metric},{filter_key}"
self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric?
if bootstrap_iters:
stderr_fn = metrics.stderr_for_metric(
metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
)
def __repr__(self):
return (
f"TaskOutput(task_name={self.task_name}, "
f"group_name={self.group_name}, "
f"version={self.version},"
f"n_shot={self.n_shot}"
f"task_alias={self.task_alias}, group_alias={self.group_alias})"
)
def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]:
task_hierarchy = collections.defaultdict(list)
outputs = list(TaskOutput.from_taskdict(x, y) for x, y in task_dict.items())
for task_output in outputs:
if group_name := task_output.group_name:
task_hierarchy[group_name].append(task_output.task_name)
else:
task_hierarchy[task_output.task_name] = []
# returns task_hierarchy tracking which groups contain which subtasks,
# and a list of TaskOutput classes for each non-group subtask
return task_hierarchy, [x for x in outputs if x.task]
def print_writeout(task) -> None:
for inst in task.instances:
# print the prompt for the first few documents
if inst.doc_id < 1:
eval_logger.info(
f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
)
eval_logger.info(f"Request: {str(inst)}")
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
if limit is not None:
limit = (
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
)
return limit
def prepare_print_tasks(
task_hierarchy: dict, results: dict, tab=0
) -> Tuple[dict, dict]:
"""
@param task_hierarchy: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results.
@param tab: The indentation level for printing the task
hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
(group_name, task_list), *_ = task_hierarchy.items()
task_list = sorted(task_list)
results_agg[group_name] = results[group_name].copy()
# results_agg[group_name]["tab"] = tab
if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
tab_string = " " * tab + "- " if tab > 0 else ""
if "alias" in results_agg[group_name]:
results_agg[group_name]["alias"] = tab_string + results_agg[group_name]["alias"]
else:
results_agg[group_name]["alias"] = tab_string + group_name
if len(task_list) > 0:
groups_agg[group_name] = results[group_name].copy()
# groups_agg[group_name]["tab"] = tab
if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if "alias" in groups_agg[group_name]:
groups_agg[group_name]["alias"] = (
tab_string + groups_agg[group_name]["alias"]
)
else:
groups_agg[group_name]["alias"] = tab_string + group_name
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else:
_task_hierarchy = {
**{task_name: []},
**task_hierarchy,
}
_results_agg, _groups_agg = prepare_print_tasks(
_task_hierarchy, results, tab + 1
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg
def consolidate_results(
eval_tasks: List[TaskOutput],
) -> Tuple[dict, dict, dict, dict, dict]:
"""
@param eval_tasks: list(TaskOutput).
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
Consolidates the results of multiple evaluation tasks into a single structure.
The method iterates over each evaluation instance and extracts relevant information to create the consolidated
results structure. The consolidated results structure has the following properties:
- results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
aliases specified in the task configuration.
- samples: A defaultdict with task names as keys and lists of log samples as values.
- configs: A defaultdict with task names as keys and task configurations as values.
- versions: A defaultdict with task names as keys and task versions as values.
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
"""
# stores the final result for each task, for each metric/filter pair.
results = collections.defaultdict(dict)
# logs info about each document evaluated.
samples = collections.defaultdict(list)
# store num-fewshot value per task
num_fewshot = collections.defaultdict(int)
# Tracks the YAML configs of all chosen task
configs = collections.defaultdict(dict)
# Tracks each task's version.
versions = collections.defaultdict(dict)
for task_output in eval_tasks:
if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_name]["alias"] = task_config["task_alias"]
if group_alias := task_output.group_alias:
if group_alias not in results and (group_name := task_output.group_name):
results[group_name]["alias"] = group_alias
num_fewshot[task_output.task_name] = task_output.n_shot
configs[task_output.task_name] = task_output.task_config
versions[task_output.task_name] = task_output.version
samples[task_output.task_name] = task_output.logged_samples
for (metric, filter_key), items in task_output.sample_metrics.items():
metric_key = f"{metric},{filter_key}"
results[task_output.task_name][metric_key] = task_output.agg_metrics[
metric_key
]
results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][
f"{metric}_stderr,{filter_key}"
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
return results, samples, configs, versions, num_fewshot
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated
def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
import pytest
package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = " or ".join(task_list)
args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
from functools import partial
from typing import List, Union
from lm_eval.api.filter import FilterEnsemble
from . import extraction, selection, transformation
FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter,
"remove_whitespace": extraction.WhitespaceFilter,
"lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter,
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
# "arg_max": selection.ArgMaxFilter,
}
def get_filter(filter_name: str) -> Union[type, str]:
if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name]
else:
return filter_name
def build_filter_ensemble(
filter_name: str, components: List[List[str]]
) -> FilterEnsemble:
"""
Create a filtering pipeline.
"""
filters = []
for function, kwargs in components:
if kwargs is None:
kwargs = {}
# create a filter given its name in the registry
f = partial(get_filter(function), **kwargs)
# add the filter as a pipeline step
filters.append(f)
return FilterEnsemble(name=filter_name, filters=filters)
from lm_eval.api.filter import Filter
class DecontaminationFilter(Filter):
"""
A filter which evaluates
"""
name = "track_decontamination"
def __init__(self, path) -> None:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id)
"""
self._decontam_results = None
def apply(self, resps, docs) -> None:
"""
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
"""
pass
import re
import sys
import unicodedata
from lm_eval.api.filter import Filter
class RegexFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def filter_set(inst):
filtered = []
for resp in inst:
match = self.regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
else:
match = self.fallback
filtered.append(match)
return filtered
# print(resps)
filtered_resps = list(map(lambda x: filter_set(x), resps))
# print(filtered_resps)
return filtered_resps
class WhitespaceFilter(Filter):
""" """
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
resp = resp[1:]
filtered_resp.append(resp)
return filtered_resp
filtered_resps = [filter_set(resp) for resp in resps]
return filtered_resps
class MultiChoiceRegexFilter(RegexFilter):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
punct_tbl = dict.fromkeys(
i
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
)
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)
if self.ignore_case:
st = st.lower()
if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st
filtered_resps = []
for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = "A"
without_paren_fallback_regexes = []
without_paren_to_target = {}
choices = doc["choices"]
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"
without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"
next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})"
)
filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(
fallback_regex, filter_ignores(resp), choice_to_alpha
)
if not match:
match = find_match(
without_paren_fallback_regex, resp, without_paren_to_target
)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)
return filtered_resps
from collections import Counter
from lm_eval.api.filter import Filter
class TakeFirstFilter(Filter):
def __init__(self) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps, docs):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return map(lambda r: r[0], resps)
class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(**kwargs)
def apply(self, resps, docs):
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k
assert (
len(resps[0]) >= self.k
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[: self.k], resps)
class MajorityVoteFilter(Filter):
def __init__(self) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps, docs):
"""
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
"""
def select_majority(resp):
counts = Counter(resp)
vote = counts.most_common(1)[0][0]
return vote
return map(lambda r: [select_majority(r)], resps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment