Commit 9d8e0532 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

change id_ to idx in instance

parent 2a9da9fb
from dataclasses import dataclass, field
from typing import Literal
from typing import Literal, Tuple
@dataclass
class Instance:
request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"]
doc: dict = None
arguments: tuple = None
id_: int = None
metadata: tuple = None # TODO: better typehints here
idx: int = None
metadata: tuple = Tuple[str, int, int] # TODO: better typehints here
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
......
class Sampler: # TODO: make this abstract class?
class Sampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None):
......@@ -17,14 +16,17 @@ class Sampler: # TODO: make this abstract class?
if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if
# draw an extra fewshot sample if using same split as evaluting on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
......@@ -53,7 +55,7 @@ class BalancedSampler(Sampler):
def sample(self, n):
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in?
TODO: what order should they be in? maybe random?
"""
pass
......
......@@ -469,7 +469,7 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "greedy_until":
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs)
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, "\n\n"), idx=0, **kwargs)
def process_results(self, doc, results):
......@@ -511,7 +511,7 @@ class MultipleChoiceTask(Task):
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
id_=i,
idx=i,
**kwargs,
)
for i, choice in enumerate(doc["choices"])]
......@@ -589,7 +589,7 @@ class PerplexityTask(Task, abc.ABC):
def construct_requests(self, doc, ctx, **kwargs):
assert not ctx
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), id_=0, **kwargs)
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), idx=0, **kwargs)
# req = rf.loglikelihood_rolling(self.doc_to_target(doc))
# return req
......
......@@ -181,7 +181,7 @@ def evaluate(
for doc_id, doc in enumerate(itertools.islice(task.test_docs(), 0, limit) if task.has_test_docs() else task.validation_docs()):
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.id_)
requests.sort(key=lambda x: x.idx)
metrics = task.process_results(doc, [req.filtered_resps[key] for req in requests])
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
......
......@@ -88,7 +88,7 @@ class GradeSchoolMath8K(Task):
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), id_=0, **kwargs)
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), idx=0, **kwargs)
# completion = rf.greedy_until(ctx, ["\n"])
# return completion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment