Commit 1576e99e authored by Albert Jiang's avatar Albert Jiang
Browse files

allow task-specific descriptions to be passed to

request construction
parent 99fdec0c
......@@ -7,6 +7,7 @@ import os
import json
import hashlib
import datasets
import inspect
from sqlitedict import SqliteDict
from tqdm import tqdm
import torch
......@@ -329,7 +330,7 @@ class BaseLM(LM):
return re_ord.get_original(res)
def multiple_temperature_sample_until(self, requests, k=32, temperature=0.3):
def generate(self, requests):
res = []
def _collate(x):
......@@ -338,21 +339,33 @@ class BaseLM(LM):
re_ord = utils.Reorderer(requests, _collate)
for context, until in tqdm(re_ord.get_reordered()):
for request in tqdm(re_ord.get_reordered()):
if len(request) == 2:
# Unpack greedy sample request
context, until, = request
k, temperature = 1, 0.
_model_generate_kwargs = {}
elif len(request) == 4:
# Unpack temperature sample request
context, until, k, temperature = request
for key in ["k", "temperature"]:
assert key in inspect.getfullargspec(self._model_generate).args, \
f"Model generation parameter '{key}' not accepted as an argument for _model_generate"
_model_generate_kwargs = {"k": k, "temperature": temperature}
else:
raise AssertionError
if isinstance(until, str):
until = [until]
(primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
assert context_enc.shape[0] == 1
context_enc = context_enc.expand(k, context_enc.shape[1])
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until,
temperature=temperature
**_model_generate_kwargs
)
generated_tokens = cont[:, context_enc.shape[1]:]
......@@ -361,50 +374,16 @@ class BaseLM(LM):
s = [candidate.split(term)[0] for candidate in s]
# partial caching
self.cache_hook.add_partial("multiple_temperature_sample_until", (context, until, k, temperature), s)
self.cache_hook.add_partial("generate", (context, until, k, temperature), s)
res.append(s)
return re_ord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
for context, until in tqdm(re_ord.get_reordered()):
if isinstance(until, str):
until = [until]
(primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return re_ord.get_original(res)
return self.generate(requests)
class Task(abc.ABC):
......@@ -881,7 +860,7 @@ class CachingLM:
REQUEST_RETURN_LENGTHS = {
"loglikelihood": 2,
"greedy_until": None,
"multiple_temperature_sample_until": None,
"generate": None,
"loglikelihood_rolling": None,
}
......
......@@ -2,6 +2,7 @@ import collections
import itertools
import numpy as np
import random
import inspect
import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
......@@ -177,6 +178,7 @@ def evaluate(
docs = {}
docs_for_decontamination = collections.defaultdict(list)
task_to_description = {}
# get lists of each type of request
for task_name, task in task_dict_items:
......@@ -203,6 +205,7 @@ def evaluate(
if description_dict and task_name in description_dict
else ""
)
task_to_description[task_name] = description
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
......@@ -215,7 +218,10 @@ def evaluate(
ctx = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if "description" in inspect.getfullargspec(task.construct_requests).args:
reqs = task.construct_requests(doc, ctx, description=description)
else:
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
for i, req in enumerate(reqs):
......@@ -262,7 +268,11 @@ def evaluate(
task = task_dict[task_name]
doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, requests)
# be backward compatible with tasks that do not allow description_dict in process_results
if "description" in inspect.getfullargspec(task.process_results).args:
metrics = task.process_results(doc, requests, task_to_description[task_name])
else:
metrics = task.process_results(doc, requests)
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
......
......@@ -121,16 +121,23 @@ class HFLM(BaseLM):
with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id, temperature=0.):
assert temperature >= 0.
if temperature == 0.:
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
else:
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=True, temperature=temperature
)
def _model_generate(self, context, max_length, eos_token_id, k=1, temperature=0.):
assert (isinstance(k, int) and k >= 1), f"Incorrect number of candidates to generate: {k}"
assert temperature >= 0., f"Negative sampling temperature: {temperature}"
# Whether to sample or to decode greedily
do_sample = (temperature != 0.)
if not do_sample:
# If decoding greedily, only sample once
assert k == 1, f"Decoding greedily but {k} generations"
if k > 1:
context = context.expand(k, context.shape[1])
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id,
do_sample=do_sample, temperature=temperature
)
# for backwards compatibility
......
......@@ -157,7 +157,6 @@ TASK_REGISTRY = {
"mutual_plus": mutual.MuTualPlus,
# math
"math_algebra": hendrycks_math.MathAlgebra,
"math_algebra_maj@k": hendrycks_math.MathAlgebraMaj,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
"math_geometry": hendrycks_math.MathGeometry,
"math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
......
......@@ -27,6 +27,8 @@ _CITATION = """
class Math(Task):
DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math)
DATASET_NAME = None
MAJORITY_VOTING = "majority_voting"
SAMPLING_TEMPERATURE = "sampling_temperature"
def has_training_docs(self):
return True
......@@ -62,21 +64,72 @@ class Math(Task):
def doc_to_target(self, doc):
return " " + doc["solution"]
def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"])
def parse_description(self, description):
"""description is a string with comma-separated key=value tuples
e.g.:
"majority_voting=32,sampling_temperature=1.0"
"""
parsed_dict = {}
for term in description.split(","):
if not term.strip():
continue
key, value = term.split("=")
parsed_dict[key] = value
return parsed_dict
def construct_requests(self, doc, ctx, description=""):
if not description.strip():
return rf.generate(ctx, ["\n"])
parsed_description = self.parse_description(description=description)
majority_voting_value = int(parsed_description.get(self.MAJORITY_VOTING, 1))
sampling_temperature_value = float(parsed_description.get(self.SAMPLING_TEMPERATURE, 1.0))
return rf.generate(ctx, ["\n"],
majority_voting_value, sampling_temperature_value)
def get_pure_answer(self, candidate):
indices = [pos for pos, char in enumerate(candidate) if char == "$"]
if len(indices) <= 1:
return candidate
return candidate[indices[0] + 1 : indices[-1]]
def process_results(self, doc, results):
def majority_vote(self, candidates):
answers = []
for candidate in candidates:
answer = self.get_pure_answer(candidate)
try:
answer = self.remove_boxed(self.last_boxed_only_string(answer))
except:
answer = None
answers.append(answer)
answer_votes = {}
for answer in answers:
answer_votes[answer] = answer_votes.get(answer, 0) + 1
max_vote = 0
elected = None
for answer, vote in answer_votes.items():
if vote > max_vote and answer is not None:
elected = answer
max_vote = vote
return elected
def process_results(self, doc, results, description=""):
retval = 0
indices = [pos for pos, char in enumerate(results[0]) if char == "$"]
if len(indices) <= 1:
answer = results[0]
else:
answer = results[0][indices[0] + 1 : indices[-1]]
if description == "":
answer = self.get_pure_answer(results[0])
elif self.MAJORITY_VOTING in self.parse_description(description):
answer = self.majority_vote(results[0])
else:
raise AssertionError
if self.is_equiv(
answer, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
):
retval = 1
return {"acc": retval}
def aggregation(self):
......@@ -286,48 +339,6 @@ class MathAlgebra(Math):
DATASET_NAME = "algebra"
class MathAlgebraMaj(Math):
VERSION = 1
DATASET_NAME = "algebra"
def construct_requests(self, doc, ctx):
return rf.multiple_temperature_sample_until(ctx, ["\n"])
def process_results(self, doc, results):
retval = 0
candidates = results[0]
answers = []
for candidate in candidates:
indices = [pos for pos, char in enumerate(candidate) if char == "$"]
if len(indices) <= 1:
answer = candidate
else:
answer = candidate[indices[0] + 1 : indices[-1]]
try:
answer = self.remove_boxed(self.last_boxed_only_string(answer))
except:
answer = None
answers.append(answer)
answer_votes = {}
for answer in answers:
answer_votes[answer] = answer_votes.get(answer, 0) + 1
max_vote = 0
elected = None
for answer, vote in answer_votes.items():
if vote > max_vote and answer is not None:
elected = answer
max_vote = vote
if self.is_equiv(
elected, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
):
retval = 1
return {"acc": retval}
class MathCountingAndProbability(Math):
VERSION = 1
DATASET_NAME = "counting_and_probability"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment