Commit 6fb8dde6 authored by Baber's avatar Baber
Browse files

fix `cost_estimate`

parent b2bf7bc4
...@@ -2,7 +2,7 @@ import random ...@@ -2,7 +2,7 @@ import random
import transformers import transformers
from lm_eval import evaluator, tasks from lm_eval import evaluator
from lm_eval.api.model import LM from lm_eval.api.model import LM
...@@ -11,6 +11,8 @@ class DryrunLM(LM): ...@@ -11,6 +11,8 @@ class DryrunLM(LM):
self.tokencost = 0 self.tokencost = 0
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2") self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
self._rank = 0
self._world_size = 1
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
...@@ -18,21 +20,21 @@ class DryrunLM(LM): ...@@ -18,21 +20,21 @@ class DryrunLM(LM):
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
for ctx, cont in [req.args for req in requests]:
for ctx, cont in requests:
res.append((-random.random(), False)) res.append((-random.random(), False))
self.tokencost += len(self.tokenizer.tokenize(ctx + cont)) # +1 for API models as they require at least on gen token
self.tokencost += len(self.tokenizer.tokenize(ctx + cont)) + 1
return res return res
def generate_until(self, requests): def generate_until(self, requests):
res = [] res = []
for ctx, _ in requests: for ctx, gen_kwargs in [reg.args for reg in requests]:
res.append("lol") res.append("lol")
max_new = gen_kwargs.get("max_gen_toks", 256)
# assume worst case - generates until 256 # assume worst case - generates until max_new tokens
self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256 self.tokencost += len(self.tokenizer.tokenize(ctx)) + max_new
return res return res
...@@ -54,8 +56,8 @@ def main(): ...@@ -54,8 +56,8 @@ def main():
for taskname in task_list.split(","): for taskname in task_list.split(","):
lm.tokencost = 0 lm.tokencost = 0
evaluator.simple_evaluate( evaluator.simple_evaluate(
lm=lm, model=lm,
task_dict={taskname: tasks.get_task(taskname)()}, tasks=[taskname],
num_fewshot=0, num_fewshot=0,
limit=None, limit=None,
bootstrap_iters=10, bootstrap_iters=10,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment