Commit f19408c3 authored by Baber's avatar Baber
Browse files

feat: implement caching decorator for request handling and improve cache key generation

parent 7aaceeec
......@@ -2,7 +2,7 @@ import abc
import hashlib
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
import transformers
......@@ -230,7 +230,7 @@ class CacheHook:
class CachingLM:
def __init__(self, lm, cache_db) -> None:
def __init__(self, lm: "LM", cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
......@@ -239,12 +239,22 @@ class CachingLM:
Path to cache db
"""
self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
# Setup cache path
cache_path = Path(cache_db)
if cache_path.is_dir() or (not cache_path.suffix and not cache_path.exists()):
cache_path = cache_path / "cache.db"
self.cache_db = cache_path
cache_path.parent.mkdir(parents=True, exist_ok=True)
# Initialize database with WAL mode for concurrent access
self.dbdict = SqliteDict(str(cache_path), autocommit=True, timeout=30.0)
# Enable WAL mode for better concurrency
self.dbdict.conn.execute("PRAGMA journal_mode=WAL")
self.dbdict.conn.commit()
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr: str):
......
......@@ -36,7 +36,7 @@ from lm_eval.api.registry import (
get_metric_aggregation,
is_higher_better,
)
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.caching.cache import cache_instances
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
......@@ -387,6 +387,7 @@ class Task(abc.ABC):
def doc_to_prefix(self, doc):
return ""
@cache_instances
def build_all_requests(
self,
*,
......@@ -394,68 +395,28 @@ class Task(abc.ABC):
samples: Optional[List[int]] = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
tokenizer_name: str = "",
) -> None:
) -> Optional[List[List[Instance]]]:
"""Build a set of Instances for a task, and store them in task.instances"""
# used with caching
og_limit = limit
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
cache_key += "-chat_template" if apply_chat_template else ""
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
cache_key += (
f"-system_prompt_hash{utils.hash_string(system_instruction)}"
if system_instruction is not None
else ""
)
cache_key += f"-tokenizer{tokenizer_name}"
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
if cache_requests and cached_instances and not rewrite_requests_cache:
cached_instances = cached_instances[:limit]
flattened_instances = [
instance
for instance_group in cached_instances
for instance in instance_group
]
self._instances = flattened_instances
return
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
instances = []
# process all documents when caching is specified for simplicity
if (
cache_requests
and (not cached_instances or rewrite_requests_cache)
and limit is not None
):
limit = None
doc_id_docs = list(
self.doc_iterator(
rank=rank, limit=limit, samples=samples, world_size=world_size
)
)
num_docs = len(doc_id_docs)
for doc_id, doc in tqdm(
doc_id_docs,
total=num_docs,
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
for doc_id, doc in tqdm(doc_id_docs, total=len(doc_id_docs)):
# sample fewshot context
fewshot_ctx = self.fewshot_context(
doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
......@@ -466,7 +427,7 @@ class Task(abc.ABC):
gen_prefix=self.doc_to_prefix(doc),
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
# construct requests
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
......@@ -480,23 +441,25 @@ class Task(abc.ABC):
instances.append(inst)
# now flatten, this is to allow slicing to work with pickles
# Handle non-caching case
if not cache_requests:
# Apply limit at document level, then flatten
if limit is not None:
instances = instances[:limit]
sliced_instances = instances[:og_limit]
flattened_instances = [
instance for instance_group in instances for instance in instance_group
]
flattened_instances = [
instance
for instance_group in sliced_instances
for instance in instance_group
]
self._instances = flattened_instances
self._instances = flattened_instances
if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
return None
if cache_requests and (not cached_instances or rewrite_requests_cache):
save_to_cache(file_name=cache_key, obj=instances)
# Return instances for decorator to handle
return instances
@abc.abstractmethod
def construct_requests(self, doc, ctx, **kwargs):
......
import hashlib
import logging
import os
import dill
from functools import wraps
from typing import Callable, List, Optional, Union
eval_logger = logging.getLogger(__name__)
......@@ -27,6 +27,8 @@ def load_from_cache(file_name: str, cache: bool = False):
if not cache:
return
try:
import dill
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
with open(path, "rb") as file:
......@@ -39,6 +41,8 @@ def load_from_cache(file_name: str, cache: bool = False):
def save_to_cache(file_name, obj):
import dill
if not os.path.exists(PATH):
os.mkdir(PATH)
......@@ -57,3 +61,152 @@ def delete_cache(key: str = ""):
if file.startswith(key) and file.endswith(FILE_SUFFIX):
file_path = f"{PATH}/{file}"
os.unlink(file_path)
def _build_cache_key(
task: str,
num_fewshot: int,
rank: int,
world_size: int,
apply_chat_template: bool,
fewshot_as_multiturn: bool,
system_instruction: Optional[str],
tokenizer_name: str,
) -> str:
"""Build cache key from parameters"""
cache_key = f"requests-{task}-{num_fewshot}shot-rank{rank}-world_size{world_size}"
if apply_chat_template:
cache_key += "-chat_template"
if fewshot_as_multiturn:
cache_key += "-fewshot_as_multiturn"
if system_instruction is not None:
# Import utils here to avoid circular imports
import utils
cache_key += f"-system_prompt_hash{utils.hash_string(system_instruction)}"
cache_key += f"-tokenizer{tokenizer_name}"
return cache_key
def cache_instances(func):
"""Decorator to handle request caching for build_all_requests"""
@wraps(func)
def wrapper(
self,
*,
limit: Union[int, None] = None,
samples: Optional[List[int]] = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
**kwargs,
):
# If caching is disabled, just call the original function
# The method will handle setting self._instances
if not cache_requests:
return func(
self,
limit=limit,
samples=samples,
rank=rank,
world_size=world_size,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
tokenizer_name=tokenizer_name,
**kwargs,
)
# Build cache key
cache_key = _build_cache_key(
self._config.task,
self.config.num_fewshot,
rank,
world_size,
apply_chat_template,
fewshot_as_multiturn,
system_instruction,
tokenizer_name,
)
# Try to load from cache
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
# Return cached instances if available and not rewriting
if cached_instances and not rewrite_requests_cache:
cached_instances = (
cached_instances[:limit] if limit is not None else cached_instances
)
flattened_instances = [
instance
for instance_group in cached_instances
for instance in instance_group
]
self._instances = flattened_instances
eval_logger.debug(
f"Using {len(flattened_instances)}contexts for {self.config.task} on rank {rank}..."
)
return
# Store original limit for later use
original_limit = limit
# Process all documents when caching for simplicity
if limit is not None:
limit = None
# Call the original function with modified parameters
instances = func(
self,
limit=limit,
samples=samples,
rank=rank,
world_size=world_size,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
tokenizer_name=tokenizer_name,
**kwargs,
)
# Check if method handled everything (non-cache mode returns None)
if instances is None:
return
# Apply original limit if specified
sliced_instances = (
instances[:original_limit] if original_limit is not None else instances
)
# Flatten and set instances
flattened_instances = [
instance
for instance_group in sliced_instances
for instance in instance_group
]
self._instances = flattened_instances
# Validate results
if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
# Save to cache if we generated new instances
if not cached_instances or rewrite_requests_cache:
save_to_cache(file_name=cache_key, obj=instances)
return wrapper
......@@ -241,15 +241,7 @@ def simple_evaluate(
if use_cache is not None:
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank"
+ str(lm.rank)
+ ".db",
)
lm = lm_eval.api.model.CachingLM(lm, use_cache)
if task_manager is None:
metadata = (
......
......@@ -208,14 +208,14 @@ def sanitize_model_name(model_name: str) -> str:
"""
Given the model name, returns a sanitized version of it.
"""
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
return re.sub(r"[\"<>:/|\\?*\[\]]+", "__", model_name)
def sanitize_task_name(task_name: str) -> str:
"""
Given the task name, returns a sanitized version of it.
"""
return re.sub(r"\W", "_", task_name)
return re.sub(r"\W+", "_", task_name)
def get_latest_filename(filenames: List[str]) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment