Unverified Commit d2b16757 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge branch 'big-refactor' into configurable-tasks

parents 1f23061b fa686d04
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal from typing import Literal, Tuple
@dataclass @dataclass
class Instance: class Instance:
request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"] request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"]
doc: dict = None doc: dict = None
arguments: tuple = None arguments: tuple = None
id_: int = None idx: int = None
metadata: tuple = None # TODO: better typehints here metadata: tuple = Tuple[str, int, int] # TODO: better typehints here
resps: list = field(default_factory=list) resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict) filtered_resps: dict = field(default_factory=dict)
......
class Sampler:
class Sampler: # TODO: make this abstract class?
def __init__(self, docs, task, fewshot_indices=None, rnd=None): def __init__(self, docs, task, fewshot_indices=None, rnd=None):
...@@ -17,14 +16,17 @@ class Sampler: # TODO: make this abstract class? ...@@ -17,14 +16,17 @@ class Sampler: # TODO: make this abstract class?
if fewshot_indices: # subset few-shot docs from if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot): def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if # draw an extra fewshot sample if using same split as evaluting on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples) fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = ( labeled_examples = (
...@@ -53,7 +55,7 @@ class BalancedSampler(Sampler): ...@@ -53,7 +55,7 @@ class BalancedSampler(Sampler):
def sample(self, n): def sample(self, n):
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? TODO: what order should they be in? maybe random?
""" """
pass pass
......
...@@ -540,7 +540,7 @@ class ConfigurableTask(Task): ...@@ -540,7 +540,7 @@ class ConfigurableTask(Task):
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, " {}".format(choice)),
id_=i, idx=i,
**kwargs, **kwargs,
) )
for i, choice in enumerate(ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc))) for i, choice in enumerate(ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc)))
...@@ -553,7 +553,7 @@ class ConfigurableTask(Task): ...@@ -553,7 +553,7 @@ class ConfigurableTask(Task):
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
doc=doc, doc=doc,
arguments=arguments, arguments=arguments,
id_=0, idx=0,
**kwargs **kwargs
) )
...@@ -631,7 +631,7 @@ class MultipleChoiceTask(Task): ...@@ -631,7 +631,7 @@ class MultipleChoiceTask(Task):
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, " {}".format(choice)),
id_=i, idx=i,
**kwargs, **kwargs,
) )
for i, choice in enumerate(doc["choices"])] for i, choice in enumerate(doc["choices"])]
...@@ -704,7 +704,7 @@ class PerplexityTask(Task, abc.ABC): ...@@ -704,7 +704,7 @@ class PerplexityTask(Task, abc.ABC):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
assert not ctx assert not ctx
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), id_=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), idx=0, **kwargs)
# req = rf.loglikelihood_rolling(self.doc_to_target(doc)) # req = rf.loglikelihood_rolling(self.doc_to_target(doc))
# return req # return req
......
...@@ -182,7 +182,7 @@ def evaluate( ...@@ -182,7 +182,7 @@ def evaluate(
for doc_id, doc in enumerate(itertools.islice(task.test_docs(), 0, limit) if task.has_test_docs() else task.validation_docs()): for doc_id, doc in enumerate(itertools.islice(task.test_docs(), 0, limit) if task.has_test_docs() else task.validation_docs()):
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.id_) requests.sort(key=lambda x: x.idx)
metrics = task.process_results(doc, [req.filtered_resps[key] for req in requests]) metrics = task.process_results(doc, [req.filtered_resps[key] for req in requests])
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) vals[(task_name, key, metric)].append(value)
......
...@@ -88,7 +88,7 @@ class GradeSchoolMath8K(Task): ...@@ -88,7 +88,7 @@ class GradeSchoolMath8K(Task):
""" """
# NOTE: The paper implements "verifiers" that assign a score to multiple # NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution. # solutions and output the highest ranked solution.
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), id_=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), idx=0, **kwargs)
# completion = rf.greedy_until(ctx, ["\n"]) # completion = rf.greedy_until(ctx, ["\n"])
# return completion # return completion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment