Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
efbe6e7f
Commit
efbe6e7f
authored
Apr 04, 2021
by
Leo Gao
Browse files
Implement partial caching
Now, if a run gets interrupted halfway, you can easily resume
parent
8fe59e59
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
6 deletions
+50
-6
lm_eval/base.py
lm_eval/base.py
+27
-0
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+10
-1
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+13
-5
No files found.
lm_eval/base.py
View file @
efbe6e7f
...
...
@@ -6,6 +6,9 @@ from lm_eval.metrics import mean
class
LM
(
abc
.
ABC
):
def
__init__
(
self
):
self
.
cache_hook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
):
"""Compute log-likelihood of generating a continuation from a context.
...
...
@@ -60,6 +63,9 @@ class LM(abc.ABC):
"""
return
cls
()
def
set_cache_hook
(
self
,
cache_hook
):
self
.
cache_hook
=
cache_hook
class
Task
(
abc
.
ABC
):
"""A task represents an entire benchmark including its dataset, problems,
...
...
@@ -251,6 +257,21 @@ def hash_args(attr, args):
return
hashlib
.
sha256
(
dat
.
encode
(
'utf-8'
)).
hexdigest
()
class
CacheHook
:
def
__init__
(
self
,
cachinglm
):
if
cachinglm
is
None
:
self
.
dbdict
=
None
return
self
.
dbdict
=
cachinglm
.
dbdict
def
add_partial
(
self
,
attr
,
req
,
res
):
if
self
.
dbdict
is
None
:
return
hsh
=
hash_args
(
attr
,
req
)
self
.
dbdict
[
hsh
]
=
res
class
CachingLM
:
def
__init__
(
self
,
lm
,
cache_db
):
self
.
lm
=
lm
...
...
@@ -258,6 +279,9 @@ class CachingLM:
os
.
makedirs
(
os
.
path
.
dirname
(
cache_db
),
exist_ok
=
True
)
self
.
dbdict
=
SqliteDict
(
cache_db
,
autocommit
=
True
)
# add hook to lm
lm
.
set_cache_hook
(
self
.
get_cache_hook
())
def
__getattr__
(
self
,
attr
):
def
fn
(
requests
):
res
=
[]
...
...
@@ -293,6 +317,9 @@ class CachingLM:
return
res
return
fn
def
get_cache_hook
(
self
):
return
CacheHook
(
self
)
class
Request
:
...
...
lm_eval/models/gpt2.py
View file @
efbe6e7f
...
...
@@ -10,6 +10,7 @@ class GPT2LM(LM):
MAX_GEN_TOKS
=
256
def
__init__
(
self
,
device
=
None
,
pretrained
=
'gpt2'
):
super
().
__init__
()
if
device
:
self
.
device
=
torch
.
device
(
device
)
else
:
...
...
@@ -69,7 +70,12 @@ class GPT2LM(LM):
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [batch, seq]
res
.
append
((
float
(
logits
.
sum
()),
bool
(
max_equal
)))
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
# partial caching
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
(
context
,
continuation
),
answer
)
res
.
append
(
answer
)
return
reord
.
get_original
(
res
)
...
...
@@ -103,6 +109,9 @@ class GPT2LM(LM):
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
lm_eval/models/gpt3.py
View file @
efbe6e7f
...
...
@@ -48,6 +48,7 @@ class GPT3LM(LM):
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super
().
__init__
()
import
openai
self
.
engine
=
engine
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
'gpt2'
)
...
...
@@ -104,8 +105,13 @@ class GPT3LM(LM):
logprobs
=
10
,
)
for
resp
,
ctxlen
in
zip
(
response
.
choices
,
ctxlens
):
res
.
append
(
get_result
(
resp
,
ctxlen
))
for
resp
,
ctxlen
,
(
context
,
continuation
)
in
zip
(
response
.
choices
,
ctxlens
,
chunk
):
answer
=
get_result
(
resp
,
ctxlen
)
res
.
append
(
answer
)
# partial caching
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
(
context
,
continuation
),
answer
)
return
reord
.
get_original
(
res
)
...
...
@@ -149,13 +155,15 @@ class GPT3LM(LM):
stop
=
until
)
for
resp
in
response
.
choices
:
for
resp
,
(
context
,
until
)
in
zip
(
response
.
choices
,
chunk
)
:
s
=
resp
[
'text'
]
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
return
reord
.
get_original
(
res
)()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment