Commit 99fdec0c authored by Albert Jiang's avatar Albert Jiang
Browse files

implementing temperature sampling and maj@k for gpt2

parent 8c048e26
...@@ -152,7 +152,7 @@ class BaseLM(LM): ...@@ -152,7 +152,7 @@ class BaseLM(LM):
pass pass
@abstractmethod @abstractmethod
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id, temperature=0.):
pass pass
@abstractmethod @abstractmethod
...@@ -329,6 +329,44 @@ class BaseLM(LM): ...@@ -329,6 +329,44 @@ class BaseLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def multiple_temperature_sample_until(self, requests, k=32, temperature=0.3):
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
for context, until in tqdm(re_ord.get_reordered()):
if isinstance(until, str):
until = [until]
(primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
assert context_enc.shape[0] == 1
context_enc = context_enc.expand(k, context_enc.shape[1])
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until,
temperature=temperature
)
generated_tokens = cont[:, context_enc.shape[1]:]
s = [self.tok_decode(candidate) for candidate in generated_tokens]
for term in until:
s = [candidate.split(term)[0] for candidate in s]
# partial caching
self.cache_hook.add_partial("multiple_temperature_sample_until", (context, until, k, temperature), s)
res.append(s)
return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement fully general `until` that handles until that are # TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly # multiple tokens or that span multiple tokens correctly
...@@ -843,6 +881,7 @@ class CachingLM: ...@@ -843,6 +881,7 @@ class CachingLM:
REQUEST_RETURN_LENGTHS = { REQUEST_RETURN_LENGTHS = {
"loglikelihood": 2, "loglikelihood": 2,
"greedy_until": None, "greedy_until": None,
"multiple_temperature_sample_until": None,
"loglikelihood_rolling": None, "loglikelihood_rolling": None,
} }
......
...@@ -121,10 +121,16 @@ class HFLM(BaseLM): ...@@ -121,10 +121,16 @@ class HFLM(BaseLM):
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257] return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id, temperature=0.):
return self.gpt2.generate( assert temperature >= 0.
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False if temperature == 0.:
) return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
else:
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=True, temperature=temperature
)
# for backwards compatibility # for backwards compatibility
......
...@@ -157,6 +157,7 @@ TASK_REGISTRY = { ...@@ -157,6 +157,7 @@ TASK_REGISTRY = {
"mutual_plus": mutual.MuTualPlus, "mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_algebra_maj@k": hendrycks_math.MathAlgebraMaj,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
"math_geometry": hendrycks_math.MathGeometry, "math_geometry": hendrycks_math.MathGeometry,
"math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra, "math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
......
...@@ -286,6 +286,48 @@ class MathAlgebra(Math): ...@@ -286,6 +286,48 @@ class MathAlgebra(Math):
DATASET_NAME = "algebra" DATASET_NAME = "algebra"
class MathAlgebraMaj(Math):
VERSION = 1
DATASET_NAME = "algebra"
def construct_requests(self, doc, ctx):
return rf.multiple_temperature_sample_until(ctx, ["\n"])
def process_results(self, doc, results):
retval = 0
candidates = results[0]
answers = []
for candidate in candidates:
indices = [pos for pos, char in enumerate(candidate) if char == "$"]
if len(indices) <= 1:
answer = candidate
else:
answer = candidate[indices[0] + 1 : indices[-1]]
try:
answer = self.remove_boxed(self.last_boxed_only_string(answer))
except:
answer = None
answers.append(answer)
answer_votes = {}
for answer in answers:
answer_votes[answer] = answer_votes.get(answer, 0) + 1
max_vote = 0
elected = None
for answer, vote in answer_votes.items():
if vote > max_vote and answer is not None:
elected = answer
max_vote = vote
if self.is_equiv(
elected, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
):
retval = 1
return {"acc": retval}
class MathCountingAndProbability(Math): class MathCountingAndProbability(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = "counting_and_probability" DATASET_NAME = "counting_and_probability"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment