Commit 5b8a7506 authored by Baber's avatar Baber
Browse files

remove other schemas. work on metrics

parent ba1d4483
...@@ -3,7 +3,6 @@ from dataclasses import dataclass ...@@ -3,7 +3,6 @@ from dataclasses import dataclass
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.schemas import GenerateOutput
class Filter(ABC): class Filter(ABC):
...@@ -47,13 +46,13 @@ class FilterEnsemble: ...@@ -47,13 +46,13 @@ class FilterEnsemble:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
# TODO: add backward # TODO: add backward
# unwrap responses from GenerateOutput as the filters expect strings # unwrap responses from GenerateOutput as the filters expect strings
resps = tuple( # resps = tuple(
[ # [
item.text if isinstance(item, GenerateOutput) else str(item) # item.text if isinstance(item, GenerateOutput) else item
for item in sublist # for item in sublist
] # ]
for sublist in resps # for sublist in resps
) # )
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
......
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, Literal, Optional, Tuple, TypeVar, Union from typing import Literal, Optional, Tuple
from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput
# from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput
OutputType = Literal[ OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice" "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
] ]
T = TypeVar("T", LoglikelihoodInput, GenerateInput) # T = TypeVar("T", LoglikelihoodInput, GenerateInput)
@dataclass @dataclass
class Instance(Generic[T]): class Instance:
request_type: OutputType request_type: OutputType
doc: dict doc: dict
arguments: T arguments: tuple
idx: int idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None) default_factory=lambda: (None, None, None)
) )
resps: list[Union[GenerateInput, LoglikelihoodInput]] = field(default_factory=list) resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict) filtered_resps: dict = field(default_factory=dict)
# initialized after init # initialized after init
...@@ -33,7 +34,7 @@ class Instance(Generic[T]): ...@@ -33,7 +34,7 @@ class Instance(Generic[T]):
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
@property @property
def args(self) -> T: def args(self):
""" """
Returns (string,) where `string` is the string to calculate loglikelihood over Returns (string,) where `string` is the string to calculate loglikelihood over
""" """
......
...@@ -9,10 +9,12 @@ from tqdm import tqdm ...@@ -9,10 +9,12 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.schemas import (
LoglikelihoodInput,
LoglikelihoodOutput, # from lm_eval.api.schemas import (
) # LoglikelihoodInput,
# LoglikelihoodOutput,
# )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -381,9 +383,7 @@ class TemplateLM(LM): ...@@ -381,9 +383,7 @@ class TemplateLM(LM):
self, requests: list["Instance"], disable_tqdm: bool = False self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[tuple[float, bool]]: ) -> list[tuple[float, bool]]:
new_reqs = [] new_reqs = []
for context, continuation in ( for context, continuation in [req.args for req in requests]:
(req.args.context, req.args.continuation) for req in requests
):
if context == "": if context == "":
# BOS or EOS as context # BOS or EOS as context
context_enc, continuation_enc = ( context_enc, continuation_enc = (
......
...@@ -2,61 +2,60 @@ from dataclasses import dataclass ...@@ -2,61 +2,60 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
@dataclass # @dataclass
class GenerateInput: # class GenerateInput:
""" # """
Inputs for the generate function. # Inputs for the generate function.
""" # """
#
prompt: str # prompt: str
gen_kwargs: dict # gen_kwargs: dict
multimodal_arg: Optional[dict] = None # multimodal_arg: Optional[dict] = None
#
def __iter__(self): # def __iter__(self):
return ( # return (
iter((self.prompt, self.gen_kwargs)) # iter((self.prompt, self.gen_kwargs))
if not self.multimodal_arg # if not self.multimodal_arg
else iter((self.prompt, self.gen_kwargs, self.multimodal_arg)) # else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
) # )
#
def __getitem__(self, item: int): # def __getitem__(self, item: int):
return [self.prompt, self.gen_kwargs][item] # return [self.prompt, self.gen_kwargs][item]
#
#
@dataclass # @dataclass
class GenerateOutput: # class GenerateOutput:
""" # """
Outputs for the generate function. # Outputs for the generate function.
""" # """
#
text: str # text: str
metadata: dict = None # metadata: dict = None
#
#
@dataclass # @dataclass
class LoglikelihoodInput: # class LoglikelihoodInput:
""" # """
Inputs for the loglikelihood function. # Inputs for the loglikelihood function.
""" # """
#
context: str # context: str
continuation: Optional[str] = None # continuation: Optional[str] = None
#
#
@dataclass # class LoglikelihoodOutput(NamedTuple):
class LoglikelihoodOutput: # """
""" # Outputs for the loglikelihood function.
Outputs for the loglikelihood function. # """
""" #
# loglikelihood: float
loglikelihood: float # is_greedy: Optional[bool] = None
is_greedy: Optional[bool] = None # ctx_tokens: Optional[list[int]] = None
ctx_tokens: Optional[list[int]] = None # cont_tokens: Optional[list[int]] = None
cont_tokens: Optional[list[int]] = None # metadata: Optional[dict] = None
metadata: Optional[dict] = None
# def __iter__(self):
def __iter__(self): # return iter((self.loglikelihood, self.is_greedy))
return iter((self.loglikelihood, self.is_greedy))
@dataclass @dataclass
...@@ -66,7 +65,7 @@ class MetricResult: ...@@ -66,7 +65,7 @@ class MetricResult:
""" """
doc_id: str | int | None doc_id: str | int | None
scores: list[dict[str, float]] | None scores: list[dict[str, float]] | dict
filter_key: str = None filter_key: str = None
metric_name: str = None metric_name: str = None
metadata: Optional[dict] = None metadata: Optional[dict] = None
...@@ -76,6 +75,8 @@ class MetricResult: ...@@ -76,6 +75,8 @@ class MetricResult:
return iter([]) return iter([])
# Group values by metric key # Group values by metric key
if not isinstance(self.scores, list):
self.scores = [self.scores]
grouped = {} grouped = {}
for score_dict in self.scores: for score_dict in self.scores:
for key, value in score_dict.items(): for key, value in score_dict.items():
...@@ -99,4 +100,8 @@ class MetricResult: ...@@ -99,4 +100,8 @@ class MetricResult:
def metric_keys(self) -> list[str]: def metric_keys(self) -> list[str]:
if self.scores is None: if self.scores is None:
return [] return []
return list(self.scores[0].keys()) if self.scores else [] return (
list(self.scores[0].keys())
if isinstance(self.scores, list)
else list(self.scores.keys())
)
...@@ -37,7 +37,7 @@ from lm_eval.api.registry import ( ...@@ -37,7 +37,7 @@ from lm_eval.api.registry import (
get_metric_aggregation, get_metric_aggregation,
is_higher_better, is_higher_better,
) )
from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput, MetricResult from lm_eval.api.schemas import MetricResult
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
...@@ -1531,7 +1531,8 @@ class ConfigurableTask(Task): ...@@ -1531,7 +1531,8 @@ class ConfigurableTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=LoglikelihoodInput(context=arg[0], continuation=arg[1]), arguments=arg,
# arguments=LoglikelihoodInput(context=arg[0], continuation=arg[1]),
idx=i, idx=i,
**kwargs, **kwargs,
) )
...@@ -1543,9 +1544,9 @@ class ConfigurableTask(Task): ...@@ -1543,9 +1544,9 @@ class ConfigurableTask(Task):
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
doc=doc, doc=doc,
arguments=LoglikelihoodInput(*arguments) arguments=arguments,
if self.OUTPUT_TYPE in ["loglikelihood", "loglikelihood_rolling"] # if self.OUTPUT_TYPE in ["loglikelihood", "loglikelihood_rolling"]
else GenerateInput(*arguments), # else GenerateInput(*arguments),
idx=0, idx=0,
**kwargs, **kwargs,
) )
...@@ -1819,15 +1820,21 @@ class ConfigurableTask(Task): ...@@ -1819,15 +1820,21 @@ class ConfigurableTask(Task):
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id # doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id] requests = instances_by_doc_id[doc_id]
metrics = [ if len(requests) > 1:
self.process_results(doc, response) # if one doc has multiple instances then calculate metric together
for req in requests metrics = self.process_results(
for response in ( doc, [req.filtered_resps[filter_key] for req in requests]
req.filtered_resps[filter_key]
if isinstance(req.filtered_resps[filter_key], list)
else [req.filtered_resps[filter_key]]
) )
] else:
metrics = [
self.process_results(doc, response)
for req in requests
for response in (
req.filtered_resps[filter_key]
if isinstance(req.filtered_resps[filter_key], list)
else [req.filtered_resps[filter_key]]
)
]
all_metrics[filter_key].append( all_metrics[filter_key].append(
MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key) MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key)
) )
......
...@@ -647,9 +647,7 @@ def evaluate( ...@@ -647,9 +647,7 @@ def evaluate(
ensure_ascii=False, ensure_ascii=False,
) )
), ),
"prompt_hash": hash_string( "prompt_hash": hash_string(requests[0].arguments[0]),
requests[0].arguments.prompt
),
"target_hash": hash_string(str(target)), "target_hash": hash_string(str(target)),
} }
example.update( example.update(
......
...@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter): ...@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
""" """
Assuming each entry of `resps` is a list of model responses, we discard all but the first response. Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
""" """
return map(lambda r: r, resps) return map(lambda r: r[0], resps)
@register_filter("take_first_k") @register_filter("take_first_k")
......
...@@ -27,7 +27,6 @@ from lm_eval import utils ...@@ -27,7 +27,6 @@ from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.api.schemas import GenerateInput, GenerateOutput, LoglikelihoodOutput
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
clear_torch_cache, clear_torch_cache,
...@@ -966,7 +965,7 @@ class HFLM(TemplateLM): ...@@ -966,7 +965,7 @@ class HFLM(TemplateLM):
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> list[LoglikelihoodOutput]: ) -> List[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -1026,7 +1025,7 @@ class HFLM(TemplateLM): ...@@ -1026,7 +1025,7 @@ class HFLM(TemplateLM):
override_bs=len(batch_windows), override_bs=len(batch_windows),
) )
# Store results with their request indices # Store results with their request indices
all_nlls.extend(zip(batch_indices, (x.loglikelihood for x in batch_nlls))) all_nlls.extend(zip(batch_indices, batch_nlls))
# Remove padding if necessary # Remove padding if necessary
if (self.world_size > 1) and (pad_amnt > 0): if (self.world_size > 1) and (pad_amnt > 0):
...@@ -1039,8 +1038,8 @@ class HFLM(TemplateLM): ...@@ -1039,8 +1038,8 @@ class HFLM(TemplateLM):
# Get all nlls for this request # Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count] request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy) # Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll for nll in request_nlls) request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(LoglikelihoodOutput(loglikelihood=request_total)) loglikelihoods.append(request_total)
current_idx += window_count current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0] string = requests[len(loglikelihoods) - 1].args[0]
...@@ -1072,7 +1071,7 @@ class HFLM(TemplateLM): ...@@ -1072,7 +1071,7 @@ class HFLM(TemplateLM):
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
override_bs: int = None, override_bs: int = None,
) -> List[LoglikelihoodOutput]: ) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
...@@ -1287,13 +1286,7 @@ class HFLM(TemplateLM): ...@@ -1287,13 +1286,7 @@ class HFLM(TemplateLM):
# Answer: (log prob, is-exact-match) # Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
res.append( res.append(answer)
LoglikelihoodOutput(
*answer,
ctx_tokens=ctx_tokens,
cont_tokens=cont_toks.tolist(),
)
)
if request_str is not None: if request_str is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests # special case: loglikelihood_rolling produces a number of loglikelihood requests
...@@ -1309,8 +1302,8 @@ class HFLM(TemplateLM): ...@@ -1309,8 +1302,8 @@ class HFLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until( def generate_until(
self, requests: List[Instance[GenerateInput]], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[GenerateOutput]: ) -> List[str]:
res = [] res = []
def _collate(req: Tuple[str, dict]): def _collate(req: Tuple[str, dict]):
...@@ -1321,8 +1314,8 @@ class HFLM(TemplateLM): ...@@ -1321,8 +1314,8 @@ class HFLM(TemplateLM):
# padded context length. this is useful to simplify the batching logic and more importantly to make # padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
toks = self.tok_encode(req.prompt) toks = self.tok_encode(req[0])
return -len(toks), req.prompt return -len(toks), req[0]
pbar = tqdm( pbar = tqdm(
total=len(requests), total=len(requests),
...@@ -1358,7 +1351,7 @@ class HFLM(TemplateLM): ...@@ -1358,7 +1351,7 @@ class HFLM(TemplateLM):
[reg.args for reg in requests], [reg.args for reg in requests],
sort_fn=_collate, sort_fn=_collate,
group_by="gen_kwargs", group_by="gen_kwargs",
group_fn=lambda x: x.gen_kwargs, group_fn=lambda x: x[1],
) )
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
...@@ -1427,7 +1420,7 @@ class HFLM(TemplateLM): ...@@ -1427,7 +1420,7 @@ class HFLM(TemplateLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = '' # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0] s = s.split(term)[0]
res.append(GenerateOutput(text=s)) res.append(s)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
pbar.update(1) pbar.update(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment