Commit 1576e99e authored by Albert Jiang's avatar Albert Jiang
Browse files

allow task-specific descriptions to be passed to

request construction
parent 99fdec0c
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
import json import json
import hashlib import hashlib
import datasets import datasets
import inspect
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
from tqdm import tqdm from tqdm import tqdm
import torch import torch
...@@ -329,7 +330,7 @@ class BaseLM(LM): ...@@ -329,7 +330,7 @@ class BaseLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def multiple_temperature_sample_until(self, requests, k=32, temperature=0.3): def generate(self, requests):
res = [] res = []
def _collate(x): def _collate(x):
...@@ -338,21 +339,33 @@ class BaseLM(LM): ...@@ -338,21 +339,33 @@ class BaseLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for context, until in tqdm(re_ord.get_reordered()): for request in tqdm(re_ord.get_reordered()):
if len(request) == 2:
# Unpack greedy sample request
context, until, = request
k, temperature = 1, 0.
_model_generate_kwargs = {}
elif len(request) == 4:
# Unpack temperature sample request
context, until, k, temperature = request
for key in ["k", "temperature"]:
assert key in inspect.getfullargspec(self._model_generate).args, \
f"Model generation parameter '{key}' not accepted as an argument for _model_generate"
_model_generate_kwargs = {"k": k, "temperature": temperature}
else:
raise AssertionError
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [until]
(primary_until,) = self.tok_encode(until[0]) (primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device) ).to(self.device)
assert context_enc.shape[0] == 1
context_enc = context_enc.expand(k, context_enc.shape[1])
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until, context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until,
temperature=temperature **_model_generate_kwargs
) )
generated_tokens = cont[:, context_enc.shape[1]:] generated_tokens = cont[:, context_enc.shape[1]:]
...@@ -361,50 +374,16 @@ class BaseLM(LM): ...@@ -361,50 +374,16 @@ class BaseLM(LM):
s = [candidate.split(term)[0] for candidate in s] s = [candidate.split(term)[0] for candidate in s]
# partial caching # partial caching
self.cache_hook.add_partial("multiple_temperature_sample_until", (context, until, k, temperature), s) self.cache_hook.add_partial("generate", (context, until, k, temperature), s)
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement fully general `until` that handles until that are # TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly # multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM? # TODO: extract to TokenizedLM?
res = [] return self.generate(requests)
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
for context, until in tqdm(re_ord.get_reordered()):
if isinstance(until, str):
until = [until]
(primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return re_ord.get_original(res)
class Task(abc.ABC): class Task(abc.ABC):
...@@ -881,7 +860,7 @@ class CachingLM: ...@@ -881,7 +860,7 @@ class CachingLM:
REQUEST_RETURN_LENGTHS = { REQUEST_RETURN_LENGTHS = {
"loglikelihood": 2, "loglikelihood": 2,
"greedy_until": None, "greedy_until": None,
"multiple_temperature_sample_until": None, "generate": None,
"loglikelihood_rolling": None, "loglikelihood_rolling": None,
} }
......
...@@ -2,6 +2,7 @@ import collections ...@@ -2,6 +2,7 @@ import collections
import itertools import itertools
import numpy as np import numpy as np
import random import random
import inspect
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
...@@ -177,6 +178,7 @@ def evaluate( ...@@ -177,6 +178,7 @@ def evaluate(
docs = {} docs = {}
docs_for_decontamination = collections.defaultdict(list) docs_for_decontamination = collections.defaultdict(list)
task_to_description = {}
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
...@@ -203,6 +205,7 @@ def evaluate( ...@@ -203,6 +205,7 @@ def evaluate(
if description_dict and task_name in description_dict if description_dict and task_name in description_dict
else "" else ""
) )
task_to_description[task_name] = description
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
...@@ -215,7 +218,10 @@ def evaluate( ...@@ -215,7 +218,10 @@ def evaluate(
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
) )
reqs = task.construct_requests(doc, ctx) if "description" in inspect.getfullargspec(task.construct_requests).args:
reqs = task.construct_requests(doc, ctx, description=description)
else:
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
...@@ -262,7 +268,11 @@ def evaluate( ...@@ -262,7 +268,11 @@ def evaluate(
task = task_dict[task_name] task = task_dict[task_name]
doc = docs[(task_name, doc_id)] doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, requests) # be backward compatible with tasks that do not allow description_dict in process_results
if "description" in inspect.getfullargspec(task.process_results).args:
metrics = task.process_results(doc, requests, task_to_description[task_name])
else:
metrics = task.process_results(doc, requests)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_name, metric)].append(value)
......
...@@ -121,16 +121,23 @@ class HFLM(BaseLM): ...@@ -121,16 +121,23 @@ class HFLM(BaseLM):
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257] return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id, temperature=0.): def _model_generate(self, context, max_length, eos_token_id, k=1, temperature=0.):
assert temperature >= 0. assert (isinstance(k, int) and k >= 1), f"Incorrect number of candidates to generate: {k}"
if temperature == 0.: assert temperature >= 0., f"Negative sampling temperature: {temperature}"
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False # Whether to sample or to decode greedily
) do_sample = (temperature != 0.)
else: if not do_sample:
return self.gpt2.generate( # If decoding greedily, only sample once
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=True, temperature=temperature assert k == 1, f"Decoding greedily but {k} generations"
)
if k > 1:
context = context.expand(k, context.shape[1])
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id,
do_sample=do_sample, temperature=temperature
)
# for backwards compatibility # for backwards compatibility
......
...@@ -157,7 +157,6 @@ TASK_REGISTRY = { ...@@ -157,7 +157,6 @@ TASK_REGISTRY = {
"mutual_plus": mutual.MuTualPlus, "mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_algebra_maj@k": hendrycks_math.MathAlgebraMaj,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
"math_geometry": hendrycks_math.MathGeometry, "math_geometry": hendrycks_math.MathGeometry,
"math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra, "math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
......
...@@ -27,6 +27,8 @@ _CITATION = """ ...@@ -27,6 +27,8 @@ _CITATION = """
class Math(Task): class Math(Task):
DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math) DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math)
DATASET_NAME = None DATASET_NAME = None
MAJORITY_VOTING = "majority_voting"
SAMPLING_TEMPERATURE = "sampling_temperature"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -62,21 +64,72 @@ class Math(Task): ...@@ -62,21 +64,72 @@ class Math(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc["solution"] return " " + doc["solution"]
def construct_requests(self, doc, ctx): def parse_description(self, description):
return rf.greedy_until(ctx, ["\n"]) """description is a string with comma-separated key=value tuples
e.g.:
"majority_voting=32,sampling_temperature=1.0"
"""
parsed_dict = {}
for term in description.split(","):
if not term.strip():
continue
key, value = term.split("=")
parsed_dict[key] = value
return parsed_dict
def construct_requests(self, doc, ctx, description=""):
if not description.strip():
return rf.generate(ctx, ["\n"])
parsed_description = self.parse_description(description=description)
majority_voting_value = int(parsed_description.get(self.MAJORITY_VOTING, 1))
sampling_temperature_value = float(parsed_description.get(self.SAMPLING_TEMPERATURE, 1.0))
return rf.generate(ctx, ["\n"],
majority_voting_value, sampling_temperature_value)
def get_pure_answer(self, candidate):
indices = [pos for pos, char in enumerate(candidate) if char == "$"]
if len(indices) <= 1:
return candidate
return candidate[indices[0] + 1 : indices[-1]]
def process_results(self, doc, results): def majority_vote(self, candidates):
answers = []
for candidate in candidates:
answer = self.get_pure_answer(candidate)
try:
answer = self.remove_boxed(self.last_boxed_only_string(answer))
except:
answer = None
answers.append(answer)
answer_votes = {}
for answer in answers:
answer_votes[answer] = answer_votes.get(answer, 0) + 1
max_vote = 0
elected = None
for answer, vote in answer_votes.items():
if vote > max_vote and answer is not None:
elected = answer
max_vote = vote
return elected
def process_results(self, doc, results, description=""):
retval = 0 retval = 0
indices = [pos for pos, char in enumerate(results[0]) if char == "$"]
if len(indices) <= 1:
answer = results[0]
else:
answer = results[0][indices[0] + 1 : indices[-1]]
if description == "":
answer = self.get_pure_answer(results[0])
elif self.MAJORITY_VOTING in self.parse_description(description):
answer = self.majority_vote(results[0])
else:
raise AssertionError
if self.is_equiv( if self.is_equiv(
answer, self.remove_boxed(self.last_boxed_only_string(doc["solution"])) answer, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
): ):
retval = 1 retval = 1
return {"acc": retval} return {"acc": retval}
def aggregation(self): def aggregation(self):
...@@ -286,48 +339,6 @@ class MathAlgebra(Math): ...@@ -286,48 +339,6 @@ class MathAlgebra(Math):
DATASET_NAME = "algebra" DATASET_NAME = "algebra"
class MathAlgebraMaj(Math):
VERSION = 1
DATASET_NAME = "algebra"
def construct_requests(self, doc, ctx):
return rf.multiple_temperature_sample_until(ctx, ["\n"])
def process_results(self, doc, results):
retval = 0
candidates = results[0]
answers = []
for candidate in candidates:
indices = [pos for pos, char in enumerate(candidate) if char == "$"]
if len(indices) <= 1:
answer = candidate
else:
answer = candidate[indices[0] + 1 : indices[-1]]
try:
answer = self.remove_boxed(self.last_boxed_only_string(answer))
except:
answer = None
answers.append(answer)
answer_votes = {}
for answer in answers:
answer_votes[answer] = answer_votes.get(answer, 0) + 1
max_vote = 0
elected = None
for answer, vote in answer_votes.items():
if vote > max_vote and answer is not None:
elected = answer
max_vote = vote
if self.is_equiv(
elected, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
):
retval = 1
return {"acc": retval}
class MathCountingAndProbability(Math): class MathCountingAndProbability(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = "counting_and_probability" DATASET_NAME = "counting_and_probability"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment