Commit 451e73f1 authored by Baber's avatar Baber
Browse files

add classes to inputs/outputs

parent e30978c7
from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple
from typing import Generic, Literal, Optional, Tuple, TypeVar
from lm_eval.api.types import GenerateInput, LoglikelihoodInput
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
T = TypeVar("T", LoglikelihoodInput, GenerateInput)
@dataclass
class Instance:
class Instance(Generic[T]):
request_type: OutputType
doc: dict
arguments: tuple
arguments: T
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
......@@ -29,10 +33,8 @@ class Instance:
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self):
def args(self) -> T:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
return (
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
)
return self.arguments
......@@ -8,6 +8,11 @@ from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.types import (
LoglikelihoodInput,
LoglikelihoodOutput,
)
if TYPE_CHECKING:
......@@ -34,7 +39,7 @@ class LM(abc.ABC):
self.cache_hook: "CacheHook" = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
......@@ -59,7 +64,7 @@ class LM(abc.ABC):
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> list[float]:
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
......@@ -101,7 +106,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests) -> list[str]:
def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
......@@ -376,7 +381,9 @@ class TemplateLM(LM):
self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
for context, continuation in (
(req.args.context, req.args.continuation) for req in requests
):
if context == "":
# BOS or EOS as context
context_enc, continuation_enc = (
......@@ -392,12 +399,14 @@ class TemplateLM(LM):
@abc.abstractmethod
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[float]:
pass
@abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
def generate_until(
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......
......@@ -36,6 +36,7 @@ from lm_eval.api.registry import (
get_metric_aggregation,
is_higher_better,
)
from lm_eval.api.types import GenerateInput, LoglikelihoodInput
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
......@@ -1493,6 +1494,13 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
else:
raise ValueError(
f"Unsupported OUTPUT_TYPE: '{self.OUTPUT_TYPE}'. "
f"Expected one of: 'loglikelihood', 'loglikelihood_rolling', "
f"'multiple_choice', 'generate_until'"
)
multimodal_arg = {}
if (
self.config.doc_to_image
......@@ -1521,7 +1529,7 @@ class ConfigurableTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
arguments=LoglikelihoodInput(context=arg[0], continuation=arg[1]),
idx=i,
**kwargs,
)
......@@ -1533,7 +1541,9 @@ class ConfigurableTask(Task):
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
arguments=LoglikelihoodInput(*arguments)
if self.OUTPUT_TYPE in ["loglikelihood", "loglikelihood_rolling"]
else GenerateInput(*arguments),
idx=0,
**kwargs,
)
......@@ -1846,7 +1856,7 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task):
OUTPUT_TYPE = "loglikelihood_rolling"
OUTPUT_TYPE: OutputType = "loglikelihood_rolling"
def has_training_docs(self) -> bool:
return False
......
from dataclasses import dataclass
from typing import Optional
@dataclass
class GenerateInput:
"""
Inputs for the generate function.
"""
prompt: str
gen_kwargs: dict
multimodal_arg: Optional[dict] = None
@dataclass
class GenerateOutput:
"""
Outputs for the generate function.
"""
text: str
metadata: dict = None
@dataclass
class LoglikelihoodInput:
"""
Inputs for the loglikelihood function.
"""
context: str
continuation: Optional[str] = None
@dataclass
class LoglikelihoodOutput:
"""
Outputs for the loglikelihood function.
"""
loglikelihood: float
is_greedy: Optional[bool] = None
ctx_tokens: Optional[list[int]] = None
cont_tokens: Optional[list[int]] = None
metadata: Optional[dict] = None
def __iter__(self):
return iter((self.loglikelihood, self.is_greedy))
......@@ -560,6 +560,8 @@ def evaluate(
# create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = []
for req in reqs:
# Note: [req] * req.repeats creates multiple references to the same request object,
# not separate copies. This means all repeated entries point to the same req.resps list
cloned_reqs.extend([req] * req.repeats)
if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
......@@ -567,6 +569,8 @@ def evaluate(
cloned_reqs.extend([req] * req.repeats)
# run requests through model
# Since cloned_reqs contains references to original objects, each response
# automatically gets appended to the correct req.resps list
resps = getattr(lm, reqtype)(cloned_reqs)
# put responses from model into a list of length K for each request.
......
......@@ -27,6 +27,7 @@ from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.api.types import GenerateInput, GenerateOutput, LoglikelihoodOutput
from lm_eval.models.utils import (
Collator,
clear_torch_cache,
......@@ -965,7 +966,7 @@ class HFLM(TemplateLM):
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
) -> list[LoglikelihoodOutput]:
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
......@@ -1025,7 +1026,7 @@ class HFLM(TemplateLM):
override_bs=len(batch_windows),
)
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
all_nlls.extend(zip(batch_indices, (x.loglikelihood for x in batch_nlls)))
# Remove padding if necessary
if (self.world_size > 1) and (pad_amnt > 0):
......@@ -1038,8 +1039,8 @@ class HFLM(TemplateLM):
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
request_total = sum(nll for nll in request_nlls)
loglikelihoods.append(LoglikelihoodOutput(loglikelihood=request_total))
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
......@@ -1071,7 +1072,7 @@ class HFLM(TemplateLM):
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
) -> List[LoglikelihoodOutput]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......@@ -1286,7 +1287,13 @@ class HFLM(TemplateLM):
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
res.append(
LoglikelihoodOutput(
*answer,
ctx_tokens=ctx_tokens,
cont_tokens=cont_toks.tolist(),
)
)
if request_str is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
......@@ -1302,8 +1309,8 @@ class HFLM(TemplateLM):
return re_ord.get_original(res)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
self, requests: List[Instance[GenerateInput]], disable_tqdm: bool = False
) -> List[GenerateOutput]:
res = []
def _collate(req: Tuple[str, dict]):
......@@ -1420,7 +1427,7 @@ class HFLM(TemplateLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
res.append(s)
res.append(GenerateOutput(text=s))
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
pbar.update(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment