Commit e7c18e53 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

refactor Instance dataclass

parent d2a9b759
from dataclasses import dataclass, field
from typing import Literal
@dataclass
class Instance:
request_type: str = None # TODO: make this an enum?
request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"]
doc: dict = None
arguments: tuple = None
id_: int = None
......@@ -15,6 +16,7 @@ class Instance:
repeats: str = None
def __post_init__(self):
# unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata
@property
......@@ -23,53 +25,3 @@ class Instance:
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
# import abc
# class Instance(abc.ABC):
# """
# A class used to bind together all necessary information and metadata for
# running forward pass of a model on a specific datapoint.
# """
# # all Instance subclasses have an attribute which is the name of the LM() class function they call to get outputs.
# request_type = None
# def __init__(self, doc, arguments=None, id_=None, metadata=("", None, None)):
# self.doc = doc # store the document which we're using. this is a dict
# self.arguments = arguments
# # need: task name, doc idx, num. repeats
# self.task_name, self.doc_id, self.repeats = metadata
# # id_ = idx within a doc's requests
# self.id_ = id_
# # handle repeats internally. should be able to run K times on exact same input/output pair
# # self.repeats = repeats
# # list containing the returns from each call of the model on this particular set of arguments.
# self.resps = []
# # filtered_resps should end up a dict, with a different key for each set of filters to apply. calculate results against each key in filtered_resps
# self.filtered_resps = {}
# #TODO: add more info as needed for detailed logging
# def __repr__(self):
# return f"Req_{self.request_type}{self.args}{self.id_}"
@dataclass
class LoglikelihoodInstance(Instance):
request_type: str = "loglikelihood"
@dataclass
class RollingLoglikelihoodInstance(Instance):
request_type: str = "loglikelihood_rolling"
@dataclass
class GenerationInstance(Instance):
request_type: str = "greedy_until"
......@@ -10,7 +10,7 @@ import datasets
import numpy as np
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api.instance import LoglikelihoodInstance, RollingLoglikelihoodInstance, GenerationInstance
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils
......@@ -460,7 +460,7 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "greedy_until":
return GenerationInstance(doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs)
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs)
def process_results(self, doc, results):
......@@ -498,7 +498,8 @@ class MultipleChoiceTask(Task):
def construct_requests(self, doc, ctx, **kwargs):
return [LoglikelihoodInstance(
return [Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
id_=i,
......@@ -579,7 +580,7 @@ class PerplexityTask(Task, abc.ABC):
def construct_requests(self, doc, ctx, **kwargs):
assert not ctx
return RollingLoglikelihoodInstance(doc=doc, arguments=(self.doc_to_target(doc),), id_=0, **kwargs)
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), id_=0, **kwargs)
# req = rf.loglikelihood_rolling(self.doc_to_target(doc))
# return req
......
......@@ -18,7 +18,7 @@ Homepage: https://github.com/openai/grade-school-math
"""
import re
from lm_eval.api.task import Task
from lm_eval.api.instance import GenerationInstance
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean
from lm_eval import utils
......@@ -87,7 +87,7 @@ class GradeSchoolMath8K(Task):
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
return GenerationInstance(doc=doc, arguments=(ctx, ["\n"]), id_=0, **kwargs)
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), id_=0, **kwargs)
# completion = rf.greedy_until(ctx, ["\n"])
# return completion
......
......@@ -13,7 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from lm_eval.api.task import Task
from lm_eval.api.instance import LoglikelihoodInstance
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity
......@@ -59,9 +59,7 @@ class LambadaBase(Task):
return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs):
return LoglikelihoodInstance(doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
return ll, is_greedy
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment