Commit 451e73f1 authored by Baber's avatar Baber
Browse files

add classes to inputs/outputs

parent e30978c7
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple from typing import Generic, Literal, Optional, Tuple, TypeVar
from lm_eval.api.types import GenerateInput, LoglikelihoodInput
OutputType = Literal[ OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice" "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
] ]
T = TypeVar("T", LoglikelihoodInput, GenerateInput)
@dataclass @dataclass
class Instance: class Instance(Generic[T]):
request_type: OutputType request_type: OutputType
doc: dict doc: dict
arguments: tuple arguments: T
idx: int idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None) default_factory=lambda: (None, None, None)
...@@ -29,10 +33,8 @@ class Instance: ...@@ -29,10 +33,8 @@ class Instance:
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
@property @property
def args(self): def args(self) -> T:
""" """
Returns (string,) where `string` is the string to calculate loglikelihood over Returns (string,) where `string` is the string to calculate loglikelihood over
""" """
return ( return self.arguments
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
)
...@@ -8,6 +8,11 @@ from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union ...@@ -8,6 +8,11 @@ from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.types import (
LoglikelihoodInput,
LoglikelihoodOutput,
)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -34,7 +39,7 @@ class LM(abc.ABC): ...@@ -34,7 +39,7 @@ class LM(abc.ABC):
self.cache_hook: "CacheHook" = CacheHook(None) self.cache_hook: "CacheHook" = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]: def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible. LM calls whenever possible.
...@@ -59,7 +64,7 @@ class LM(abc.ABC): ...@@ -59,7 +64,7 @@ class LM(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...@@ -101,7 +106,7 @@ class LM(abc.ABC): ...@@ -101,7 +106,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length # TODO: Add an optional max length
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests) -> list[str]: def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
...@@ -376,7 +381,9 @@ class TemplateLM(LM): ...@@ -376,7 +381,9 @@ class TemplateLM(LM):
self, requests: list["Instance"], disable_tqdm: bool = False self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[tuple[float, bool]]: ) -> list[tuple[float, bool]]:
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in (
(req.args.context, req.args.continuation) for req in requests
):
if context == "": if context == "":
# BOS or EOS as context # BOS or EOS as context
context_enc, continuation_enc = ( context_enc, continuation_enc = (
...@@ -392,12 +399,14 @@ class TemplateLM(LM): ...@@ -392,12 +399,14 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> list[float]: ) -> list[float]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: def generate_until(
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
pass pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......
...@@ -36,6 +36,7 @@ from lm_eval.api.registry import ( ...@@ -36,6 +36,7 @@ from lm_eval.api.registry import (
get_metric_aggregation, get_metric_aggregation,
is_higher_better, is_higher_better,
) )
from lm_eval.api.types import GenerateInput, LoglikelihoodInput
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
...@@ -1493,6 +1494,13 @@ class ConfigurableTask(Task): ...@@ -1493,6 +1494,13 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs)) arguments = (ctx, deepcopy(self.config.generation_kwargs))
else:
raise ValueError(
f"Unsupported OUTPUT_TYPE: '{self.OUTPUT_TYPE}'. "
f"Expected one of: 'loglikelihood', 'loglikelihood_rolling', "
f"'multiple_choice', 'generate_until'"
)
multimodal_arg = {} multimodal_arg = {}
if ( if (
self.config.doc_to_image self.config.doc_to_image
...@@ -1521,7 +1529,7 @@ class ConfigurableTask(Task): ...@@ -1521,7 +1529,7 @@ class ConfigurableTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=arg, arguments=LoglikelihoodInput(context=arg[0], continuation=arg[1]),
idx=i, idx=i,
**kwargs, **kwargs,
) )
...@@ -1533,7 +1541,9 @@ class ConfigurableTask(Task): ...@@ -1533,7 +1541,9 @@ class ConfigurableTask(Task):
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
doc=doc, doc=doc,
arguments=arguments, arguments=LoglikelihoodInput(*arguments)
if self.OUTPUT_TYPE in ["loglikelihood", "loglikelihood_rolling"]
else GenerateInput(*arguments),
idx=0, idx=0,
**kwargs, **kwargs,
) )
...@@ -1846,7 +1856,7 @@ class MultipleChoiceTask(Task): ...@@ -1846,7 +1856,7 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task): class PerplexityTask(Task):
OUTPUT_TYPE = "loglikelihood_rolling" OUTPUT_TYPE: OutputType = "loglikelihood_rolling"
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
return False return False
......
from dataclasses import dataclass
from typing import Optional
@dataclass
class GenerateInput:
"""
Inputs for the generate function.
"""
prompt: str
gen_kwargs: dict
multimodal_arg: Optional[dict] = None
@dataclass
class GenerateOutput:
"""
Outputs for the generate function.
"""
text: str
metadata: dict = None
@dataclass
class LoglikelihoodInput:
"""
Inputs for the loglikelihood function.
"""
context: str
continuation: Optional[str] = None
@dataclass
class LoglikelihoodOutput:
"""
Outputs for the loglikelihood function.
"""
loglikelihood: float
is_greedy: Optional[bool] = None
ctx_tokens: Optional[list[int]] = None
cont_tokens: Optional[list[int]] = None
metadata: Optional[dict] = None
def __iter__(self):
return iter((self.loglikelihood, self.is_greedy))
...@@ -560,6 +560,8 @@ def evaluate( ...@@ -560,6 +560,8 @@ def evaluate(
# create `K` copies of each request `req` based off `K = req.repeats` # create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = [] cloned_reqs = []
for req in reqs: for req in reqs:
# Note: [req] * req.repeats creates multiple references to the same request object,
# not separate copies. This means all repeated entries point to the same req.resps list
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
if (lm.world_size > 1) and (padding_requests[reqtype] > 0): if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
...@@ -567,6 +569,8 @@ def evaluate( ...@@ -567,6 +569,8 @@ def evaluate(
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
# run requests through model # run requests through model
# Since cloned_reqs contains references to original objects, each response
# automatically gets appended to the correct req.resps list
resps = getattr(lm, reqtype)(cloned_reqs) resps = getattr(lm, reqtype)(cloned_reqs)
# put responses from model into a list of length K for each request. # put responses from model into a list of length K for each request.
......
...@@ -27,6 +27,7 @@ from lm_eval import utils ...@@ -27,6 +27,7 @@ from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.api.types import GenerateInput, GenerateOutput, LoglikelihoodOutput
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
clear_torch_cache, clear_torch_cache,
...@@ -965,7 +966,7 @@ class HFLM(TemplateLM): ...@@ -965,7 +966,7 @@ class HFLM(TemplateLM):
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[LoglikelihoodOutput]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -1025,7 +1026,7 @@ class HFLM(TemplateLM): ...@@ -1025,7 +1026,7 @@ class HFLM(TemplateLM):
override_bs=len(batch_windows), override_bs=len(batch_windows),
) )
# Store results with their request indices # Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls)) all_nlls.extend(zip(batch_indices, (x.loglikelihood for x in batch_nlls)))
# Remove padding if necessary # Remove padding if necessary
if (self.world_size > 1) and (pad_amnt > 0): if (self.world_size > 1) and (pad_amnt > 0):
...@@ -1038,8 +1039,8 @@ class HFLM(TemplateLM): ...@@ -1038,8 +1039,8 @@ class HFLM(TemplateLM):
# Get all nlls for this request # Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count] request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy) # Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls) request_total = sum(nll for nll in request_nlls)
loglikelihoods.append(request_total) loglikelihoods.append(LoglikelihoodOutput(loglikelihood=request_total))
current_idx += window_count current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0] string = requests[len(loglikelihoods) - 1].args[0]
...@@ -1071,7 +1072,7 @@ class HFLM(TemplateLM): ...@@ -1071,7 +1072,7 @@ class HFLM(TemplateLM):
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
override_bs: int = None, override_bs: int = None,
) -> List[Tuple[float, bool]]: ) -> List[LoglikelihoodOutput]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
...@@ -1286,7 +1287,13 @@ class HFLM(TemplateLM): ...@@ -1286,7 +1287,13 @@ class HFLM(TemplateLM):
# Answer: (log prob, is-exact-match) # Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
res.append(answer) res.append(
LoglikelihoodOutput(
*answer,
ctx_tokens=ctx_tokens,
cont_tokens=cont_toks.tolist(),
)
)
if request_str is not None: if request_str is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests # special case: loglikelihood_rolling produces a number of loglikelihood requests
...@@ -1302,8 +1309,8 @@ class HFLM(TemplateLM): ...@@ -1302,8 +1309,8 @@ class HFLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance[GenerateInput]], disable_tqdm: bool = False
) -> List[str]: ) -> List[GenerateOutput]:
res = [] res = []
def _collate(req: Tuple[str, dict]): def _collate(req: Tuple[str, dict]):
...@@ -1420,7 +1427,7 @@ class HFLM(TemplateLM): ...@@ -1420,7 +1427,7 @@ class HFLM(TemplateLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = '' # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0] s = s.split(term)[0]
res.append(s) res.append(GenerateOutput(text=s))
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
pbar.update(1) pbar.update(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment