Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6fb8dde6
Commit
6fb8dde6
authored
Sep 26, 2024
by
Baber
Browse files
fix `cost_estimate`
parent
b2bf7bc4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
10 deletions
+12
-10
scripts/cost_estimate.py
scripts/cost_estimate.py
+12
-10
No files found.
scripts/cost_estimate.py
View file @
6fb8dde6
...
...
@@ -2,7 +2,7 @@ import random
import
transformers
from
lm_eval
import
evaluator
,
tasks
from
lm_eval
import
evaluator
from
lm_eval.api.model
import
LM
...
...
@@ -11,6 +11,8 @@ class DryrunLM(LM):
self
.
tokencost
=
0
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
"gpt2"
)
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
self
.
_rank
=
0
self
.
_world_size
=
1
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
):
...
...
@@ -18,21 +20,21 @@ class DryrunLM(LM):
def
loglikelihood
(
self
,
requests
):
res
=
[]
for
ctx
,
cont
in
requests
:
for
ctx
,
cont
in
[
req
.
args
for
req
in
requests
]:
res
.
append
((
-
random
.
random
(),
False
))
self
.
tokencost
+=
len
(
self
.
tokenizer
.
tokenize
(
ctx
+
cont
))
# +1 for API models as they require at least on gen token
self
.
tokencost
+=
len
(
self
.
tokenizer
.
tokenize
(
ctx
+
cont
))
+
1
return
res
def
generate_until
(
self
,
requests
):
res
=
[]
for
ctx
,
_
in
requests
:
for
ctx
,
gen_kwargs
in
[
reg
.
args
for
reg
in
requests
]
:
res
.
append
(
"lol"
)
# assume worst case - generates until
256
self
.
tokencost
+=
len
(
self
.
tokenizer
.
tokenize
(
ctx
))
+
256
max_new
=
gen_kwargs
.
get
(
"max_gen_toks"
,
256
)
# assume worst case - generates until
max_new tokens
self
.
tokencost
+=
len
(
self
.
tokenizer
.
tokenize
(
ctx
))
+
max_new
return
res
...
...
@@ -54,8 +56,8 @@ def main():
for
taskname
in
task_list
.
split
(
","
):
lm
.
tokencost
=
0
evaluator
.
simple_evaluate
(
l
m
=
lm
,
task
_dict
=
{
taskname
:
tasks
.
get_task
(
taskname
)()}
,
m
odel
=
lm
,
task
s
=
[
taskname
]
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
10
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment