Commit e7c18e53 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

refactor Instance dataclass

parent d2a9b759
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal
@dataclass @dataclass
class Instance: class Instance:
request_type: str = None # TODO: make this an enum? request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"]
doc: dict = None doc: dict = None
arguments: tuple = None arguments: tuple = None
id_: int = None id_: int = None
...@@ -15,6 +16,7 @@ class Instance: ...@@ -15,6 +16,7 @@ class Instance:
repeats: str = None repeats: str = None
def __post_init__(self): def __post_init__(self):
# unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
@property @property
...@@ -23,53 +25,3 @@ class Instance: ...@@ -23,53 +25,3 @@ class Instance:
Returns (string,) where `string` is the string to calculate loglikelihood over Returns (string,) where `string` is the string to calculate loglikelihood over
""" """
return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
# import abc
# class Instance(abc.ABC):
# """
# A class used to bind together all necessary information and metadata for
# running forward pass of a model on a specific datapoint.
# """
# # all Instance subclasses have an attribute which is the name of the LM() class function they call to get outputs.
# request_type = None
# def __init__(self, doc, arguments=None, id_=None, metadata=("", None, None)):
# self.doc = doc # store the document which we're using. this is a dict
# self.arguments = arguments
# # need: task name, doc idx, num. repeats
# self.task_name, self.doc_id, self.repeats = metadata
# # id_ = idx within a doc's requests
# self.id_ = id_
# # handle repeats internally. should be able to run K times on exact same input/output pair
# # self.repeats = repeats
# # list containing the returns from each call of the model on this particular set of arguments.
# self.resps = []
# # filtered_resps should end up a dict, with a different key for each set of filters to apply. calculate results against each key in filtered_resps
# self.filtered_resps = {}
# #TODO: add more info as needed for detailed logging
# def __repr__(self):
# return f"Req_{self.request_type}{self.args}{self.id_}"
@dataclass
class LoglikelihoodInstance(Instance):
request_type: str = "loglikelihood"
@dataclass
class RollingLoglikelihoodInstance(Instance):
request_type: str = "loglikelihood_rolling"
@dataclass
class GenerationInstance(Instance):
request_type: str = "greedy_until"
...@@ -10,7 +10,7 @@ import datasets ...@@ -10,7 +10,7 @@ import datasets
import numpy as np import numpy as np
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api.instance import LoglikelihoodInstance, RollingLoglikelihoodInstance, GenerationInstance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils from lm_eval import utils
...@@ -460,7 +460,7 @@ class ConfigurableTask(Task): ...@@ -460,7 +460,7 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "greedy_until": if self.OUTPUT_TYPE == "greedy_until":
return GenerationInstance(doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs)
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -498,7 +498,8 @@ class MultipleChoiceTask(Task): ...@@ -498,7 +498,8 @@ class MultipleChoiceTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
return [LoglikelihoodInstance( return [Instance(
request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, " {}".format(choice)),
id_=i, id_=i,
...@@ -579,7 +580,7 @@ class PerplexityTask(Task, abc.ABC): ...@@ -579,7 +580,7 @@ class PerplexityTask(Task, abc.ABC):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
assert not ctx assert not ctx
return RollingLoglikelihoodInstance(doc=doc, arguments=(self.doc_to_target(doc),), id_=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), id_=0, **kwargs)
# req = rf.loglikelihood_rolling(self.doc_to_target(doc)) # req = rf.loglikelihood_rolling(self.doc_to_target(doc))
# return req # return req
......
...@@ -18,7 +18,7 @@ Homepage: https://github.com/openai/grade-school-math ...@@ -18,7 +18,7 @@ Homepage: https://github.com/openai/grade-school-math
""" """
import re import re
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import GenerationInstance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean from lm_eval.api.metrics import mean
from lm_eval import utils from lm_eval import utils
...@@ -87,7 +87,7 @@ class GradeSchoolMath8K(Task): ...@@ -87,7 +87,7 @@ class GradeSchoolMath8K(Task):
""" """
# NOTE: The paper implements "verifiers" that assign a score to multiple # NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution. # solutions and output the highest ranked solution.
return GenerationInstance(doc=doc, arguments=(ctx, ["\n"]), id_=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), id_=0, **kwargs)
# completion = rf.greedy_until(ctx, ["\n"]) # completion = rf.greedy_until(ctx, ["\n"])
# return completion # return completion
......
...@@ -13,7 +13,7 @@ in the broader discourse. ...@@ -13,7 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import LoglikelihoodInstance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity from lm_eval.api.metrics import mean, perplexity
...@@ -59,9 +59,7 @@ class LambadaBase(Task): ...@@ -59,9 +59,7 @@ class LambadaBase(Task):
return " " + doc["text"].rsplit(" ", 1)[1] return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
return LoglikelihoodInstance(doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
return ll, is_greedy
def process_results(self, doc, results): def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score # TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment