import abc import random import numpy as np from lm_eval.metrics import mean class LM(abc.ABC): def __init__(self): self.cache_hook = CacheHook(None) @abc.abstractmethod def loglikelihood(self, requests): """Compute log-likelihood of generating a continuation from a context. Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. :param requests: list A list of pairs (context, continuation) context: str Context string. Implementations of LM must be able to handle an empty context string. continuation: str The continuation over which log likelihood will be calculated. If there is a word boundary, the space should be in the continuation. For example, context="hello" continuation=" world" is correct. :return: list A list of pairs (logprob, isgreedy) logprob: float The log probability of `contination` isgreedy: Whether `contination` would be generated by greedy sampling from `context` """ pass # TODO: Add an optional max length @abc.abstractmethod def greedy_until(self, requests): """Generate greedily until a stopping sequence :param requests: list A list of pairs (context, until) context: str Context string until: [str] The string sequences to generate until. These string sequences may each span across multiple tokens, or may be part of one token. :return: list A list of strings continuation continuation: str The generated continuation. """ pass @classmethod def create_from_arg_string(cls, arg_string): """Constructor method, in case models need additional arguments e.g. OpenAI API engine, paths for loading, other params :param arg_string: str Left up to individual model class to handle """ return cls() def set_cache_hook(self, cache_hook): self.cache_hook = cache_hook class Task(abc.ABC): """A task represents an entire benchmark including its dataset, problems, answers, and evaluation methods. See BoolQ for a simple example implementation A `doc` can be any python object which represents one instance of evaluation. This is usually a dictionary e.g. {"question": ..., "answer": ...} or {"question": ..., question, answer) """ def __init__(self): self.download() self._training_docs = None self._fewshot_docs = None def download(self): """Downloads the task dataset if necessary""" pass @abc.abstractmethod def has_training_docs(self): """Whether the task has a training set""" pass @abc.abstractmethod def has_validation_docs(self): """Whether the task has a validation set""" pass @abc.abstractmethod def has_test_docs(self): """Whether the task has a test set""" pass def training_docs(self): """ :return: Iterable[obj] A iterable of any object, that doc_to_text can handle """ return [] def validation_docs(self): """ :return: Iterable[obj] A iterable of any object, that doc_to_text can handle """ return [] def test_docs(self): """ :return: Iterable[obj] A iterable of any object, that doc_to_text can handle """ return [] def fewshot_examples(self, k, rnd): if self._training_docs is None: self._training_docs = list(self.training_docs()) return rnd.sample(self._training_docs, k) @abc.abstractmethod def doc_to_text(self, doc): pass @abc.abstractmethod def doc_to_target(self, doc): pass @abc.abstractmethod def construct_requests(self, doc, ctx): """ Uses RequestFactory to construct Requests and returns an iterable of Requests which will be sent to the LM. :param doc: The document as returned from training_docs, validation_docs, or test_docs. :param ctx: str The context string, generated by fewshot_context. This includes the natural language description, as well as the few shot examples, and the question part of the document for `doc`. """ pass @abc.abstractmethod def process_results(self, doc, results): """Take a single document and the LM results and evaluates, returning a dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: The document as returned from training_docs, validation_docs, or test_docs. :param results: The results of the requests created in construct_requests. """ pass @abc.abstractmethod def aggregation(self): """ :returns: {str: [metric_score] -> float} A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metric scores """ pass @abc.abstractmethod def higher_is_better(self): """ :returns: {str: bool} A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ pass def fewshot_description(self): return "" def fewshot_context(self, doc, num_fewshot, provide_description, rnd): raw_description = self.fewshot_description() description = (raw_description + "\n===\n\n") if provide_description and raw_description else "" if num_fewshot == 0: labeled_examples = "" else: # for sets with no training docs, draw from other set *but ensure no overlap with current doc* if self.has_training_docs(): fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) else: if self._fewshot_docs is None: self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs()) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) # get rid of the doc that's the one we're evaluating, if it's in the fewshot fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] labeled_examples = "\n\n".join( [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex] ) + "\n\n" example = self.doc_to_text(doc) return description + labeled_examples + example class MultipleChoiceTask(Task): def doc_to_target(self, doc): return " " + doc['choices'][doc['gold']] def construct_requests(self, doc, ctx): lls = [ rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc['choices'] ] return lls def process_results(self, doc, results): gold = doc["gold"] acc = 1. if np.argmax(results) == gold else 0. completion_len = np.array([float(len(i)) for i in doc["choices"]]) acc_norm = 1. if np.argmax(results / completion_len) == gold else 0. return { "acc": acc, "acc_norm": acc_norm, } def higher_is_better(self): return { "acc": True, "acc_norm": True, } def aggregation(self): return { "acc": mean, "acc_norm": mean, } req_ret_lens = { 'loglikelihood': 2, 'greedy_until': None, } import os import json import hashlib from sqlitedict import SqliteDict def hash_args(attr, args): dat = json.dumps([attr] + list(args)) return hashlib.sha256(dat.encode('utf-8')).hexdigest() class CacheHook: def __init__(self, cachinglm): if cachinglm is None: self.dbdict = None return self.dbdict = cachinglm.dbdict def add_partial(self, attr, req, res): if self.dbdict is None: return hsh = hash_args(attr, req) self.dbdict[hsh] = res class CachingLM: def __init__(self, lm, cache_db): self.lm = lm self.cache_db = cache_db os.makedirs(os.path.dirname(cache_db), exist_ok=True) self.dbdict = SqliteDict(cache_db, autocommit=True) # add hook to lm lm.set_cache_hook(self.get_cache_hook()) def __getattr__(self, attr): def fn(requests): res = [] remaining_reqs = [] # figure out which ones are cached and which ones are new for req in requests: hsh = hash_args(attr, req) if hsh in self.dbdict: ob = self.dbdict[hsh] assert ob is not None res.append(ob) else: res.append(None) remaining_reqs.append(req) # actually run the LM rem_res = getattr(self.lm, attr)(remaining_reqs) # stick the new ones back into the list and also cache any of the new ones resptr = 0 for req, r in zip(remaining_reqs, rem_res): while res[resptr] is not None: resptr += 1 res[resptr] = r # caching hsh = hash_args(attr, req) self.dbdict[hsh] = r self.dbdict.commit() return res return fn def get_cache_hook(self): return CacheHook(self) class Request: def __init__(self, type, args, index=None): if type not in req_ret_lens.keys(): raise NotImplementedError('The request type {} is not implemented!'.format(type)) self.type = type self.args = args self.index = index def __iter__(self): if req_ret_lens[self.type] is None: raise IndexError('This request type does not return multiple arguments!') i = 0 for i in range(req_ret_lens[self.type]): yield Request(self.type, self.args, i) def __getitem__(self, i): if req_ret_lens[self.type] is None: raise IndexError('This request type does not return multiple arguments!') return Request(self.type, self.args, i) def __eq__(self, other): return self.type == other.type and self.args == other.args and self.index == other.index def __repr__(self): return f"Req_{self.type}{self.args}[{self.index}]\n" class RequestFactory: def __getattr__(self, attr): def fn(*args): return Request(attr, args) return fn rf = RequestFactory()