Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d5cd9655
Commit
d5cd9655
authored
Feb 04, 2021
by
Leo Gao
Browse files
Implement caching
parent
1815286c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
2 deletions
+64
-2
lm_eval/base.py
lm_eval/base.py
+58
-1
main.py
main.py
+6
-1
No files found.
lm_eval/base.py
View file @
d5cd9655
...
...
@@ -235,9 +235,66 @@ def perplexity(items):
return
math
.
exp
(
-
mean
(
items
))
req_ret_lens
=
{
'loglikelihood'
:
2
'loglikelihood'
:
2
,
}
import
os
import
json
import
hashlib
from
sqlitedict
import
SqliteDict
def
hash_args
(
args
):
dat
=
b
""
for
arg
in
args
:
assert
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
int
)
dat
+=
str
(
arg
).
encode
()
dat
+=
b
"
\0
"
return
hashlib
.
sha256
(
dat
).
hexdigest
()
class
CachingLM
:
def
__init__
(
self
,
lm
,
cache_db
):
self
.
lm
=
lm
self
.
cache_db
=
cache_db
os
.
makedirs
(
os
.
path
.
dirname
(
cache_db
),
exist_ok
=
True
)
self
.
dbdict
=
SqliteDict
(
cache_db
,
autocommit
=
True
)
def
__getattr__
(
self
,
attr
):
def
fn
(
requests
):
res
=
[]
remaining_reqs
=
[]
# figure out which ones are cached and which ones are new
for
req
in
requests
:
hsh
=
attr
+
'_'
+
hash_args
(
req
)
if
hsh
in
self
.
dbdict
:
ob
=
self
.
dbdict
[
hsh
]
assert
ob
is
not
None
res
.
append
(
ob
)
else
:
res
.
append
(
None
)
remaining_reqs
.
append
(
req
)
# actually run the LM
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
# stick the new ones back into the list and also cache any of the new ones
resptr
=
0
for
req
,
r
in
zip
(
remaining_reqs
,
rem_res
):
while
res
[
resptr
]
is
not
None
:
resptr
+=
1
res
[
resptr
]
=
r
# caching
hsh
=
attr
+
'_'
+
hash_args
(
req
)
self
.
dbdict
[
hsh
]
=
r
return
res
return
fn
class
Request
:
def
__init__
(
self
,
type
,
args
,
index
=
None
):
...
...
main.py
View file @
d5cd9655
...
...
@@ -5,7 +5,7 @@ import random
import
itertools
import
collections
from
lm_eval
import
models
,
tasks
,
evaluator
from
lm_eval
import
models
,
tasks
,
evaluator
,
base
def
parse_args
():
...
...
@@ -18,14 +18,19 @@ def parse_args():
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
)
parser
.
add_argument
(
'--output_path'
,
default
=
None
)
parser
.
add_argument
(
'--limit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--cache'
,
action
=
"store_true"
)
return
parser
.
parse_args
()
def
main
():
args
=
parse_args
()
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
lm
=
models
.
get_model
(
args
.
model
).
create_from_arg_string
(
args
.
model_args
)
if
args
.
cache
:
lm
=
base
.
CachingLM
(
lm
,
'lm_cache/'
+
args
.
model
+
'_'
+
args
.
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
)
+
'.db'
)
if
args
.
tasks
==
"all_tasks"
:
task_names
=
tasks
.
ALL_TASKS
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment