Commit 5b8a7506 authored by Baber's avatar Baber
Browse files

remove other schemas. work on metrics

parent ba1d4483
......@@ -3,7 +3,6 @@ from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance
from lm_eval.api.schemas import GenerateOutput
class Filter(ABC):
......@@ -47,13 +46,13 @@ class FilterEnsemble:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
# TODO: add backward
# unwrap responses from GenerateOutput as the filters expect strings
resps = tuple(
[
item.text if isinstance(item, GenerateOutput) else str(item)
for item in sublist
]
for sublist in resps
)
# resps = tuple(
# [
# item.text if isinstance(item, GenerateOutput) else item
# for item in sublist
# ]
# for sublist in resps
# )
for f in self.filters:
# apply filters in sequence
......
from dataclasses import dataclass, field
from typing import Generic, Literal, Optional, Tuple, TypeVar, Union
from typing import Literal, Optional, Tuple
from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput
# from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
T = TypeVar("T", LoglikelihoodInput, GenerateInput)
# T = TypeVar("T", LoglikelihoodInput, GenerateInput)
@dataclass
class Instance(Generic[T]):
class Instance:
request_type: OutputType
doc: dict
arguments: T
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
)
resps: list[Union[GenerateInput, LoglikelihoodInput]] = field(default_factory=list)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
......@@ -33,7 +34,7 @@ class Instance(Generic[T]):
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self) -> T:
def args(self):
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
......
......@@ -9,10 +9,12 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.schemas import (
LoglikelihoodInput,
LoglikelihoodOutput,
)
# from lm_eval.api.schemas import (
# LoglikelihoodInput,
# LoglikelihoodOutput,
# )
if TYPE_CHECKING:
......@@ -381,9 +383,7 @@ class TemplateLM(LM):
self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
new_reqs = []
for context, continuation in (
(req.args.context, req.args.continuation) for req in requests
):
for context, continuation in [req.args for req in requests]:
if context == "":
# BOS or EOS as context
context_enc, continuation_enc = (
......
......@@ -2,61 +2,60 @@ from dataclasses import dataclass
from typing import Optional
@dataclass
class GenerateInput:
"""
Inputs for the generate function.
"""
prompt: str
gen_kwargs: dict
multimodal_arg: Optional[dict] = None
def __iter__(self):
return (
iter((self.prompt, self.gen_kwargs))
if not self.multimodal_arg
else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
)
def __getitem__(self, item: int):
return [self.prompt, self.gen_kwargs][item]
@dataclass
class GenerateOutput:
"""
Outputs for the generate function.
"""
text: str
metadata: dict = None
@dataclass
class LoglikelihoodInput:
"""
Inputs for the loglikelihood function.
"""
context: str
continuation: Optional[str] = None
@dataclass
class LoglikelihoodOutput:
"""
Outputs for the loglikelihood function.
"""
loglikelihood: float
is_greedy: Optional[bool] = None
ctx_tokens: Optional[list[int]] = None
cont_tokens: Optional[list[int]] = None
metadata: Optional[dict] = None
def __iter__(self):
return iter((self.loglikelihood, self.is_greedy))
# @dataclass
# class GenerateInput:
# """
# Inputs for the generate function.
# """
#
# prompt: str
# gen_kwargs: dict
# multimodal_arg: Optional[dict] = None
#
# def __iter__(self):
# return (
# iter((self.prompt, self.gen_kwargs))
# if not self.multimodal_arg
# else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
# )
#
# def __getitem__(self, item: int):
# return [self.prompt, self.gen_kwargs][item]
#
#
# @dataclass
# class GenerateOutput:
# """
# Outputs for the generate function.
# """
#
# text: str
# metadata: dict = None
#
#
# @dataclass
# class LoglikelihoodInput:
# """
# Inputs for the loglikelihood function.
# """
#
# context: str
# continuation: Optional[str] = None
#
#
# class LoglikelihoodOutput(NamedTuple):
# """
# Outputs for the loglikelihood function.
# """
#
# loglikelihood: float
# is_greedy: Optional[bool] = None
# ctx_tokens: Optional[list[int]] = None
# cont_tokens: Optional[list[int]] = None
# metadata: Optional[dict] = None
# def __iter__(self):
# return iter((self.loglikelihood, self.is_greedy))
@dataclass
......@@ -66,7 +65,7 @@ class MetricResult:
"""
doc_id: str | int | None
scores: list[dict[str, float]] | None
scores: list[dict[str, float]] | dict
filter_key: str = None
metric_name: str = None
metadata: Optional[dict] = None
......@@ -76,6 +75,8 @@ class MetricResult:
return iter([])
# Group values by metric key
if not isinstance(self.scores, list):
self.scores = [self.scores]
grouped = {}
for score_dict in self.scores:
for key, value in score_dict.items():
......@@ -99,4 +100,8 @@ class MetricResult:
def metric_keys(self) -> list[str]:
if self.scores is None:
return []
return list(self.scores[0].keys()) if self.scores else []
return (
list(self.scores[0].keys())
if isinstance(self.scores, list)
else list(self.scores.keys())
)
......@@ -37,7 +37,7 @@ from lm_eval.api.registry import (
get_metric_aggregation,
is_higher_better,
)
from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput, MetricResult
from lm_eval.api.schemas import MetricResult
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
......@@ -1531,7 +1531,8 @@ class ConfigurableTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=LoglikelihoodInput(context=arg[0], continuation=arg[1]),
arguments=arg,
# arguments=LoglikelihoodInput(context=arg[0], continuation=arg[1]),
idx=i,
**kwargs,
)
......@@ -1543,9 +1544,9 @@ class ConfigurableTask(Task):
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=LoglikelihoodInput(*arguments)
if self.OUTPUT_TYPE in ["loglikelihood", "loglikelihood_rolling"]
else GenerateInput(*arguments),
arguments=arguments,
# if self.OUTPUT_TYPE in ["loglikelihood", "loglikelihood_rolling"]
# else GenerateInput(*arguments),
idx=0,
**kwargs,
)
......@@ -1819,15 +1820,21 @@ class ConfigurableTask(Task):
for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
metrics = [
self.process_results(doc, response)
for req in requests
for response in (
req.filtered_resps[filter_key]
if isinstance(req.filtered_resps[filter_key], list)
else [req.filtered_resps[filter_key]]
if len(requests) > 1:
# if one doc has multiple instances then calculate metric together
metrics = self.process_results(
doc, [req.filtered_resps[filter_key] for req in requests]
)
]
else:
metrics = [
self.process_results(doc, response)
for req in requests
for response in (
req.filtered_resps[filter_key]
if isinstance(req.filtered_resps[filter_key], list)
else [req.filtered_resps[filter_key]]
)
]
all_metrics[filter_key].append(
MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key)
)
......
......@@ -647,9 +647,7 @@ def evaluate(
ensure_ascii=False,
)
),
"prompt_hash": hash_string(
requests[0].arguments.prompt
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update(
......
......@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return map(lambda r: r, resps)
return map(lambda r: r[0], resps)
@register_filter("take_first_k")
......
......@@ -27,7 +27,6 @@ from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.api.schemas import GenerateInput, GenerateOutput, LoglikelihoodOutput
from lm_eval.models.utils import (
Collator,
clear_torch_cache,
......@@ -966,7 +965,7 @@ class HFLM(TemplateLM):
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> list[LoglikelihoodOutput]:
) -> List[float]:
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
......@@ -1026,7 +1025,7 @@ class HFLM(TemplateLM):
override_bs=len(batch_windows),
)
# Store results with their request indices
all_nlls.extend(zip(batch_indices, (x.loglikelihood for x in batch_nlls)))
all_nlls.extend(zip(batch_indices, batch_nlls))
# Remove padding if necessary
if (self.world_size > 1) and (pad_amnt > 0):
......@@ -1039,8 +1038,8 @@ class HFLM(TemplateLM):
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll for nll in request_nlls)
loglikelihoods.append(LoglikelihoodOutput(loglikelihood=request_total))
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
......@@ -1072,7 +1071,7 @@ class HFLM(TemplateLM):
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[LoglikelihoodOutput]:
) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......@@ -1287,13 +1286,7 @@ class HFLM(TemplateLM):
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(
LoglikelihoodOutput(
*answer,
ctx_tokens=ctx_tokens,
cont_tokens=cont_toks.tolist(),
)
)
res.append(answer)
if request_str is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
......@@ -1309,8 +1302,8 @@ class HFLM(TemplateLM):
return re_ord.get_original(res)
def generate_until(
self, requests: List[Instance[GenerateInput]], disable_tqdm: bool = False
) -> List[GenerateOutput]:
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
def _collate(req: Tuple[str, dict]):
......@@ -1321,8 +1314,8 @@ class HFLM(TemplateLM):
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(req.prompt)
return -len(toks), req.prompt
toks = self.tok_encode(req[0])
return -len(toks), req[0]
pbar = tqdm(
total=len(requests),
......@@ -1358,7 +1351,7 @@ class HFLM(TemplateLM):
[reg.args for reg in requests],
sort_fn=_collate,
group_by="gen_kwargs",
group_fn=lambda x: x.gen_kwargs,
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
......@@ -1427,7 +1420,7 @@ class HFLM(TemplateLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
res.append(GenerateOutput(text=s))
res.append(s)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
pbar.update(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment