Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
efbe6e7f
Commit
efbe6e7f
authored
Apr 04, 2021
by
Leo Gao
Browse files
Implement partial caching
Now, if a run gets interrupted halfway, you can easily resume
parent
8fe59e59
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
6 deletions
+50
-6
lm_eval/base.py
lm_eval/base.py
+27
-0
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+10
-1
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+13
-5
No files found.
lm_eval/base.py
View file @
efbe6e7f
...
@@ -6,6 +6,9 @@ from lm_eval.metrics import mean
...
@@ -6,6 +6,9 @@ from lm_eval.metrics import mean
class
LM
(
abc
.
ABC
):
class
LM
(
abc
.
ABC
):
def
__init__
(
self
):
self
.
cache_hook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
"""Compute log-likelihood of generating a continuation from a context.
"""Compute log-likelihood of generating a continuation from a context.
...
@@ -60,6 +63,9 @@ class LM(abc.ABC):
...
@@ -60,6 +63,9 @@ class LM(abc.ABC):
"""
"""
return
cls
()
return
cls
()
def
set_cache_hook
(
self
,
cache_hook
):
self
.
cache_hook
=
cache_hook
class
Task
(
abc
.
ABC
):
class
Task
(
abc
.
ABC
):
"""A task represents an entire benchmark including its dataset, problems,
"""A task represents an entire benchmark including its dataset, problems,
...
@@ -251,6 +257,21 @@ def hash_args(attr, args):
...
@@ -251,6 +257,21 @@ def hash_args(attr, args):
return
hashlib
.
sha256
(
dat
.
encode
(
'utf-8'
)).
hexdigest
()
return
hashlib
.
sha256
(
dat
.
encode
(
'utf-8'
)).
hexdigest
()
class
CacheHook
:
def
__init__
(
self
,
cachinglm
):
if
cachinglm
is
None
:
self
.
dbdict
=
None
return
self
.
dbdict
=
cachinglm
.
dbdict
def
add_partial
(
self
,
attr
,
req
,
res
):
if
self
.
dbdict
is
None
:
return
hsh
=
hash_args
(
attr
,
req
)
self
.
dbdict
[
hsh
]
=
res
class
CachingLM
:
class
CachingLM
:
def
__init__
(
self
,
lm
,
cache_db
):
def
__init__
(
self
,
lm
,
cache_db
):
self
.
lm
=
lm
self
.
lm
=
lm
...
@@ -258,6 +279,9 @@ class CachingLM:
...
@@ -258,6 +279,9 @@ class CachingLM:
os
.
makedirs
(
os
.
path
.
dirname
(
cache_db
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
cache_db
),
exist_ok
=
True
)
self
.
dbdict
=
SqliteDict
(
cache_db
,
autocommit
=
True
)
self
.
dbdict
=
SqliteDict
(
cache_db
,
autocommit
=
True
)
# add hook to lm
lm
.
set_cache_hook
(
self
.
get_cache_hook
())
def
__getattr__
(
self
,
attr
):
def
__getattr__
(
self
,
attr
):
def
fn
(
requests
):
def
fn
(
requests
):
res
=
[]
res
=
[]
...
@@ -293,6 +317,9 @@ class CachingLM:
...
@@ -293,6 +317,9 @@ class CachingLM:
return
res
return
res
return
fn
return
fn
def
get_cache_hook
(
self
):
return
CacheHook
(
self
)
class
Request
:
class
Request
:
...
...
lm_eval/models/gpt2.py
View file @
efbe6e7f
...
@@ -10,6 +10,7 @@ class GPT2LM(LM):
...
@@ -10,6 +10,7 @@ class GPT2LM(LM):
MAX_GEN_TOKS
=
256
MAX_GEN_TOKS
=
256
def
__init__
(
self
,
device
=
None
,
pretrained
=
'gpt2'
):
def
__init__
(
self
,
device
=
None
,
pretrained
=
'gpt2'
):
super
().
__init__
()
if
device
:
if
device
:
self
.
device
=
torch
.
device
(
device
)
self
.
device
=
torch
.
device
(
device
)
else
:
else
:
...
@@ -69,7 +70,12 @@ class GPT2LM(LM):
...
@@ -69,7 +70,12 @@ class GPT2LM(LM):
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [batch, seq]
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [batch, seq]
res
.
append
((
float
(
logits
.
sum
()),
bool
(
max_equal
)))
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
# partial caching
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
(
context
,
continuation
),
answer
)
res
.
append
(
answer
)
return
reord
.
get_original
(
res
)
return
reord
.
get_original
(
res
)
...
@@ -103,6 +109,9 @@ class GPT2LM(LM):
...
@@ -103,6 +109,9 @@ class GPT2LM(LM):
for
term
in
until
:
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
return
reord
.
get_original
(
res
)
lm_eval/models/gpt3.py
View file @
efbe6e7f
...
@@ -48,6 +48,7 @@ class GPT3LM(LM):
...
@@ -48,6 +48,7 @@ class GPT3LM(LM):
:param truncate: bool
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
Truncate input if too long (if False and input is too long, throw error)
"""
"""
super
().
__init__
()
import
openai
import
openai
self
.
engine
=
engine
self
.
engine
=
engine
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
'gpt2'
)
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
'gpt2'
)
...
@@ -104,8 +105,13 @@ class GPT3LM(LM):
...
@@ -104,8 +105,13 @@ class GPT3LM(LM):
logprobs
=
10
,
logprobs
=
10
,
)
)
for
resp
,
ctxlen
in
zip
(
response
.
choices
,
ctxlens
):
for
resp
,
ctxlen
,
(
context
,
continuation
)
in
zip
(
response
.
choices
,
ctxlens
,
chunk
):
res
.
append
(
get_result
(
resp
,
ctxlen
))
answer
=
get_result
(
resp
,
ctxlen
)
res
.
append
(
answer
)
# partial caching
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
(
context
,
continuation
),
answer
)
return
reord
.
get_original
(
res
)
return
reord
.
get_original
(
res
)
...
@@ -149,13 +155,15 @@ class GPT3LM(LM):
...
@@ -149,13 +155,15 @@ class GPT3LM(LM):
stop
=
until
stop
=
until
)
)
for
resp
in
response
.
choices
:
for
resp
,
(
context
,
until
)
in
zip
(
response
.
choices
,
chunk
)
:
s
=
resp
[
'text'
]
s
=
resp
[
'text'
]
for
term
in
until
:
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
return
reord
.
get_original
(
res
)()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment