Commit a6b358ca authored by Rayyyyy's avatar Rayyyyy
Browse files

update version

parent ed53d51c
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance
class Filter(ABC):
"""
Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`)
across all instances of a task, and perform operations.
In a single run, one can configure any number of separate filters or lists of filters.
"""
def __init__(self, **kwargs) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
[<filtered resps for instance 0>, <filtered resps for instance 1>]
"""
return resps
@dataclass
class FilterEnsemble:
"""
FilterEnsemble creates a pipeline applying multiple filters.
Its intended usage is to stack multiple post-processing steps in order.
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
pipeline separately.
"""
name: str
filters: List[Callable[[], Filter]]
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
for f in self.filters:
# apply filters in sequence
resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
@dataclass
class Instance:
request_type: OutputType
doc: dict
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: Optional[str] = None
doc_id: Optional[int] = None
repeats: Optional[int] = None
def __post_init__(self) -> None:
# unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self):
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
return (
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
)
This diff is collapsed.
import abc
import hashlib
import json
import logging
import os
from typing import List, Optional, Tuple, Type, TypeVar
import transformers
from sqlitedict import SqliteDict
from tqdm import tqdm
from lm_eval import utils
eval_logger = logging.getLogger("lm-eval")
T = TypeVar("T", bound="LM")
class LM(abc.ABC):
def __init__(self) -> None:
"""Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output
(inputs/outputs should be tokenization-agnostic.)
"""
# set rank and world size to a single process, by default.
self._rank = 0
self._world_size = 1
self.cache_hook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
`continuation: str`
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
`logprob: float`
The log probability of `continuation`.
`isgreedy`:
Whether `continuation` would be generated by greedy sampling from `context`.
"""
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float]]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: BOS/EOS
Max context length: 4
Resulting input/prediction pairs:
INPUT: BOS 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context,).
string: str
String for which we are computing overall loglikelihood
:return: list[tuple[float]]
A list of tuples (logprob,)
logprob: float
The log probability of `context` conditioned on the BOS/EOS token.
Can also be overridden for custom cases by `prefix_token_id`.
"""
pass
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests) -> List[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until).
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list[str]
A list of strings continuation
continuation: str
The generated continuation.
"""
pass
@classmethod
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
@classmethod
def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given arg_obj
Parameters:
- arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config
additional_config = {
k: v for k, v in additional_config.items() if v is not None
}
return cls(**arg_dict, **additional_config)
@property
def rank(self):
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return self._rank
@property
def world_size(self):
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return self._world_size
def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook
### SQLite-based caching of LM responses
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
def __init__(self, cachinglm) -> None:
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res) -> None:
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM:
def __init__(self, lm, cache_db) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
Underlying LM
:param cache_db: str
Path to cache db
"""
self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr):
lm_attr = getattr(self.lm, attr)
if not callable(lm_attr):
return lm_attr
def fn(requests):
res = []
remaining_reqs = []
warned = False
# figure out which ones are cached and which ones are new
eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
)
for req in tqdm(requests, desc="Checking cached requests"):
hsh = hash_args(attr, req.args)
if attr == "generate_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned:
eval_logger.warning(
f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
)
warned = True
res.append(None)
remaining_reqs.append(req)
elif hsh in self.dbdict:
ob = self.dbdict[hsh]
assert ob is not None
res.append(ob)
else:
res.append(None)
remaining_reqs.append(req)
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
while res[resptr] is not None:
resptr += 1
res[resptr] = r
# caching
hsh = hash_args(attr, req.args)
self.dbdict[hsh] = r
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
return CacheHook(self)
class TemplateLM(LM):
"""
A class acting as intermediary between the LM base class
and boilerplate often included in other LM subclasses.
"""
@property
@abc.abstractmethod
def eot_token_id(self):
pass
@property
def prefix_token_id(self):
# it is used as prefix for loglikelihood
return self.eot_token_id
@abc.abstractmethod
def tok_encode(self, string: str, **kwargs):
pass
@abc.abstractmethod
def _loglikelihood_tokens(self, requests, **kwargs):
pass
def _encode_pair(self, context, continuation):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
if model_class == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
else:
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# BOS or EOS as context
context_enc, continuation_enc = (
[self.prefix_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
@abc.abstractmethod
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
pass
@abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
pass
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import hashlib
import os
import dill
from lm_eval.utils import eval_logger
MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
# This should be sufficient for uniqueness
HASH_INPUT = "EleutherAI-lm-evaluation-harness"
HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
def load_from_cache(file_name):
try:
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
with open(path, "rb") as file:
cached_task_dict = dill.loads(file.read())
return cached_task_dict
except Exception:
eval_logger.debug(f"{file_name} is not cached, generating...")
pass
def save_to_cache(file_name, obj):
if not os.path.exists(PATH):
os.mkdir(PATH)
file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
eval_logger.debug(f"Saving {file_path} to cache...")
with open(file_path, "wb") as file:
file.write(dill.dumps(obj))
# NOTE the "key" param is to allow for flexibility
def delete_cache(key: str = ""):
files = os.listdir(PATH)
for file in files:
if file.startswith(key) and file.endswith(FILE_SUFFIX):
file_path = f"{PATH}/{file}"
os.unlink(file_path)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment