Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f19408c3
Commit
f19408c3
authored
May 25, 2025
by
Baber
Browse files
feat: implement caching decorator for request handling and improve cache key generation
parent
7aaceeec
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
197 additions
and
79 deletions
+197
-79
lm_eval/api/model.py
lm_eval/api/model.py
+17
-7
lm_eval/api/task.py
lm_eval/api/task.py
+22
-59
lm_eval/caching/cache.py
lm_eval/caching/cache.py
+155
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-9
lm_eval/utils.py
lm_eval/utils.py
+2
-2
No files found.
lm_eval/api/model.py
View file @
f19408c3
...
@@ -2,7 +2,7 @@ import abc
...
@@ -2,7 +2,7 @@ import abc
import
hashlib
import
hashlib
import
json
import
json
import
logging
import
logging
import
os
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
import
transformers
import
transformers
...
@@ -230,7 +230,7 @@ class CacheHook:
...
@@ -230,7 +230,7 @@ class CacheHook:
class
CachingLM
:
class
CachingLM
:
def
__init__
(
self
,
lm
,
cache_db
)
->
None
:
def
__init__
(
self
,
lm
:
"LM"
,
cache_db
:
str
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
:param lm: LM
...
@@ -239,12 +239,22 @@ class CachingLM:
...
@@ -239,12 +239,22 @@ class CachingLM:
Path to cache db
Path to cache db
"""
"""
self
.
lm
=
lm
self
.
lm
=
lm
self
.
cache_db
=
cache_db
if
os
.
path
.
dirname
(
cache_db
):
os
.
makedirs
(
os
.
path
.
dirname
(
cache_db
),
exist_ok
=
True
)
self
.
dbdict
=
SqliteDict
(
cache_db
,
autocommit
=
True
)
# add hook to lm
# Setup cache path
cache_path
=
Path
(
cache_db
)
if
cache_path
.
is_dir
()
or
(
not
cache_path
.
suffix
and
not
cache_path
.
exists
()):
cache_path
=
cache_path
/
"cache.db"
self
.
cache_db
=
cache_path
cache_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Initialize database with WAL mode for concurrent access
self
.
dbdict
=
SqliteDict
(
str
(
cache_path
),
autocommit
=
True
,
timeout
=
30.0
)
# Enable WAL mode for better concurrency
self
.
dbdict
.
conn
.
execute
(
"PRAGMA journal_mode=WAL"
)
self
.
dbdict
.
conn
.
commit
()
lm
.
set_cache_hook
(
self
.
get_cache_hook
())
lm
.
set_cache_hook
(
self
.
get_cache_hook
())
def
__getattr__
(
self
,
attr
:
str
):
def
__getattr__
(
self
,
attr
:
str
):
...
...
lm_eval/api/task.py
View file @
f19408c3
...
@@ -36,7 +36,7 @@ from lm_eval.api.registry import (
...
@@ -36,7 +36,7 @@ from lm_eval.api.registry import (
get_metric_aggregation
,
get_metric_aggregation
,
is_higher_better
,
is_higher_better
,
)
)
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.caching.cache
import
cache_instances
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.prompts
import
get_prompt
from
lm_eval.prompts
import
get_prompt
...
@@ -387,6 +387,7 @@ class Task(abc.ABC):
...
@@ -387,6 +387,7 @@ class Task(abc.ABC):
def
doc_to_prefix
(
self
,
doc
):
def
doc_to_prefix
(
self
,
doc
):
return
""
return
""
@
cache_instances
def
build_all_requests
(
def
build_all_requests
(
self
,
self
,
*
,
*
,
...
@@ -394,68 +395,28 @@ class Task(abc.ABC):
...
@@ -394,68 +395,28 @@ class Task(abc.ABC):
samples
:
Optional
[
List
[
int
]]
=
None
,
samples
:
Optional
[
List
[
int
]]
=
None
,
rank
:
int
=
0
,
rank
:
int
=
0
,
world_size
:
int
=
1
,
world_size
:
int
=
1
,
cache_requests
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
system_instruction
:
Optional
[
str
]
=
None
,
system_instruction
:
Optional
[
str
]
=
None
,
apply_chat_template
:
bool
=
False
,
apply_chat_template
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
chat_template
:
Optional
[
Callable
]
=
None
,
chat_template
:
Optional
[
Callable
]
=
None
,
cache_requests
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
tokenizer_name
:
str
=
""
,
tokenizer_name
:
str
=
""
,
)
->
None
:
)
->
Optional
[
List
[
List
[
Instance
]]]
:
"""Build a set of Instances for a task, and store them in task.instances"""
"""Build a set of Instances for a task, and store them in task.instances"""
# used with caching
og_limit
=
limit
cache_key
=
f
"requests-
{
self
.
_config
.
task
}
-
{
self
.
config
.
num_fewshot
}
shot-rank
{
rank
}
-world_size
{
world_size
}
"
cache_key
+=
"-chat_template"
if
apply_chat_template
else
""
cache_key
+=
"-fewshot_as_multiturn"
if
fewshot_as_multiturn
else
""
cache_key
+=
(
f
"-system_prompt_hash
{
utils
.
hash_string
(
system_instruction
)
}
"
if
system_instruction
is
not
None
else
""
)
cache_key
+=
f
"-tokenizer
{
tokenizer_name
}
"
cached_instances
=
load_from_cache
(
file_name
=
cache_key
,
cache
=
cache_requests
)
if
cache_requests
and
cached_instances
and
not
rewrite_requests_cache
:
cached_instances
=
cached_instances
[:
limit
]
flattened_instances
=
[
instance
for
instance_group
in
cached_instances
for
instance
in
instance_group
]
self
.
_instances
=
flattened_instances
return
eval_logger
.
info
(
f
"Building contexts for
{
self
.
config
.
task
}
on rank
{
rank
}
..."
)
eval_logger
.
info
(
f
"Building contexts for
{
self
.
config
.
task
}
on rank
{
rank
}
..."
)
instances
=
[]
instances
=
[]
# process all documents when caching is specified for simplicity
if
(
cache_requests
and
(
not
cached_instances
or
rewrite_requests_cache
)
and
limit
is
not
None
):
limit
=
None
doc_id_docs
=
list
(
doc_id_docs
=
list
(
self
.
doc_iterator
(
self
.
doc_iterator
(
rank
=
rank
,
limit
=
limit
,
samples
=
samples
,
world_size
=
world_size
rank
=
rank
,
limit
=
limit
,
samples
=
samples
,
world_size
=
world_size
)
)
)
)
num_docs
=
len
(
doc_id_docs
)
for
doc_id
,
doc
in
tqdm
(
doc_id_docs
,
total
=
len
(
doc_id_docs
)):
# sample fewshot context
for
doc_id
,
doc
in
tqdm
(
doc_id_docs
,
total
=
num_docs
,
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx
=
self
.
fewshot_context
(
fewshot_ctx
=
self
.
fewshot_context
(
doc
,
doc
,
0
if
self
.
config
.
num_fewshot
is
None
else
self
.
config
.
num_fewshot
,
0
if
self
.
config
.
num_fewshot
is
None
else
self
.
config
.
num_fewshot
,
...
@@ -466,7 +427,7 @@ class Task(abc.ABC):
...
@@ -466,7 +427,7 @@ class Task(abc.ABC):
gen_prefix
=
self
.
doc_to_prefix
(
doc
),
gen_prefix
=
self
.
doc_to_prefix
(
doc
),
)
)
#
TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
#
construct requests
inst
=
self
.
construct_requests
(
inst
=
self
.
construct_requests
(
doc
=
doc
,
doc
=
doc
,
ctx
=
fewshot_ctx
,
ctx
=
fewshot_ctx
,
...
@@ -480,14 +441,14 @@ class Task(abc.ABC):
...
@@ -480,14 +441,14 @@ class Task(abc.ABC):
instances
.
append
(
inst
)
instances
.
append
(
inst
)
# now flatten, this is to allow slicing to work with pickles
# Handle non-caching case
if
not
cache_requests
:
sliced_instances
=
instances
[:
og_limit
]
# Apply limit at document level, then flatten
if
limit
is
not
None
:
instances
=
instances
[:
limit
]
flattened_instances
=
[
flattened_instances
=
[
instance
instance
for
instance_group
in
instances
for
instance
in
instance_group
for
instance_group
in
sliced_instances
for
instance
in
instance_group
]
]
self
.
_instances
=
flattened_instances
self
.
_instances
=
flattened_instances
...
@@ -495,8 +456,10 @@ class Task(abc.ABC):
...
@@ -495,8 +456,10 @@ class Task(abc.ABC):
if
len
(
self
.
_instances
)
==
0
:
if
len
(
self
.
_instances
)
==
0
:
raise
ValueError
(
"task.build_requests() did not find any docs!"
)
raise
ValueError
(
"task.build_requests() did not find any docs!"
)
if
cache_requests
and
(
not
cached_instances
or
rewrite_requests_cache
):
return
None
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
# Return instances for decorator to handle
return
instances
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
...
...
lm_eval/caching/cache.py
View file @
f19408c3
import
hashlib
import
hashlib
import
logging
import
logging
import
os
import
os
from
functools
import
wraps
import
dill
from
typing
import
Callable
,
List
,
Optional
,
Union
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -27,6 +27,8 @@ def load_from_cache(file_name: str, cache: bool = False):
...
@@ -27,6 +27,8 @@ def load_from_cache(file_name: str, cache: bool = False):
if
not
cache
:
if
not
cache
:
return
return
try
:
try
:
import
dill
path
=
f
"
{
PATH
}
/
{
file_name
}{
FILE_SUFFIX
}
"
path
=
f
"
{
PATH
}
/
{
file_name
}{
FILE_SUFFIX
}
"
with
open
(
path
,
"rb"
)
as
file
:
with
open
(
path
,
"rb"
)
as
file
:
...
@@ -39,6 +41,8 @@ def load_from_cache(file_name: str, cache: bool = False):
...
@@ -39,6 +41,8 @@ def load_from_cache(file_name: str, cache: bool = False):
def
save_to_cache
(
file_name
,
obj
):
def
save_to_cache
(
file_name
,
obj
):
import
dill
if
not
os
.
path
.
exists
(
PATH
):
if
not
os
.
path
.
exists
(
PATH
):
os
.
mkdir
(
PATH
)
os
.
mkdir
(
PATH
)
...
@@ -57,3 +61,152 @@ def delete_cache(key: str = ""):
...
@@ -57,3 +61,152 @@ def delete_cache(key: str = ""):
if
file
.
startswith
(
key
)
and
file
.
endswith
(
FILE_SUFFIX
):
if
file
.
startswith
(
key
)
and
file
.
endswith
(
FILE_SUFFIX
):
file_path
=
f
"
{
PATH
}
/
{
file
}
"
file_path
=
f
"
{
PATH
}
/
{
file
}
"
os
.
unlink
(
file_path
)
os
.
unlink
(
file_path
)
def
_build_cache_key
(
task
:
str
,
num_fewshot
:
int
,
rank
:
int
,
world_size
:
int
,
apply_chat_template
:
bool
,
fewshot_as_multiturn
:
bool
,
system_instruction
:
Optional
[
str
],
tokenizer_name
:
str
,
)
->
str
:
"""Build cache key from parameters"""
cache_key
=
f
"requests-
{
task
}
-
{
num_fewshot
}
shot-rank
{
rank
}
-world_size
{
world_size
}
"
if
apply_chat_template
:
cache_key
+=
"-chat_template"
if
fewshot_as_multiturn
:
cache_key
+=
"-fewshot_as_multiturn"
if
system_instruction
is
not
None
:
# Import utils here to avoid circular imports
import
utils
cache_key
+=
f
"-system_prompt_hash
{
utils
.
hash_string
(
system_instruction
)
}
"
cache_key
+=
f
"-tokenizer
{
tokenizer_name
}
"
return
cache_key
def
cache_instances
(
func
):
"""Decorator to handle request caching for build_all_requests"""
@
wraps
(
func
)
def
wrapper
(
self
,
*
,
limit
:
Union
[
int
,
None
]
=
None
,
samples
:
Optional
[
List
[
int
]]
=
None
,
rank
:
int
=
0
,
world_size
:
int
=
1
,
cache_requests
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
system_instruction
:
Optional
[
str
]
=
None
,
apply_chat_template
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
chat_template
:
Optional
[
Callable
]
=
None
,
tokenizer_name
:
str
=
""
,
**
kwargs
,
):
# If caching is disabled, just call the original function
# The method will handle setting self._instances
if
not
cache_requests
:
return
func
(
self
,
limit
=
limit
,
samples
=
samples
,
rank
=
rank
,
world_size
=
world_size
,
cache_requests
=
cache_requests
,
rewrite_requests_cache
=
rewrite_requests_cache
,
system_instruction
=
system_instruction
,
apply_chat_template
=
apply_chat_template
,
fewshot_as_multiturn
=
fewshot_as_multiturn
,
chat_template
=
chat_template
,
tokenizer_name
=
tokenizer_name
,
**
kwargs
,
)
# Build cache key
cache_key
=
_build_cache_key
(
self
.
_config
.
task
,
self
.
config
.
num_fewshot
,
rank
,
world_size
,
apply_chat_template
,
fewshot_as_multiturn
,
system_instruction
,
tokenizer_name
,
)
# Try to load from cache
cached_instances
=
load_from_cache
(
file_name
=
cache_key
,
cache
=
cache_requests
)
# Return cached instances if available and not rewriting
if
cached_instances
and
not
rewrite_requests_cache
:
cached_instances
=
(
cached_instances
[:
limit
]
if
limit
is
not
None
else
cached_instances
)
flattened_instances
=
[
instance
for
instance_group
in
cached_instances
for
instance
in
instance_group
]
self
.
_instances
=
flattened_instances
eval_logger
.
debug
(
f
"Using
{
len
(
flattened_instances
)
}
contexts for
{
self
.
config
.
task
}
on rank
{
rank
}
..."
)
return
# Store original limit for later use
original_limit
=
limit
# Process all documents when caching for simplicity
if
limit
is
not
None
:
limit
=
None
# Call the original function with modified parameters
instances
=
func
(
self
,
limit
=
limit
,
samples
=
samples
,
rank
=
rank
,
world_size
=
world_size
,
cache_requests
=
cache_requests
,
rewrite_requests_cache
=
rewrite_requests_cache
,
system_instruction
=
system_instruction
,
apply_chat_template
=
apply_chat_template
,
fewshot_as_multiturn
=
fewshot_as_multiturn
,
chat_template
=
chat_template
,
tokenizer_name
=
tokenizer_name
,
**
kwargs
,
)
# Check if method handled everything (non-cache mode returns None)
if
instances
is
None
:
return
# Apply original limit if specified
sliced_instances
=
(
instances
[:
original_limit
]
if
original_limit
is
not
None
else
instances
)
# Flatten and set instances
flattened_instances
=
[
instance
for
instance_group
in
sliced_instances
for
instance
in
instance_group
]
self
.
_instances
=
flattened_instances
# Validate results
if
len
(
self
.
_instances
)
==
0
:
raise
ValueError
(
"task.build_requests() did not find any docs!"
)
# Save to cache if we generated new instances
if
not
cached_instances
or
rewrite_requests_cache
:
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
return
wrapper
lm_eval/evaluator.py
View file @
f19408c3
...
@@ -241,15 +241,7 @@ def simple_evaluate(
...
@@ -241,15 +241,7 @@ def simple_evaluate(
if
use_cache
is
not
None
:
if
use_cache
is
not
None
:
eval_logger
.
info
(
f
"Using cache at
{
use_cache
+
'_rank'
+
str
(
lm
.
rank
)
+
'.db'
}
"
)
eval_logger
.
info
(
f
"Using cache at
{
use_cache
+
'_rank'
+
str
(
lm
.
rank
)
+
'.db'
}
"
)
lm
=
lm_eval
.
api
.
model
.
CachingLM
(
lm
=
lm_eval
.
api
.
model
.
CachingLM
(
lm
,
use_cache
)
lm
,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+
"_rank"
+
str
(
lm
.
rank
)
+
".db"
,
)
if
task_manager
is
None
:
if
task_manager
is
None
:
metadata
=
(
metadata
=
(
...
...
lm_eval/utils.py
View file @
f19408c3
...
@@ -208,14 +208,14 @@ def sanitize_model_name(model_name: str) -> str:
...
@@ -208,14 +208,14 @@ def sanitize_model_name(model_name: str) -> str:
"""
"""
Given the model name, returns a sanitized version of it.
Given the model name, returns a sanitized version of it.
"""
"""
return
re
.
sub
(
r
"[\"<>:/
\
|\\?
\
*\[\]]+"
,
"__"
,
model_name
)
return
re
.
sub
(
r
"[\"<>:/|\\?*\[\]]+"
,
"__"
,
model_name
)
def
sanitize_task_name
(
task_name
:
str
)
->
str
:
def
sanitize_task_name
(
task_name
:
str
)
->
str
:
"""
"""
Given the task name, returns a sanitized version of it.
Given the task name, returns a sanitized version of it.
"""
"""
return
re
.
sub
(
r
"\W"
,
"_"
,
task_name
)
return
re
.
sub
(
r
"\W
+
"
,
"_"
,
task_name
)
def
get_latest_filename
(
filenames
:
List
[
str
])
->
str
:
def
get_latest_filename
(
filenames
:
List
[
str
])
->
str
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment