"vscode:/vscode.git/clone" did not exist on "f434d1f5d0a7e75f5a289b8350f2fe7b4487148f"
Commit e7c18e53 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

refactor Instance dataclass

parent d2a9b759
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal
@dataclass @dataclass
class Instance: class Instance:
request_type: str = None # TODO: make this an enum? request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"]
doc: dict = None doc: dict = None
arguments: tuple = None arguments: tuple = None
id_: int = None id_: int = None
...@@ -15,6 +16,7 @@ class Instance: ...@@ -15,6 +16,7 @@ class Instance:
repeats: str = None repeats: str = None
def __post_init__(self): def __post_init__(self):
# unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
@property @property
...@@ -23,53 +25,3 @@ class Instance: ...@@ -23,53 +25,3 @@ class Instance:
Returns (string,) where `string` is the string to calculate loglikelihood over Returns (string,) where `string` is the string to calculate loglikelihood over
""" """
return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
# import abc
# class Instance(abc.ABC):
# """
# A class used to bind together all necessary information and metadata for
# running forward pass of a model on a specific datapoint.
# """
# # all Instance subclasses have an attribute which is the name of the LM() class function they call to get outputs.
# request_type = None
# def __init__(self, doc, arguments=None, id_=None, metadata=("", None, None)):
# self.doc = doc # store the document which we're using. this is a dict
# self.arguments = arguments
# # need: task name, doc idx, num. repeats
# self.task_name, self.doc_id, self.repeats = metadata
# # id_ = idx within a doc's requests
# self.id_ = id_
# # handle repeats internally. should be able to run K times on exact same input/output pair
# # self.repeats = repeats
# # list containing the returns from each call of the model on this particular set of arguments.
# self.resps = []
# # filtered_resps should end up a dict, with a different key for each set of filters to apply. calculate results against each key in filtered_resps
# self.filtered_resps = {}
# #TODO: add more info as needed for detailed logging
# def __repr__(self):
# return f"Req_{self.request_type}{self.args}{self.id_}"
@dataclass
class LoglikelihoodInstance(Instance):
request_type: str = "loglikelihood"
@dataclass
class RollingLoglikelihoodInstance(Instance):
request_type: str = "loglikelihood_rolling"
@dataclass
class GenerationInstance(Instance):
request_type: str = "greedy_until"
...@@ -10,7 +10,7 @@ import datasets ...@@ -10,7 +10,7 @@ import datasets
import numpy as np import numpy as np
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api.instance import LoglikelihoodInstance, RollingLoglikelihoodInstance, GenerationInstance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils from lm_eval import utils
...@@ -460,7 +460,7 @@ class ConfigurableTask(Task): ...@@ -460,7 +460,7 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "greedy_until": if self.OUTPUT_TYPE == "greedy_until":
return GenerationInstance(doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs)
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -498,7 +498,8 @@ class MultipleChoiceTask(Task): ...@@ -498,7 +498,8 @@ class MultipleChoiceTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
return [LoglikelihoodInstance( return [Instance(
request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, " {}".format(choice)),
id_=i, id_=i,
...@@ -579,7 +580,7 @@ class PerplexityTask(Task, abc.ABC): ...@@ -579,7 +580,7 @@ class PerplexityTask(Task, abc.ABC):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
assert not ctx assert not ctx
return RollingLoglikelihoodInstance(doc=doc, arguments=(self.doc_to_target(doc),), id_=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), id_=0, **kwargs)
# req = rf.loglikelihood_rolling(self.doc_to_target(doc)) # req = rf.loglikelihood_rolling(self.doc_to_target(doc))
# return req # return req
......
...@@ -18,7 +18,7 @@ Homepage: https://github.com/openai/grade-school-math ...@@ -18,7 +18,7 @@ Homepage: https://github.com/openai/grade-school-math
""" """
import re import re
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import GenerationInstance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean from lm_eval.api.metrics import mean
from lm_eval import utils from lm_eval import utils
...@@ -87,7 +87,7 @@ class GradeSchoolMath8K(Task): ...@@ -87,7 +87,7 @@ class GradeSchoolMath8K(Task):
""" """
# NOTE: The paper implements "verifiers" that assign a score to multiple # NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution. # solutions and output the highest ranked solution.
return GenerationInstance(doc=doc, arguments=(ctx, ["\n"]), id_=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), id_=0, **kwargs)
# completion = rf.greedy_until(ctx, ["\n"]) # completion = rf.greedy_until(ctx, ["\n"])
# return completion # return completion
......
...@@ -13,7 +13,7 @@ in the broader discourse. ...@@ -13,7 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import LoglikelihoodInstance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity from lm_eval.api.metrics import mean, perplexity
...@@ -59,9 +59,7 @@ class LambadaBase(Task): ...@@ -59,9 +59,7 @@ class LambadaBase(Task):
return " " + doc["text"].rsplit(" ", 1)[1] return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
return LoglikelihoodInstance(doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
return ll, is_greedy
def process_results(self, doc, results): def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score # TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment