Commit f19408c3 authored by Baber's avatar Baber
Browse files

feat: implement caching decorator for request handling and improve cache key generation

parent 7aaceeec
...@@ -2,7 +2,7 @@ import abc ...@@ -2,7 +2,7 @@ import abc
import hashlib import hashlib
import json import json
import logging import logging
import os from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
import transformers import transformers
...@@ -230,7 +230,7 @@ class CacheHook: ...@@ -230,7 +230,7 @@ class CacheHook:
class CachingLM: class CachingLM:
def __init__(self, lm, cache_db) -> None: def __init__(self, lm: "LM", cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not. """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM :param lm: LM
...@@ -239,12 +239,22 @@ class CachingLM: ...@@ -239,12 +239,22 @@ class CachingLM:
Path to cache db Path to cache db
""" """
self.lm = lm self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm # Setup cache path
cache_path = Path(cache_db)
if cache_path.is_dir() or (not cache_path.suffix and not cache_path.exists()):
cache_path = cache_path / "cache.db"
self.cache_db = cache_path
cache_path.parent.mkdir(parents=True, exist_ok=True)
# Initialize database with WAL mode for concurrent access
self.dbdict = SqliteDict(str(cache_path), autocommit=True, timeout=30.0)
# Enable WAL mode for better concurrency
self.dbdict.conn.execute("PRAGMA journal_mode=WAL")
self.dbdict.conn.commit()
lm.set_cache_hook(self.get_cache_hook()) lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr: str): def __getattr__(self, attr: str):
......
...@@ -36,7 +36,7 @@ from lm_eval.api.registry import ( ...@@ -36,7 +36,7 @@ from lm_eval.api.registry import (
get_metric_aggregation, get_metric_aggregation,
is_higher_better, is_higher_better,
) )
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import cache_instances
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
...@@ -387,6 +387,7 @@ class Task(abc.ABC): ...@@ -387,6 +387,7 @@ class Task(abc.ABC):
def doc_to_prefix(self, doc): def doc_to_prefix(self, doc):
return "" return ""
@cache_instances
def build_all_requests( def build_all_requests(
self, self,
*, *,
...@@ -394,68 +395,28 @@ class Task(abc.ABC): ...@@ -394,68 +395,28 @@ class Task(abc.ABC):
samples: Optional[List[int]] = None, samples: Optional[List[int]] = None,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Optional[Callable] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
tokenizer_name: str = "", tokenizer_name: str = "",
) -> None: ) -> Optional[List[List[Instance]]]:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
# used with caching
og_limit = limit
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
cache_key += "-chat_template" if apply_chat_template else ""
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
cache_key += (
f"-system_prompt_hash{utils.hash_string(system_instruction)}"
if system_instruction is not None
else ""
)
cache_key += f"-tokenizer{tokenizer_name}"
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
if cache_requests and cached_instances and not rewrite_requests_cache:
cached_instances = cached_instances[:limit]
flattened_instances = [
instance
for instance_group in cached_instances
for instance in instance_group
]
self._instances = flattened_instances
return
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
instances = [] instances = []
# process all documents when caching is specified for simplicity
if (
cache_requests
and (not cached_instances or rewrite_requests_cache)
and limit is not None
):
limit = None
doc_id_docs = list( doc_id_docs = list(
self.doc_iterator( self.doc_iterator(
rank=rank, limit=limit, samples=samples, world_size=world_size rank=rank, limit=limit, samples=samples, world_size=world_size
) )
) )
num_docs = len(doc_id_docs) for doc_id, doc in tqdm(doc_id_docs, total=len(doc_id_docs)):
# sample fewshot context
for doc_id, doc in tqdm(
doc_id_docs,
total=num_docs,
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot, 0 if self.config.num_fewshot is None else self.config.num_fewshot,
...@@ -466,7 +427,7 @@ class Task(abc.ABC): ...@@ -466,7 +427,7 @@ class Task(abc.ABC):
gen_prefix=self.doc_to_prefix(doc), gen_prefix=self.doc_to_prefix(doc),
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # construct requests
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
...@@ -480,23 +441,25 @@ class Task(abc.ABC): ...@@ -480,23 +441,25 @@ class Task(abc.ABC):
instances.append(inst) instances.append(inst)
# now flatten, this is to allow slicing to work with pickles # Handle non-caching case
if not cache_requests:
# Apply limit at document level, then flatten
if limit is not None:
instances = instances[:limit]
sliced_instances = instances[:og_limit] flattened_instances = [
instance for instance_group in instances for instance in instance_group
]
flattened_instances = [ self._instances = flattened_instances
instance
for instance_group in sliced_instances
for instance in instance_group
]
self._instances = flattened_instances if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
if len(self._instances) == 0: return None
raise ValueError("task.build_requests() did not find any docs!")
if cache_requests and (not cached_instances or rewrite_requests_cache): # Return instances for decorator to handle
save_to_cache(file_name=cache_key, obj=instances) return instances
@abc.abstractmethod @abc.abstractmethod
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
......
import hashlib import hashlib
import logging import logging
import os import os
from functools import wraps
import dill from typing import Callable, List, Optional, Union
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -27,6 +27,8 @@ def load_from_cache(file_name: str, cache: bool = False): ...@@ -27,6 +27,8 @@ def load_from_cache(file_name: str, cache: bool = False):
if not cache: if not cache:
return return
try: try:
import dill
path = f"{PATH}/{file_name}{FILE_SUFFIX}" path = f"{PATH}/{file_name}{FILE_SUFFIX}"
with open(path, "rb") as file: with open(path, "rb") as file:
...@@ -39,6 +41,8 @@ def load_from_cache(file_name: str, cache: bool = False): ...@@ -39,6 +41,8 @@ def load_from_cache(file_name: str, cache: bool = False):
def save_to_cache(file_name, obj): def save_to_cache(file_name, obj):
import dill
if not os.path.exists(PATH): if not os.path.exists(PATH):
os.mkdir(PATH) os.mkdir(PATH)
...@@ -57,3 +61,152 @@ def delete_cache(key: str = ""): ...@@ -57,3 +61,152 @@ def delete_cache(key: str = ""):
if file.startswith(key) and file.endswith(FILE_SUFFIX): if file.startswith(key) and file.endswith(FILE_SUFFIX):
file_path = f"{PATH}/{file}" file_path = f"{PATH}/{file}"
os.unlink(file_path) os.unlink(file_path)
def _build_cache_key(
task: str,
num_fewshot: int,
rank: int,
world_size: int,
apply_chat_template: bool,
fewshot_as_multiturn: bool,
system_instruction: Optional[str],
tokenizer_name: str,
) -> str:
"""Build cache key from parameters"""
cache_key = f"requests-{task}-{num_fewshot}shot-rank{rank}-world_size{world_size}"
if apply_chat_template:
cache_key += "-chat_template"
if fewshot_as_multiturn:
cache_key += "-fewshot_as_multiturn"
if system_instruction is not None:
# Import utils here to avoid circular imports
import utils
cache_key += f"-system_prompt_hash{utils.hash_string(system_instruction)}"
cache_key += f"-tokenizer{tokenizer_name}"
return cache_key
def cache_instances(func):
"""Decorator to handle request caching for build_all_requests"""
@wraps(func)
def wrapper(
self,
*,
limit: Union[int, None] = None,
samples: Optional[List[int]] = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
**kwargs,
):
# If caching is disabled, just call the original function
# The method will handle setting self._instances
if not cache_requests:
return func(
self,
limit=limit,
samples=samples,
rank=rank,
world_size=world_size,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
tokenizer_name=tokenizer_name,
**kwargs,
)
# Build cache key
cache_key = _build_cache_key(
self._config.task,
self.config.num_fewshot,
rank,
world_size,
apply_chat_template,
fewshot_as_multiturn,
system_instruction,
tokenizer_name,
)
# Try to load from cache
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
# Return cached instances if available and not rewriting
if cached_instances and not rewrite_requests_cache:
cached_instances = (
cached_instances[:limit] if limit is not None else cached_instances
)
flattened_instances = [
instance
for instance_group in cached_instances
for instance in instance_group
]
self._instances = flattened_instances
eval_logger.debug(
f"Using {len(flattened_instances)}contexts for {self.config.task} on rank {rank}..."
)
return
# Store original limit for later use
original_limit = limit
# Process all documents when caching for simplicity
if limit is not None:
limit = None
# Call the original function with modified parameters
instances = func(
self,
limit=limit,
samples=samples,
rank=rank,
world_size=world_size,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
tokenizer_name=tokenizer_name,
**kwargs,
)
# Check if method handled everything (non-cache mode returns None)
if instances is None:
return
# Apply original limit if specified
sliced_instances = (
instances[:original_limit] if original_limit is not None else instances
)
# Flatten and set instances
flattened_instances = [
instance
for instance_group in sliced_instances
for instance in instance_group
]
self._instances = flattened_instances
# Validate results
if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
# Save to cache if we generated new instances
if not cached_instances or rewrite_requests_cache:
save_to_cache(file_name=cache_key, obj=instances)
return wrapper
...@@ -241,15 +241,7 @@ def simple_evaluate( ...@@ -241,15 +241,7 @@ def simple_evaluate(
if use_cache is not None: if use_cache is not None:
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}") eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM( lm = lm_eval.api.model.CachingLM(lm, use_cache)
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank"
+ str(lm.rank)
+ ".db",
)
if task_manager is None: if task_manager is None:
metadata = ( metadata = (
......
...@@ -208,14 +208,14 @@ def sanitize_model_name(model_name: str) -> str: ...@@ -208,14 +208,14 @@ def sanitize_model_name(model_name: str) -> str:
""" """
Given the model name, returns a sanitized version of it. Given the model name, returns a sanitized version of it.
""" """
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name) return re.sub(r"[\"<>:/|\\?*\[\]]+", "__", model_name)
def sanitize_task_name(task_name: str) -> str: def sanitize_task_name(task_name: str) -> str:
""" """
Given the task name, returns a sanitized version of it. Given the task name, returns a sanitized version of it.
""" """
return re.sub(r"\W", "_", task_name) return re.sub(r"\W+", "_", task_name)
def get_latest_filename(filenames: List[str]) -> str: def get_latest_filename(filenames: List[str]) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment