Commit 99fdec0c authored by Albert Jiang's avatar Albert Jiang
Browse files

implementing temperature sampling and maj@k for gpt2

parent 8c048e26
......@@ -152,7 +152,7 @@ class BaseLM(LM):
pass
@abstractmethod
def _model_generate(self, context, max_length, eos_token_id):
def _model_generate(self, context, max_length, eos_token_id, temperature=0.):
pass
@abstractmethod
......@@ -329,6 +329,44 @@ class BaseLM(LM):
return re_ord.get_original(res)
def multiple_temperature_sample_until(self, requests, k=32, temperature=0.3):
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
for context, until in tqdm(re_ord.get_reordered()):
if isinstance(until, str):
until = [until]
(primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
assert context_enc.shape[0] == 1
context_enc = context_enc.expand(k, context_enc.shape[1])
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until,
temperature=temperature
)
generated_tokens = cont[:, context_enc.shape[1]:]
s = [self.tok_decode(candidate) for candidate in generated_tokens]
for term in until:
s = [candidate.split(term)[0] for candidate in s]
# partial caching
self.cache_hook.add_partial("multiple_temperature_sample_until", (context, until, k, temperature), s)
res.append(s)
return re_ord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly
......@@ -843,6 +881,7 @@ class CachingLM:
REQUEST_RETURN_LENGTHS = {
"loglikelihood": 2,
"greedy_until": None,
"multiple_temperature_sample_until": None,
"loglikelihood_rolling": None,
}
......
......@@ -121,10 +121,16 @@ class HFLM(BaseLM):
with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
def _model_generate(self, context, max_length, eos_token_id, temperature=0.):
assert temperature >= 0.
if temperature == 0.:
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
else:
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=True, temperature=temperature
)
# for backwards compatibility
......
......@@ -157,6 +157,7 @@ TASK_REGISTRY = {
"mutual_plus": mutual.MuTualPlus,
# math
"math_algebra": hendrycks_math.MathAlgebra,
"math_algebra_maj@k": hendrycks_math.MathAlgebraMaj,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
"math_geometry": hendrycks_math.MathGeometry,
"math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
......
......@@ -286,6 +286,48 @@ class MathAlgebra(Math):
DATASET_NAME = "algebra"
class MathAlgebraMaj(Math):
VERSION = 1
DATASET_NAME = "algebra"
def construct_requests(self, doc, ctx):
return rf.multiple_temperature_sample_until(ctx, ["\n"])
def process_results(self, doc, results):
retval = 0
candidates = results[0]
answers = []
for candidate in candidates:
indices = [pos for pos, char in enumerate(candidate) if char == "$"]
if len(indices) <= 1:
answer = candidate
else:
answer = candidate[indices[0] + 1 : indices[-1]]
try:
answer = self.remove_boxed(self.last_boxed_only_string(answer))
except:
answer = None
answers.append(answer)
answer_votes = {}
for answer in answers:
answer_votes[answer] = answer_votes.get(answer, 0) + 1
max_vote = 0
elected = None
for answer, vote in answer_votes.items():
if vote > max_vote and answer is not None:
elected = answer
max_vote = vote
if self.is_equiv(
elected, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
):
retval = 1
return {"acc": retval}
class MathCountingAndProbability(Math):
VERSION = 1
DATASET_NAME = "counting_and_probability"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment