Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4e0d0e3a
Unverified
Commit
4e0d0e3a
authored
Jun 28, 2023
by
Hailey Schoelkopf
Committed by
GitHub
Jun 28, 2023
Browse files
Merge pull request #619 from EleutherAI/cachinglm-only
[Refactor] CachingLM support via `--use_cache`
parents
9dea125b
5a5442ff
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
172 additions
and
25 deletions
+172
-25
lm_eval/api/model.py
lm_eval/api/model.py
+109
-0
lm_eval/api/task.py
lm_eval/api/task.py
+4
-5
lm_eval/evaluator.py
lm_eval/evaluator.py
+14
-4
lm_eval/models/anthropic_llms.py
lm_eval/models/anthropic_llms.py
+5
-0
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+16
-10
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+16
-4
lm_eval/models/textsynth.py
lm_eval/models/textsynth.py
+6
-0
main.py
main.py
+2
-2
No files found.
lm_eval/api/model.py
View file @
4e0d0e3a
import
abc
import
os
from
typing
import
Union
from
sqlitedict
import
SqliteDict
import
json
import
hashlib
from
tqdm
import
tqdm
from
lm_eval
import
utils
from
lm_eval.logger
import
eval_logger
class
LM
(
abc
.
ABC
):
...
...
@@ -12,6 +19,7 @@ class LM(abc.ABC):
(inputs/outputs should be tokenization-agnostic.)
"""
self
.
cache_hook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
):
...
...
@@ -118,3 +126,104 @@ class LM(abc.ABC):
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return
1
def
set_cache_hook
(
self
,
cache_hook
):
self
.
cache_hook
=
cache_hook
### SQLite-based caching of LM responses
def
hash_args
(
attr
,
args
):
dat
=
json
.
dumps
([
attr
]
+
list
(
args
))
return
hashlib
.
sha256
(
dat
.
encode
(
"utf-8"
)).
hexdigest
()
class
CacheHook
:
def
__init__
(
self
,
cachinglm
):
if
cachinglm
is
None
:
self
.
dbdict
=
None
return
self
.
dbdict
=
cachinglm
.
dbdict
def
add_partial
(
self
,
attr
,
req
,
res
):
if
self
.
dbdict
is
None
:
return
hsh
=
hash_args
(
attr
,
req
)
self
.
dbdict
[
hsh
]
=
res
class
CachingLM
:
def
__init__
(
self
,
lm
,
cache_db
):
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
Underlying LM
:param cache_db: str
Path to cache db
"""
self
.
lm
=
lm
self
.
cache_db
=
cache_db
if
os
.
path
.
dirname
(
cache_db
):
os
.
makedirs
(
os
.
path
.
dirname
(
cache_db
),
exist_ok
=
True
)
self
.
dbdict
=
SqliteDict
(
cache_db
,
autocommit
=
True
)
# add hook to lm
lm
.
set_cache_hook
(
self
.
get_cache_hook
())
def
__getattr__
(
self
,
attr
):
lm_attr
=
getattr
(
self
.
lm
,
attr
)
if
not
callable
(
lm_attr
):
return
lm_attr
def
fn
(
requests
):
res
=
[]
remaining_reqs
=
[]
warned
=
False
# figure out which ones are cached and which ones are new
eval_logger
.
info
(
f
"Loading '
{
attr
}
' responses from cache '
{
self
.
cache_db
}
' where possible..."
)
for
req
in
tqdm
(
requests
):
hsh
=
hash_args
(
attr
,
req
.
args
)
if
attr
==
"greedy_until"
and
req
.
args
[
1
].
get
(
"do_sample"
,
False
):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if
not
warned
:
eval_logger
.
warning
(
f
"Arguments to lm.greedy_until() '
{
req
.
args
[
1
]
}
' include non-deterministic sampling. Caching will not be performed for such requests."
)
warned
=
True
res
.
append
(
None
)
remaining_reqs
.
append
(
req
)
elif
hsh
in
self
.
dbdict
:
ob
=
self
.
dbdict
[
hsh
]
assert
ob
is
not
None
res
.
append
(
ob
)
else
:
res
.
append
(
None
)
remaining_reqs
.
append
(
req
)
# actually run the LM on the requests that do not have cached results
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
# stick the new ones back into the list and also cache any of the new ones
resptr
=
0
for
req
,
r
in
zip
(
remaining_reqs
,
rem_res
):
while
res
[
resptr
]
is
not
None
:
resptr
+=
1
res
[
resptr
]
=
r
# caching
hsh
=
hash_args
(
attr
,
req
.
args
)
self
.
dbdict
[
hsh
]
=
r
self
.
dbdict
.
commit
()
return
res
return
fn
def
get_cache_hook
(
self
):
return
CacheHook
(
self
)
lm_eval/api/task.py
View file @
4e0d0e3a
...
...
@@ -52,7 +52,6 @@ class TaskConfig(dict):
task
:
str
=
None
group
:
Union
[
str
,
list
]
=
None
reference
:
str
=
None
dataset_path
:
str
=
None
dataset_name
:
str
=
None
...
...
@@ -67,6 +66,8 @@ class TaskConfig(dict):
doc_to_target
:
Union
[
Callable
,
str
]
=
None
use_prompt
:
str
=
None
description
:
str
=
""
target_delimiter
:
str
=
" "
fewshot_delimiter
:
str
=
"
\n\n
"
num_fewshot
:
int
=
0
batch_size
:
int
=
1
...
...
@@ -76,8 +77,6 @@ class TaskConfig(dict):
gold_alias
:
Union
[
Callable
,
str
]
=
None
output_type
:
str
=
"greedy_until"
generation_kwargs
:
dict
=
None
target_delimiter
:
str
=
" "
fewshot_delimiter
:
str
=
"
\n\n
"
filter_list
:
Union
[
str
,
list
]
=
None
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
str
=
None
...
...
@@ -343,7 +342,7 @@ class Task(abc.ABC):
fewshot_ctx
=
self
.
fewshot_context
(
doc
,
self
.
_config
.
num_fewshot
,
rnd
=
random
.
Random
()
)
# TODO:
hardcoded for now: # of runs on each input to be 2. # TODO:
we should override this if doing greedy gen so users don't waste time+compute
# TODO: we should override this if doing greedy gen so users don't waste time+compute
inst
=
self
.
construct_requests
(
doc
=
doc
,
ctx
=
fewshot_ctx
,
...
...
@@ -773,7 +772,7 @@ class ConfigurableTask(Task):
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
(
""
,
"{}"
.
format
(
choice
)),
arguments
=
(
""
,
"
{}"
.
format
(
choice
)),
idx
=
i
,
**
kwargs
,
)
...
...
lm_eval/evaluator.py
View file @
4e0d0e3a
...
...
@@ -39,7 +39,7 @@ def simple_evaluate(
batch_size
=
None
,
max_batch_size
=
None
,
device
=
None
,
no
_cache
=
Fals
e
,
use
_cache
=
Non
e
,
limit
=
None
,
bootstrap_iters
=
100000
,
check_integrity
=
False
,
...
...
@@ -64,8 +64,8 @@ def simple_evaluate(
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param
no
_cache:
boo
l
Whether or
not
to
cach
e
:param
use
_cache:
str, optiona
l
A path to a sqlite db file for caching model responses. `None` if
not cach
ing.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
...
...
@@ -99,6 +99,16 @@ def simple_evaluate(
assert
isinstance
(
model
,
lm_eval
.
api
.
model
.
LM
)
lm
=
model
if
use_cache
is
not
None
:
print
(
f
"Using cache at
{
use_cache
+
'_rank'
+
str
(
lm
.
rank
)
+
'.db'
}
"
)
lm
=
lm_eval
.
api
.
model
.
CachingLM
(
lm
,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+
"_rank"
+
str
(
lm
.
rank
)
+
".db"
,
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
,
num_fewshot
=
num_fewshot
)
if
check_integrity
:
...
...
@@ -127,7 +137,7 @@ def simple_evaluate(
if
hasattr
(
lm
,
"batch_sizes"
)
else
[],
"device"
:
device
,
"
no
_cache"
:
no
_cache
,
"
use
_cache"
:
use
_cache
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
}
...
...
lm_eval/models/anthropic_llms.py
View file @
4e0d0e3a
...
...
@@ -88,6 +88,8 @@ class AnthropicLM(LM):
if
not
requests
:
return
[]
requests
=
[
req
.
args
for
req
in
requests
]
res
=
[]
for
request
in
tqdm
(
requests
):
inp
=
request
[
0
]
...
...
@@ -102,6 +104,9 @@ class AnthropicLM(LM):
stop
=
until
,
)
res
.
append
(
response
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
request
,
response
)
return
res
def
_model_call
(
self
,
inps
):
...
...
lm_eval/models/huggingface.py
View file @
4e0d0e3a
...
...
@@ -486,6 +486,8 @@ class HFLM(LM):
res
.
append
(
answer
)
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
...
...
@@ -497,26 +499,28 @@ class HFLM(LM):
re_ord
=
utils
.
Reorderer
([
req
.
args
for
req
in
requests
],
_collate
)
for
context
,
gen_kwargs
in
tqdm
(
re_ord
.
get_reordered
()):
for
context
,
gen_kwargs
in
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
self
.
rank
!=
0
)
):
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
gen_
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
gen_
kwargs
.
keys
():
until
=
gen_
kwargs
.
pop
(
"until"
)
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
gen_
kwargs
]
until
=
[
kwargs
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `
gen_
kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
raise
ValueError
(
f
"Expected `
gen_
kwargs` to be of type `dict` but got
{
gen_
kwargs
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
kwargs
}
"
)
if
not
until
:
until
=
[
self
.
tok_decode
(
self
.
eot_token_id
)]
if
"max_gen_toks"
in
gen_
kwargs
.
keys
():
max_gen_toks
=
gen_
kwargs
.
pop
(
"max_gen_toks"
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
# first stop sequence is used to halt generation upon encountering
...
...
@@ -539,7 +543,7 @@ class HFLM(LM):
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
**
gen_
kwargs
,
**
kwargs
,
)
cont_toks_list
=
cont
[
0
].
tolist
()
...
...
@@ -556,4 +560,6 @@ class HFLM(LM):
res
.
append
(
s
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
gen_kwargs
),
s
)
return
re_ord
.
get_original
(
res
)
lm_eval/models/openai_completions.py
View file @
4e0d0e3a
...
...
@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM):
yield
ret
,
lastuntil
# todo: more intelligent batching for heterogeneous `until`
for
chunk
,
until
in
tqdm
(
for
chunk
,
request_args
in
tqdm
(
list
(
sameuntil_chunks
(
re_ord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))
):
inps
=
[]
...
...
@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM):
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
inps
.
append
(
inp
)
try
:
until
=
request_args
[
"until"
][
0
]
# TODO: does this handle a list of stop seqs correctly?
except
KeyError
:
until
=
"<|endoftext|>"
response
=
oa_completion
(
engine
=
self
.
engine
,
prompt
=
inps
,
...
...
@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM):
stop
=
until
,
)
for
resp
,
(
context
,
until
_
)
in
zip
(
response
.
choices
,
chunk
):
for
resp
,
(
context
,
args
_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
[
"text"
]
until_
=
args_
.
get
([
"until"
],
[])
for
term
in
until_
:
s
=
s
.
split
(
term
)[
0
]
if
len
(
term
)
>
0
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until_
),
s
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
{
"until"
:
until_
}),
s
)
res
.
append
(
s
)
...
...
lm_eval/models/textsynth.py
View file @
4e0d0e3a
...
...
@@ -101,6 +101,10 @@ class TextSynthLM(LM):
logprob
=
resp
[
"logprob"
]
is_greedy
=
resp
[
"is_greedy"
]
res
.
append
((
logprob
,
is_greedy
))
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
(
context
,
continuation
),
(
logprob
,
is_greedy
)
)
else
:
logger
.
error
(
f
"The following response does not contain `logprobs`. Got:
\n
{
resp
}
"
...
...
@@ -141,6 +145,8 @@ class TextSynthLM(LM):
if
"text"
in
resp
:
s
=
resp
[
"text"
]
res
.
append
(
s
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
inp
,
request_args
),
s
)
else
:
logger
.
error
(
f
"The following response does not contain generated `text`. "
...
...
main.py
View file @
4e0d0e3a
...
...
@@ -39,7 +39,7 @@ def parse_args():
"If <1, limit is a percentage of the total number of examples."
,
)
parser
.
add_argument
(
"--data_sampling"
,
type
=
float
,
default
=
None
)
parser
.
add_argument
(
"--
no
_cache"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--
use
_cache"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--decontamination_ngrams_path"
,
default
=
None
)
parser
.
add_argument
(
"--check_integrity"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--write_out"
,
action
=
"store_true"
,
default
=
False
)
...
...
@@ -85,7 +85,7 @@ def main():
batch_size
=
args
.
batch_size
,
max_batch_size
=
args
.
max_batch_size
,
device
=
args
.
device
,
no
_cache
=
args
.
no
_cache
,
use
_cache
=
args
.
use
_cache
,
limit
=
args
.
limit
,
decontamination_ngrams_path
=
args
.
decontamination_ngrams_path
,
check_integrity
=
args
.
check_integrity
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment