Commit 6fb8dde6 authored by Baber's avatar Baber
Browse files

fix `cost_estimate`

parent b2bf7bc4
......@@ -2,7 +2,7 @@ import random
import transformers
from lm_eval import evaluator, tasks
from lm_eval import evaluator
from lm_eval.api.model import LM
......@@ -11,6 +11,8 @@ class DryrunLM(LM):
self.tokencost = 0
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.tokenizer.pad_token = "<|endoftext|>"
self._rank = 0
self._world_size = 1
@classmethod
def create_from_arg_string(cls, arg_string):
......@@ -18,21 +20,21 @@ class DryrunLM(LM):
def loglikelihood(self, requests):
res = []
for ctx, cont in requests:
for ctx, cont in [req.args for req in requests]:
res.append((-random.random(), False))
self.tokencost += len(self.tokenizer.tokenize(ctx + cont))
# +1 for API models as they require at least on gen token
self.tokencost += len(self.tokenizer.tokenize(ctx + cont)) + 1
return res
def generate_until(self, requests):
res = []
for ctx, _ in requests:
for ctx, gen_kwargs in [reg.args for reg in requests]:
res.append("lol")
# assume worst case - generates until 256
self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256
max_new = gen_kwargs.get("max_gen_toks", 256)
# assume worst case - generates until max_new tokens
self.tokencost += len(self.tokenizer.tokenize(ctx)) + max_new
return res
......@@ -54,8 +56,8 @@ def main():
for taskname in task_list.split(","):
lm.tokencost = 0
evaluator.simple_evaluate(
lm=lm,
task_dict={taskname: tasks.get_task(taskname)()},
model=lm,
tasks=[taskname],
num_fewshot=0,
limit=None,
bootstrap_iters=10,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment