Commit 173b2bc3 authored by Baber's avatar Baber
Browse files

Merge branch 'main' into humaneval

# Conflicts:
#	lm_eval/api/task.py
parents 74344829 bb098f13
...@@ -3,7 +3,7 @@ import hashlib ...@@ -3,7 +3,7 @@ import hashlib
import json import json
import logging import logging
import os import os
from typing import Dict, List, Optional, Tuple, Type, TypeVar from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
import transformers import transformers
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
...@@ -55,7 +55,7 @@ class LM(abc.ABC): ...@@ -55,7 +55,7 @@ class LM(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: def loglikelihood_rolling(self, requests) -> List[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...@@ -101,14 +101,13 @@ class LM(abc.ABC): ...@@ -101,14 +101,13 @@ class LM(abc.ABC):
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until). A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str context: str
Context string Context string
until: [str] gen_kwargs: dict
The string sequences to generate until. These string sequences A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
may each span across multiple tokens, or may be part of one token.
:return: list[str] :return: list[str]
A list of strings continuation A list of model generated continuations.
continuation: str continuation: str
The generated continuation. The generated continuation.
""" """
...@@ -193,15 +192,13 @@ class LM(abc.ABC): ...@@ -193,15 +192,13 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property." "To use this model with chat templates, please implement the 'tokenizer_name' property."
) )
@property def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self) -> str: """Returns the chat template structure for user/assistant messages if a template is provided.
"""Must be defined for LM subclasses that implement Chat Templating. This method is intended to be overridden in a subclass to define a specific chat template format.
Should return the structure of the chat template applied to user/assistant messages. For models that do not support chat templates, this method returns None by default.
This is used only to save in the experiment results for reproducibility.
""" """
raise NotImplementedError(
"To use this model with chat templates, please implement the 'chat_template' property." return ""
)
def set_cache_hook(self, cache_hook) -> None: def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook self.cache_hook = cache_hook
...@@ -246,9 +243,10 @@ class CachingLM: ...@@ -246,9 +243,10 @@ class CachingLM:
# add hook to lm # add hook to lm
lm.set_cache_hook(self.get_cache_hook()) lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr): def __getattr__(self, attr: str):
lm_attr = getattr(self.lm, attr) lm_attr = getattr(self.lm, attr)
if not callable(lm_attr): if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr return lm_attr
def fn(requests): def fn(requests):
...@@ -283,8 +281,11 @@ class CachingLM: ...@@ -283,8 +281,11 @@ class CachingLM:
eval_logger.info( eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
) )
if remaining_reqs:
# actually run the LM on the requests that do not have cached results # actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs) rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
# stick the new ones back into the list and also cache any of the new ones # stick the new ones back into the list and also cache any of the new ones
resptr = 0 resptr = 0
...@@ -313,6 +314,8 @@ class TemplateLM(LM): ...@@ -313,6 +314,8 @@ class TemplateLM(LM):
and boilerplate often included in other LM subclasses. and boilerplate often included in other LM subclasses.
""" """
tokenizer = None
@property @property
@abc.abstractmethod @abc.abstractmethod
def eot_token_id(self): def eot_token_id(self):
...@@ -324,14 +327,19 @@ class TemplateLM(LM): ...@@ -324,14 +327,19 @@ class TemplateLM(LM):
return self.eot_token_id return self.eot_token_id
@abc.abstractmethod @abc.abstractmethod
def tok_encode(self, string: str, **kwargs): def tok_encode(self, string: str, **kwargs) -> List[int]:
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def _loglikelihood_tokens(self, requests, **kwargs): def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
pass pass
def _encode_pair(self, context, continuation): def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0: if n_spaces > 0:
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
...@@ -372,9 +380,110 @@ class TemplateLM(LM): ...@@ -372,9 +380,110 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[float]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
pass pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
"""
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
The template selection logic is adapted from the Transformers library's `apply_chat_template`
method in the Tokenizer class. The original implementation can be found at:
https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
This method ensures that the right template is chosen based on the following:
0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
1. If the model's tokenizer has multiple templates:
a. Use the specified template if it exists in the dictionary.
b. Use the default template from the list if no specific template is provided.
c. Raise an error if no default template exists and no specific template is provided.
2. If the model's tokenizer has a single template or no template:
a. Use the tokenizer's chat template if available.
b. Fall back to the default chat template if no tokenizer chat template exists.
Args:
chat_template (Union[bool, str]): Specifies the chat template to use.
- If False or None, no template is applied.
- If True, the default or only available template is used.
- If a string, the template with the matching name is used.
Returns:
Optional[str]: The selected chat template, or None if no template is applied.
"""
if self.tokenizer is None:
return ""
if chat_template is False or chat_template is None:
eval_logger.warning(
"model.chat_template was called with the chat_template set to False or None. "
"Therefore no chat template will be applied. Make sure this is an intended behavior."
)
return None
# Convert boolean chat_template to None to ensure compatibility with the adapted logic
if isinstance(chat_template, bool):
chat_template = None
using_default_template = False
# First, handle the cases when the model has a dict of multiple templates
try:
template = (
self.tokenizer.chat_template or self.tokenizer.default_chat_template
)
except AttributeError:
return None
if isinstance(template, dict):
using_default_dict = self.tokenizer.chat_template is None
if chat_template is not None:
if chat_template in template:
selected_template = template[chat_template]
if using_default_dict:
using_default_template = True
else:
raise ValueError(
f"The specified chat template '{chat_template}' is not available. "
f"Available template names are {sorted(template.keys())}."
)
else:
# If user didn't pass a chat template, use the default template from the dict
if "default" in template:
selected_template = template["default"]
using_default_template = True
else:
raise ValueError(
"This model has multiple chat templates with no default specified! Please either pass a chat "
"template or the name of the template you wish to use to the `chat_template` argument. Available "
f"template names are {sorted(template.keys())}."
)
# Cases when the model has a single template or no template
else:
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
if isinstance(chat_template, str):
eval_logger.warning(
"Chat template name provided, but the tokenizer's chat template is not a dictionary. "
"Using the tokenizer's chat template or the default template instead."
)
if self.tokenizer.chat_template is not None:
selected_template = self.tokenizer.chat_template
else:
selected_template = self.tokenizer.default_chat_template
using_default_template = True
if using_default_template:
eval_logger.warning(
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)
return selected_template
import logging import logging
from typing import Callable, Dict from typing import Callable, Dict, Union
import evaluate as hf_evaluate import evaluate as hf_evaluate
...@@ -185,8 +185,12 @@ def register_filter(name): ...@@ -185,8 +185,12 @@ def register_filter(name):
return decorate return decorate
def get_filter(filter_name: str) -> type: def get_filter(filter_name: Union[str, Callable]) -> Callable:
try: try:
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
except KeyError: except KeyError as e:
if callable(filter_name):
return filter_name
else:
eval_logger.warning(f"filter `{filter_name}` is not registered!") eval_logger.warning(f"filter `{filter_name}` is not registered!")
raise e
from functools import partial
import datasets import datasets
...@@ -15,8 +17,37 @@ class ContextSampler: ...@@ -15,8 +17,37 @@ class ContextSampler:
self.target_delimiter = self.config.target_delimiter self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter self.fewshot_delimiter = self.config.fewshot_delimiter
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_text", None) is not None
):
self.doc_to_text = partial(
self.task.doc_to_text,
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
)
else:
self.doc_to_text = self.task.doc_to_text self.doc_to_text = self.task.doc_to_text
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_target", None) is not None
):
self.doc_to_target = partial(
self.task.doc_to_target,
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
)
else:
self.doc_to_target = self.task.doc_to_target self.doc_to_target = self.task.doc_to_target
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_choice", None) is not None
):
self.doc_to_choice = partial(
self.task.doc_to_choice,
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
)
else:
self.doc_to_choice = self.task.doc_to_choice self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs() self.docs = docs # HF dataset split, provided by task._fewshot_docs()
...@@ -51,6 +82,8 @@ class ContextSampler: ...@@ -51,6 +82,8 @@ class ContextSampler:
if self.config.doc_to_choice is None or isinstance(doc_content, str) if self.config.doc_to_choice is None or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content] else self.doc_to_choice(doc)[doc_content]
) )
if doc_target != "":
labeled_examples += self.target_delimiter labeled_examples += self.target_delimiter
labeled_examples += ( labeled_examples += (
str(doc_target[0]) str(doc_target[0])
......
...@@ -56,8 +56,7 @@ class TaskConfig(dict): ...@@ -56,8 +56,7 @@ class TaskConfig(dict):
# task naming/registry # task naming/registry
task: Optional[str] = None task: Optional[str] = None
task_alias: Optional[str] = None task_alias: Optional[str] = None
group: Optional[Union[str, list]] = None tag: Optional[Union[str, list]] = None
group_alias: Optional[Union[str, list]] = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
...@@ -68,13 +67,14 @@ class TaskConfig(dict): ...@@ -68,13 +67,14 @@ class TaskConfig(dict):
validation_split: Optional[str] = None validation_split: Optional[str] = None
test_split: Optional[str] = None test_split: Optional[str] = None
fewshot_split: Optional[str] = ( fewshot_split: Optional[str] = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
) )
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None use_prompt: Optional[str] = None
...@@ -365,18 +365,23 @@ class Task(abc.ABC): ...@@ -365,18 +365,23 @@ class Task(abc.ABC):
def doc_to_target(self, doc): def doc_to_target(self, doc):
pass pass
# not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc):
raise NotImplementedError
def build_all_requests( def build_all_requests(
self, self,
*, *,
limit=None, limit: Union[int, None] = None,
rank=None, rank: int = 0,
world_size=None, world_size: int = 1,
cache_requests=False, cache_requests: bool = False,
rewrite_requests_cache=False, rewrite_requests_cache: bool = False,
system_instruction=None, system_instruction: Optional[str] = None,
apply_chat_template=False, apply_chat_template: bool = False,
fewshot_as_multiturn=False, fewshot_as_multiturn: bool = False,
lm=None, chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
) -> None: ) -> None:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
...@@ -391,9 +396,9 @@ class Task(abc.ABC): ...@@ -391,9 +396,9 @@ class Task(abc.ABC):
if system_instruction is not None if system_instruction is not None
else "" else ""
) )
cache_key += f"-tokenizer{lm.tokenizer_name}" if apply_chat_template else "" cache_key += f"-tokenizer{tokenizer_name}"
cached_instances = load_from_cache(file_name=cache_key) cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
if cache_requests and cached_instances and not rewrite_requests_cache: if cache_requests and cached_instances and not rewrite_requests_cache:
cached_instances = cached_instances[:limit] cached_instances = cached_instances[:limit]
...@@ -436,7 +441,7 @@ class Task(abc.ABC): ...@@ -436,7 +441,7 @@ class Task(abc.ABC):
system_instruction, system_instruction,
apply_chat_template, apply_chat_template,
fewshot_as_multiturn, fewshot_as_multiturn,
lm, chat_template,
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -444,6 +449,7 @@ class Task(abc.ABC): ...@@ -444,6 +449,7 @@ class Task(abc.ABC):
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats), metadata=(self.config["task"], doc_id, self.config.repeats),
apply_chat_template=apply_chat_template,
) )
if not isinstance(inst, list): if not isinstance(inst, list):
...@@ -722,6 +728,10 @@ class ConfigurableTask(Task): ...@@ -722,6 +728,10 @@ class ConfigurableTask(Task):
) )
self.OUTPUT_TYPE = self.config.output_type self.OUTPUT_TYPE = self.config.output_type
if self.config.doc_to_image is not None:
# mark the task as requiring multimodality.
self.MULTIMODAL = True
if self.config.dataset_path is not None: if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path self.DATASET_PATH = self.config.dataset_path
...@@ -979,7 +989,7 @@ class ConfigurableTask(Task): ...@@ -979,7 +989,7 @@ class ConfigurableTask(Task):
else: else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning( eval_logger.warning(
f"Task '{self.config.task}': " f"[Task: {self.config.task}] "
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule." "using preconfigured rule."
) )
...@@ -1014,7 +1024,7 @@ class ConfigurableTask(Task): ...@@ -1014,7 +1024,7 @@ class ConfigurableTask(Task):
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
lm=None, chat_template: Optional[Callable] = None,
) -> str: ) -> str:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1029,8 +1039,8 @@ class ConfigurableTask(Task): ...@@ -1029,8 +1039,8 @@ class ConfigurableTask(Task):
Whether to apply the chat template to the fewshot context. Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool :param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param lm: :param chat_template:
Language model with definition of the tokenizer/function to use for applying the chat template. callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
...@@ -1077,7 +1087,7 @@ class ConfigurableTask(Task): ...@@ -1077,7 +1087,7 @@ class ConfigurableTask(Task):
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if apply_chat_template: if apply_chat_template:
if self.multiple_input: if self.multiple_input:
return lm.apply_chat_template(labeled_examples) return chat_template(labeled_examples)
if isinstance(example, str): if isinstance(example, str):
self.append_target_question( self.append_target_question(
labeled_examples, example, fewshot_as_multiturn labeled_examples, example, fewshot_as_multiturn
...@@ -1089,7 +1099,7 @@ class ConfigurableTask(Task): ...@@ -1089,7 +1099,7 @@ class ConfigurableTask(Task):
for ex in example: for ex in example:
chat = deepcopy(labeled_examples) chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn) self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(lm.apply_chat_template(chat)) labeled_examples_list.append(chat_template(chat))
return labeled_examples_list return labeled_examples_list
# if example is an integer, append the choice or convert to string # if example is an integer, append the choice or convert to string
elif isinstance(example, int): elif isinstance(example, int):
...@@ -1103,7 +1113,7 @@ class ConfigurableTask(Task): ...@@ -1103,7 +1113,7 @@ class ConfigurableTask(Task):
labeled_examples, str(example), fewshot_as_multiturn labeled_examples, str(example), fewshot_as_multiturn
) )
# return lm.apply_chat_template(labeled_examples) # return lm.apply_chat_template(labeled_examples)
return lm.apply_chat_template(labeled_examples) return chat_template(labeled_examples)
else: else:
if self.multiple_input: if self.multiple_input:
return labeled_examples return labeled_examples
...@@ -1158,9 +1168,11 @@ class ConfigurableTask(Task): ...@@ -1158,9 +1168,11 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc, doc_to_text=None):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
elif doc_to_text is not None:
doc_to_text = doc_to_text
else: else:
doc_to_text = self.config.doc_to_text doc_to_text = self.config.doc_to_text
...@@ -1192,9 +1204,11 @@ class ConfigurableTask(Task): ...@@ -1192,9 +1204,11 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]: def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
elif doc_to_target is not None:
doc_to_target = doc_to_target
else: else:
doc_to_target = self.config.doc_to_target doc_to_target = self.config.doc_to_target
...@@ -1236,9 +1250,11 @@ class ConfigurableTask(Task): ...@@ -1236,9 +1250,11 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]: def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif doc_to_choice is not None:
doc_to_choice = doc_to_choice
elif self.config.doc_to_choice is None: elif self.config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config") eval_logger.error("doc_to_choice was called but not set in config")
else: else:
...@@ -1260,9 +1276,36 @@ class ConfigurableTask(Task): ...@@ -1260,9 +1276,36 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
doc_to_image = self.config.doc_to_image
else:
return None
if isinstance(doc_to_image, list):
image_feature = [
self.doc_to_image(doc, feature) for feature in doc_to_image
]
return [feature for feature in image_feature if feature is not None]
elif isinstance(doc_to_image, str):
if doc_to_image in self.features:
return doc[doc_to_image]
else:
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
elif callable(doc_to_image):
return doc_to_image(doc)
else:
return None
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
apply_chat_template = kwargs.pop("apply_chat_template", False)
aux_arguments = None
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc)) arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
...@@ -1270,6 +1313,8 @@ class ConfigurableTask(Task): ...@@ -1270,6 +1313,8 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter target_delimiter = self.config.target_delimiter
if apply_chat_template:
target_delimiter = ""
if self.multiple_input: if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
...@@ -1280,16 +1325,6 @@ class ConfigurableTask(Task): ...@@ -1280,16 +1325,6 @@ class ConfigurableTask(Task):
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys(): if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
...@@ -1298,25 +1333,48 @@ class ConfigurableTask(Task): ...@@ -1298,25 +1333,48 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating # here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice. # in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend( aux_arguments = [("", f"{choice}") for choice in choices]
[
arguments.extend(aux_arguments)
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
multimodal_arg = {}
if (
self.config.doc_to_image
): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = {
**multimodal_arg,
**{"visual": self.doc_to_image(doc)},
}
if bool(multimodal_arg):
if isinstance(arguments, list):
arguments = [arg + (multimodal_arg,) for arg in arguments]
else:
arguments = arguments + (multimodal_arg,)
if self.OUTPUT_TYPE == "multiple_choice":
request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=("", "{}".format(choice)), arguments=arg,
idx=i, idx=i,
**kwargs, **kwargs,
) )
for i, choice in enumerate(choices) for i, arg in enumerate(arguments)
] ]
)
return request_list
elif self.OUTPUT_TYPE == "generate_until": return request_list
arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs,
) )
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -1445,7 +1503,10 @@ class ConfigurableTask(Task): ...@@ -1445,7 +1503,10 @@ class ConfigurableTask(Task):
# we expect multiple_targets to be a list. # we expect multiple_targets to be a list.
elif self.multiple_target: elif self.multiple_target:
gold = list(gold) gold = list(gold)
elif type(gold) != type(result) and not isinstance(result, List): elif (
type(gold) is not type(result)
and "bypass" not in self._metric_fn_list.keys()
):
# cast gold to the same type as result # cast gold to the same type as result
gold = type(result)(gold) gold = type(result)(gold)
...@@ -1519,10 +1580,13 @@ class ConfigurableTask(Task): ...@@ -1519,10 +1580,13 @@ class ConfigurableTask(Task):
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
@property
def task_name(self) -> Any:
return getattr(self.config, "task", None)
def __repr__(self): def __repr__(self):
return ( return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE}," f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})" f"num_samples={len(self.eval_docs)})"
......
...@@ -21,7 +21,9 @@ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest() ...@@ -21,7 +21,9 @@ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
FILE_SUFFIX = f".{HASH_PREFIX}.pickle" FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
def load_from_cache(file_name): def load_from_cache(file_name: str, cache: bool = False):
if not cache:
return
try: try:
path = f"{PATH}/{file_name}{FILE_SUFFIX}" path = f"{PATH}/{file_name}{FILE_SUFFIX}"
......
...@@ -110,12 +110,15 @@ class TextReader: ...@@ -110,12 +110,15 @@ class TextReader:
def read_tqdm(self, update_frequency: int = 10000): def read_tqdm(self, update_frequency: int = 10000):
current_file_position = 0 current_file_position = 0
line_counter = 0 line_counter = 0
with open(self.file_path, "r", encoding="utf-8") as fh, tqdm.tqdm( with (
open(self.file_path, "r", encoding="utf-8") as fh,
tqdm.tqdm(
total=os.path.getsize(self.file_path), total=os.path.getsize(self.file_path),
dynamic_ncols=True, dynamic_ncols=True,
unit="byte", unit="byte",
unit_scale=1, unit_scale=1,
) as progress: ) as progress,
):
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
......
...@@ -11,19 +11,25 @@ import torch ...@@ -11,19 +11,25 @@ import torch
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.api.task
import lm_eval.models import lm_eval.models
from lm_eval.caching.cache import delete_cache from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import ( from lm_eval.evaluator_utils import (
consolidate_group_results,
consolidate_results, consolidate_results,
get_sample_size, get_sample_size,
get_subtask_list,
get_task_list, get_task_list,
prepare_print_tasks, prepare_print_tasks,
print_writeout, print_writeout,
run_task_tests, run_task_tests,
) )
from lm_eval.loggers import EvaluationTracker from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, get_git_commit_hash from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import (
TaskManager,
get_task_dict,
)
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
handle_non_serializable, handle_non_serializable,
...@@ -35,7 +41,7 @@ from lm_eval.utils import ( ...@@ -35,7 +41,7 @@ from lm_eval.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.tasks import Task from lm_eval.api.task import Task
@positional_deprecated @positional_deprecated
...@@ -44,7 +50,7 @@ def simple_evaluate( ...@@ -44,7 +50,7 @@ def simple_evaluate(
model_args: Optional[Union[str, dict]] = None, model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None, tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None, num_fewshot: Optional[int] = None,
batch_size: Optional[int] = None, batch_size: Optional[Union[int, str]] = None,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
device: Optional[str] = None, device: Optional[str] = None,
use_cache: Optional[str] = None, use_cache: Optional[str] = None,
...@@ -58,7 +64,7 @@ def simple_evaluate( ...@@ -58,7 +64,7 @@ def simple_evaluate(
log_samples: bool = True, log_samples: bool = True,
evaluation_tracker: Optional[EvaluationTracker] = None, evaluation_tracker: Optional[EvaluationTracker] = None,
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
gen_kwargs: Optional[str] = None, gen_kwargs: Optional[str] = None,
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
...@@ -106,8 +112,11 @@ def simple_evaluate( ...@@ -106,8 +112,11 @@ def simple_evaluate(
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param system_instruction: str :param system_instruction: str
System instruction to be applied to the prompt System instruction to be applied to the prompt
:param apply_chat_template: bool :param apply_chat_template: Union[bool, str]
If True, apply chat template to the prompt Specifies whether to apply a chat template to the prompt.
- If set to True, the default chat template is applied.
- If set to a string, applies the specified chat template by name.
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool :param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param gen_kwargs: str :param gen_kwargs: str
...@@ -148,6 +157,9 @@ def simple_evaluate( ...@@ -148,6 +157,9 @@ def simple_evaluate(
seed_message.append(f"Setting torch manual seed to {torch_random_seed}") seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed) torch.manual_seed(torch_random_seed)
if fewshot_random_seed is not None:
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
if seed_message: if seed_message:
eval_logger.info(" | ".join(seed_message)) eval_logger.info(" | ".join(seed_message))
...@@ -199,7 +211,9 @@ def simple_evaluate( ...@@ -199,7 +211,9 @@ def simple_evaluate(
) )
else: else:
if not isinstance(model, lm_eval.api.model.LM): if not isinstance(model, lm_eval.api.model.LM):
raise TypeError raise TypeError(
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
)
eval_logger.info("Using pre-initialized model") eval_logger.info("Using pre-initialized model")
lm = model lm = model
...@@ -219,13 +233,19 @@ def simple_evaluate( ...@@ -219,13 +233,19 @@ def simple_evaluate(
task_manager = TaskManager(verbosity) task_manager = TaskManager(verbosity)
task_dict = get_task_dict(tasks, task_manager) task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if isinstance(task_obj, tuple):
_, task_obj = task_obj
if task_obj is None:
continue
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def _adjust_config(task_dict):
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
adjusted_task_dict = {
**adjusted_task_dict,
**{task_name: _adjust_config(task_obj)},
}
else:
if task_obj.get_config("output_type") == "generate_until": if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None: if gen_kwargs is not None:
task_obj.set_config( task_obj.set_config(
...@@ -233,7 +253,6 @@ def simple_evaluate( ...@@ -233,7 +253,6 @@ def simple_evaluate(
) )
if predict_only: if predict_only:
log_samples = True
eval_logger.info( eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!" f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
) )
...@@ -254,13 +273,18 @@ def simple_evaluate( ...@@ -254,13 +273,18 @@ def simple_evaluate(
task_obj.set_config(key="num_fewshot", value=num_fewshot) task_obj.set_config(key="num_fewshot", value=num_fewshot)
else: else:
# if num_fewshot not provided, and the task does not define a default one, default to 0 # if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None: if (
default_num_fewshot := task_obj.get_config("num_fewshot")
) is None:
task_obj.set_config(key="num_fewshot", value=0) task_obj.set_config(key="num_fewshot", value=0)
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
task_obj.set_fewshot_seed(seed=fewshot_random_seed) task_obj.set_fewshot_seed(seed=fewshot_random_seed)
eval_logger.info(
f"Setting fewshot random generator seed to {fewshot_random_seed}" adjusted_task_dict[task_name] = task_obj
)
return adjusted_task_dict
task_dict = _adjust_config(task_dict)
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -270,7 +294,10 @@ def simple_evaluate( ...@@ -270,7 +294,10 @@ def simple_evaluate(
model_source=model, model_source=model,
model_args=model_args, model_args=model_args,
system_instruction=system_instruction, system_instruction=system_instruction,
chat_template=lm.chat_template if apply_chat_template else None, chat_template=lm.chat_template(apply_chat_template)
if apply_chat_template
else None,
fewshot_as_multiturn=fewshot_as_multiturn,
) )
results = evaluate( results = evaluate(
...@@ -281,7 +308,7 @@ def simple_evaluate( ...@@ -281,7 +308,7 @@ def simple_evaluate(
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
write_out=write_out, write_out=write_out,
log_samples=log_samples, log_samples=True if predict_only else log_samples,
system_instruction=system_instruction, system_instruction=system_instruction,
apply_chat_template=apply_chat_template, apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn, fewshot_as_multiturn=fewshot_as_multiturn,
...@@ -325,6 +352,7 @@ def simple_evaluate( ...@@ -325,6 +352,7 @@ def simple_evaluate(
results["git_hash"] = get_git_commit_hash() results["git_hash"] = get_git_commit_hash()
results["date"] = start_date results["date"] = start_date
add_env_info(results) # additional environment info to results add_env_info(results) # additional environment info to results
add_tokenizer_info(results, lm) # additional info about tokenizer
return results return results
else: else:
return None return None
...@@ -341,7 +369,7 @@ def evaluate( ...@@ -341,7 +369,7 @@ def evaluate(
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
verbosity: str = "INFO", verbosity: str = "INFO",
): ):
...@@ -361,8 +389,11 @@ def evaluate( ...@@ -361,8 +389,11 @@ def evaluate(
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param system_instruction: str :param system_instruction: str
System instruction to be applied to the prompt System instruction to be applied to the prompt
:param apply_chat_template: bool :param apply_chat_template: Union[bool, str]
If True, apply chat template to the prompt Specifies whether to apply a chat template to the prompt.
- If set to True, the default chat template is applied.
- If set to a string, applies the specified chat template by name.
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool :param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:return :return
...@@ -371,6 +402,11 @@ def evaluate( ...@@ -371,6 +402,11 @@ def evaluate(
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
if apply_chat_template:
eval_logger.warning(
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
)
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = defaultdict(list) requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
...@@ -378,16 +414,40 @@ def evaluate( ...@@ -378,16 +414,40 @@ def evaluate(
padding_requests = defaultdict(int) padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request # get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict) eval_tasks = get_task_list(task_dict)
if not log_samples: if not log_samples:
if not all( if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
for task_output in eval_tasks for task_output in eval_tasks
): ):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks") raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
# validation check: are we running multimodal task <-> non-multimodal model class, or vice-versa.
incompatible_tasks = []
for task_output in eval_tasks:
task: Task = task_output.task
if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name)
if len(incompatible_tasks) > 0:
if not getattr(lm, "MULTIMODAL", False):
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
)
else:
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end multimodality validation check
# Cache the limit arg.
limit_arg = limit
limits = []
for task_output in eval_tasks: for task_output in eval_tasks:
task: Task = task_output.task task: Task = task_output.task
limit = get_sample_size(task, limit)
limit = get_sample_size(task, limit_arg)
limits.append(limit)
task.build_all_requests( task.build_all_requests(
limit=limit, limit=limit,
rank=lm.rank, rank=lm.rank,
...@@ -395,9 +455,14 @@ def evaluate( ...@@ -395,9 +455,14 @@ def evaluate(
cache_requests=cache_requests, cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
system_instruction=system_instruction, system_instruction=system_instruction,
apply_chat_template=apply_chat_template, apply_chat_template=bool(apply_chat_template),
fewshot_as_multiturn=fewshot_as_multiturn, fewshot_as_multiturn=fewshot_as_multiturn,
lm=lm, chat_template=getattr(lm, "apply_chat_template")
if apply_chat_template
else None,
tokenizer_name=getattr(lm, "tokenizer_name", "")
if apply_chat_template
else "",
) )
eval_logger.debug( eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
...@@ -452,7 +517,7 @@ def evaluate( ...@@ -452,7 +517,7 @@ def evaluate(
WORLD_SIZE = lm.world_size WORLD_SIZE = lm.world_size
### Postprocess outputs ### ### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_output in eval_tasks: for task_output, limit in zip(eval_tasks, limits):
task = task_output.task task = task_output.task
task.apply_filters() task.apply_filters()
...@@ -487,6 +552,8 @@ def evaluate( ...@@ -487,6 +552,8 @@ def evaluate(
"filtered_resps": [ "filtered_resps": [
req.filtered_resps[filter_key] for req in requests req.filtered_resps[filter_key] for req in requests
], ],
"filter": filter_key,
"metrics": list(metrics.keys()),
"doc_hash": hash_string( "doc_hash": hash_string(
json.dumps( json.dumps(
requests[0].doc, requests[0].doc,
...@@ -550,22 +617,26 @@ def evaluate( ...@@ -550,22 +617,26 @@ def evaluate(
### Calculate group metrics ### ### Calculate group metrics ###
if bool(results): if bool(results):
for group, task_list in reversed(task_hierarchy.items()): results, versions, show_group_table, *_ = consolidate_group_results(
if len(task_list) == 0: results, versions, task_dict
# task_hierarchy entries are either )
# `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`. results_agg, group_agg = prepare_print_tasks(task_dict, results)
# we only want to operate on groups here. subtask_list = get_subtask_list(task_dict)
continue
# collect all higher_is_better values for metrics # collect all higher_is_better values for metrics
# in the group's subtasks. # in the group's subtasks.
# TODO: clean this up ; unify with the below metric_list loop? # TODO: clean this up ; unify with the below metric_list loop?
_higher_is_better = {} _higher_is_better = {}
for group, task_list in subtask_list.items():
if (
len(task_list) != 0
): # subtask list will list "task_name": [] for solo tasks
for task in task_list: for task in task_list:
for m, h in higher_is_better[task].items(): for m, h in higher_is_better[task].items():
if m not in _higher_is_better.keys(): if m not in _higher_is_better.keys():
_higher_is_better[m] = h _higher_is_better[m] = h
if ( if (
m in _higher_is_better m in _higher_is_better
and _higher_is_better[m] is not None and _higher_is_better[m] is not None
...@@ -577,79 +648,14 @@ def evaluate( ...@@ -577,79 +648,14 @@ def evaluate(
_higher_is_better[m] = None _higher_is_better[m] = None
higher_is_better[group] = _higher_is_better higher_is_better[group] = _higher_is_better
# collect all metric keys used by a subtask in the group.
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
# compute group's pooled metric and stderr
results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
else:
results[group][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group]["samples"] = sum(sizes)
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
if len(left_tasks_list) == 0:
break
_task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
for group_name, task_list in task_hierarchy.items():
if task_list:
num_fewshot[group_name] = num_fewshot[
task_list[0]
] # TODO: validate this
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}), **(
"group_subtasks": dict(reversed(task_hierarchy.items())), {"groups": dict(group_agg.items())}
if (bool(group_agg) & show_group_table)
else {}
),
"group_subtasks": dict(reversed(subtask_list.items())),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())), "n-shot": dict(sorted(num_fewshot.items())),
...@@ -662,7 +668,7 @@ def evaluate( ...@@ -662,7 +668,7 @@ def evaluate(
len(task_output.task.eval_docs), len(task_output.task.eval_docs),
), ),
} }
for task_output in eval_tasks for task_output, limit in zip(eval_tasks, limits)
}, },
} }
if log_samples: if log_samples:
......
...@@ -2,9 +2,15 @@ import collections ...@@ -2,9 +2,15 @@ import collections
import math import math
import pathlib import pathlib
import sys import sys
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lm_eval.api import metrics from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import (
aggregate_subtask_metrics,
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import Task
from lm_eval.utils import eval_logger, positional_deprecated from lm_eval.utils import eval_logger, positional_deprecated
...@@ -98,7 +104,7 @@ class TaskOutput: ...@@ -98,7 +104,7 @@ class TaskOutput:
self.agg_metrics[metric_key] = agg_fn(items) self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric? self.sample_len = len(items) # TODO: same sample size for each metric?
if isinstance(bootstrap_iters, int): if isinstance(bootstrap_iters, int):
stderr_fn = metrics.stderr_for_metric( stderr_fn = stderr_for_metric(
metric=agg_fn, metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100) bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"] if metric in ["bleu", "chrf", "ter"]
...@@ -116,23 +122,71 @@ class TaskOutput: ...@@ -116,23 +122,71 @@ class TaskOutput:
return ( return (
f"TaskOutput(task_name={self.task_name}, " f"TaskOutput(task_name={self.task_name}, "
f"group_name={self.group_name}, " f"group_name={self.group_name}, "
f"version={self.version}," f"version={self.version}, "
f"n_shot={self.n_shot}" f"n_shot={self.n_shot}, "
f"task_alias={self.task_alias}, group_alias={self.group_alias})" f"task_alias={self.task_alias}, "
f"group_alias={self.group_alias})"
) )
def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]: def get_task_list(task_dict: dict) -> List[TaskOutput]:
task_hierarchy = collections.defaultdict(list) outputs = []
outputs = list(TaskOutput.from_taskdict(x, y) for x, y in task_dict.items()) for task_name, task_obj in task_dict.items():
for task_output in outputs: if isinstance(task_obj, dict):
if group_name := task_output.group_name: _outputs = get_task_list(task_obj)
task_hierarchy[group_name].append(task_output.task_name) outputs.extend(_outputs)
else: else:
task_hierarchy[task_output.task_name] = [] task_output = TaskOutput.from_taskdict(task_name, task_obj)
# returns task_hierarchy tracking which groups contain which subtasks, outputs.append(task_output)
# and a list of TaskOutput classes for each non-group subtask
return task_hierarchy, [x for x in outputs if x.task] return outputs
def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list = {}
for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup):
# group_name = group_obj.group_name
group_name = group_obj.group_name
else:
group_name = group_obj
if isinstance(task_obj, dict):
_subtask_list = get_subtask_list(
task_obj, task_root=group_name, depth=depth + 1
)
if task_root:
subtask_list.setdefault((task_root, depth), []).extend(
[
_task
for (_task, _depth) in _subtask_list.keys()
if (_depth - 1) == depth
]
)
subtask_list = {**subtask_list, **_subtask_list}
else:
if isinstance(task_obj, ConfigurableGroup):
# group_or_task_name = task_obj.group_name
group_or_task_name = task_obj.group_name
elif isinstance(task_obj, Task):
# group_or_task_name = task_obj.task_name
group_or_task_name = task_obj.task_name
if task_root is None:
subtask_list.setdefault((group_or_task_name, depth), [])
else:
subtask_list.setdefault((task_root, depth), []).append(
group_or_task_name
)
if depth == 0:
_subtask_list = {}
for group_key, task_list in subtask_list.items():
group_name, depth = group_key
_subtask_list[group_name] = task_list
subtask_list = _subtask_list
return subtask_list
def print_writeout(task) -> None: def print_writeout(task) -> None:
...@@ -155,70 +209,95 @@ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]: ...@@ -155,70 +209,95 @@ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
def prepare_print_tasks( def prepare_print_tasks(
task_hierarchy: dict, results: dict, tab=0 task_dict: dict,
results: dict,
task_depth=0,
group_depth=0,
) -> Tuple[dict, dict]: ) -> Tuple[dict, dict]:
""" """
@param task_hierarchy: Dictionary representing the group hierarchy of tasks. Each key is a group name and its @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names. value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a @param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results. group name and its value is a dictionary of task results.
@param tab: The indentation level for printing the task @param task_depth: The indentation level for printing the task
hierarchy. Default is 0.
@param group_depth: The indentation level for printing the group
hierarchy. Default is 0. hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group. aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
""" """
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
(group_name, task_list), *_ = task_hierarchy.items() def _sort_task_dict(task_dict):
task_list = sorted(task_list) """
Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
results_agg[group_name] = results[group_name].copy() Required so that we end up sorting within each sub-header correctly.
# results_agg[group_name]["tab"] = tab """
if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
tab_string = " " * tab + "- " if tab > 0 else "" return dict(
sorted(
task_dict.items(),
key=lambda item: item[0].group_name
if isinstance(item[0], ConfigurableGroup)
else item[0],
)
)
if "alias" in results_agg[group_name]: task_agg = collections.defaultdict(dict)
results_agg[group_name]["alias"] = tab_string + results_agg[group_name]["alias"] group_agg = collections.defaultdict(dict)
task_dict = _sort_task_dict(task_dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
# string_name = task_or_group_name.group_name
name = task_or_group_name.group_name
from_configurable_group = True
task_or_group_obj = _sort_task_dict(task_or_group_obj)
elif isinstance(task_or_group_name, str):
name = task_or_group_name
if isinstance(task_or_group_obj, Task):
# string_name = task_or_group_obj.task_name
name = task_or_group_obj.task_name
from_configurable_group = False
task_agg[name] = results[name].copy()
if from_configurable_group:
if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias
else: else:
results_agg[group_name]["alias"] = tab_string + group_name alias = task_or_group_name.group
if len(task_list) > 0:
groups_agg[group_name] = results[group_name].copy()
# groups_agg[group_name]["tab"] = tab
if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if "alias" in groups_agg[group_name]:
groups_agg[group_name]["alias"] = (
tab_string + groups_agg[group_name]["alias"]
)
else: else:
groups_agg[group_name]["alias"] = tab_string + group_name if "alias" in task_agg[name]:
alias = task_agg[name]["alias"]
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else: else:
_task_hierarchy = { alias = name
**{task_name: []},
**task_hierarchy, task_agg[name]["alias"] = tab_string + alias
} if "samples" in task_agg[name]:
task_agg[name].pop("samples")
_results_agg, _groups_agg = prepare_print_tasks(
_task_hierarchy, results, tab + 1 if from_configurable_group and (" " not in results[name]):
group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
group_agg[name] = results[name].copy()
group_agg[name]["alias"] = group_tab_string + alias
if "samples" in group_agg[name]:
group_agg[name].pop("samples")
if isinstance(task_or_group_obj, dict):
task_depth += 1
group_depth += 1
_task_agg, _group_agg = prepare_print_tasks(
task_or_group_obj, results, task_depth, group_depth
) )
results_agg = {**results_agg, **_results_agg} task_agg = {
groups_agg = {**groups_agg, **_groups_agg} **task_agg,
**_task_agg,
return results_agg, groups_agg }
group_agg = {**group_agg, **_group_agg}
task_depth -= 1
group_depth -= 1
return task_agg, group_agg
def consolidate_results( def consolidate_results(
...@@ -261,6 +340,8 @@ def consolidate_results( ...@@ -261,6 +340,8 @@ def consolidate_results(
for task_output in eval_tasks: for task_output in eval_tasks:
if "task_alias" in (task_config := task_output.task_config): if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_name]["alias"] = task_config["task_alias"] results[task_output.task_name]["alias"] = task_config["task_alias"]
else:
results[task_output.task_name]["alias"] = task_output.task_name
if group_alias := task_output.group_alias: if group_alias := task_output.group_alias:
if group_alias not in results and (group_name := task_output.group_name): if group_alias not in results and (group_name := task_output.group_name):
results[group_name]["alias"] = group_alias results[group_name]["alias"] = group_alias
...@@ -275,12 +356,153 @@ def consolidate_results( ...@@ -275,12 +356,153 @@ def consolidate_results(
metric_key metric_key
] ]
results[task_output.task_name]["samples"] = task_output.sample_len results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][ results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
f"{metric}_stderr,{filter_key}" task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] )
return results, samples, configs, versions, num_fewshot, higher_is_better return results, samples, configs, versions, num_fewshot, higher_is_better
def consolidate_group_results(
results,
versions,
task_dict,
task_root=None,
show_group_table=False,
task_aggregation_list=None,
) -> Tuple[dict, dict, bool, Union[None,]]:
"""
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
@return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
- results: A defaultdict with task names (and, after this function is called, group names of
groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
- versions: A defaultdict with task names (and, after this function is called, group names of
groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
- show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
- task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored.
"""
if task_root is None:
task_root = {}
if task_aggregation_list is None:
task_aggregation_list = {}
for group_or_task, group_or_task_info in task_dict.items():
# Convert to string
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group_or_task = group_or_task.group_name
else:
group_config = None
if isinstance(group_or_task_info, Task):
if task_root:
task_aggregation_list.setdefault(task_root, []).append(
group_or_task_info.task_name
)
else:
(
results,
versions,
show_group_table,
_task_aggregation_list,
) = consolidate_group_results(
results,
versions,
group_or_task_info,
group_or_task,
show_group_table,
task_aggregation_list,
)
if task_root:
task_aggregation_list.setdefault(task_root, []).extend(
task_aggregation_list.get(group_or_task, [])
)
if (group_config is None) or (
group_config["aggregate_metric_list"] is None
):
results[group_or_task][" "] = " "
continue
if "aggregate_metric_list" in group_config:
agg_metric_list = group_config["aggregate_metric_list"]
show_group_table = show_group_table | bool(
group_config["aggregate_metric_list"]
)
task_list = _task_aggregation_list[group_or_task]
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["task", "alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
for metric_config in agg_metric_list:
for filter_name in metric_config["filter_list"]:
if metric != ",".join([metric_config["metric"], filter_name]):
continue
# compute group's pooled metric and stderr
if metric_config["aggregation"] == "mean":
aggregate_fn = aggregate_subtask_metrics
elif callable(metric_config["aggregation"]):
aggregate_fn = metric_config["aggregation"]
else:
raise ValueError(
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
)
results[group_or_task][metric] = aggregate_fn(
metrics,
sizes,
metric_config["weight_by_size"],
)
# TODO: calculate groups' metrics using arbitrary agg fns
if "N/A" in stderrs:
results[group_or_task][stderr] = "N/A"
else:
# NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
results[group_or_task][stderr] = pooled_sample_stderr(
stderrs, sizes
)
results[group_or_task]["samples"] = sum(sizes)
group_metadata = group_config.get("metadata", None)
if group_metadata is not None:
versions[group_or_task] = group_metadata.get("version", None)
# print(results)
return results, versions, show_group_table, task_aggregation_list
@positional_deprecated @positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path: def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
""" """
......
...@@ -37,16 +37,18 @@ class RegexFilter(Filter): ...@@ -37,16 +37,18 @@ class RegexFilter(Filter):
if match: if match:
match = match[self.group_select] match = match[self.group_select]
if isinstance(match, tuple): if isinstance(match, tuple):
match = [m for m in match if m][0] match = [m for m in match if m]
if match:
match = match[0]
else:
match = self.fallback
match = match.strip() match = match.strip()
else: else:
match = self.fallback match = self.fallback
filtered.append(match) filtered.append(match)
return filtered return filtered
# print(resps)
filtered_resps = list(map(lambda x: filter_set(x), resps)) filtered_resps = list(map(lambda x: filter_set(x), resps))
# print(filtered_resps)
return filtered_resps return filtered_resps
...@@ -62,11 +64,8 @@ class WhitespaceFilter(Filter): ...@@ -62,11 +64,8 @@ class WhitespaceFilter(Filter):
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
if resp.startswith(" "): resp = resp.lstrip()
resp = resp[1:]
filtered_resp.append(resp) filtered_resp.append(resp)
return filtered_resp return filtered_resp
filtered_resps = [filter_set(resp) for resp in resps] filtered_resps = [filter_set(resp) for resp in resps]
...@@ -165,7 +164,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -165,7 +164,7 @@ class MultiChoiceRegexFilter(RegexFilter):
fallback_regex = re.compile("|".join(fallback_regexes)) fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile( without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})" rf":[\s]*({without_paren_fallback_regex})"
) )
filtered = [] filtered = []
......
import json import json
import os
import re import re
import time import time
from collections import defaultdict from collections import defaultdict
...@@ -14,6 +15,7 @@ from huggingface_hub import ( ...@@ -14,6 +15,7 @@ from huggingface_hub import (
HfApi, HfApi,
hf_hub_url, hf_hub_url,
) )
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
...@@ -48,6 +50,7 @@ class GeneralConfigTracker: ...@@ -48,6 +50,7 @@ class GeneralConfigTracker:
model_name_sanitized: str = None model_name_sanitized: str = None
system_instruction: str = None system_instruction: str = None
system_instruction_sha: str = None system_instruction_sha: str = None
fewshot_as_multiturn: bool = None
chat_template: str = None chat_template: str = None
chat_template_sha: str = None chat_template_sha: str = None
start_time: float = None start_time: float = None
...@@ -80,6 +83,7 @@ class GeneralConfigTracker: ...@@ -80,6 +83,7 @@ class GeneralConfigTracker:
model_args: str, model_args: str,
system_instruction: str, system_instruction: str,
chat_template: str, chat_template: str,
fewshot_as_multiturn: bool,
) -> None: ) -> None:
"""Logs model parameters and job ID.""" """Logs model parameters and job ID."""
self.model_source = model_source self.model_source = model_source
...@@ -91,6 +95,7 @@ class GeneralConfigTracker: ...@@ -91,6 +95,7 @@ class GeneralConfigTracker:
) )
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_sha = hash_string(chat_template) if chat_template else None self.chat_template_sha = hash_string(chat_template) if chat_template else None
self.fewshot_as_multiturn = fewshot_as_multiturn
def log_end_time(self) -> None: def log_end_time(self) -> None:
"""Logs the end time of the evaluation and calculates the total evaluation time.""" """Logs the end time of the evaluation and calculates the total evaluation time."""
...@@ -109,12 +114,15 @@ class EvaluationTracker: ...@@ -109,12 +114,15 @@ class EvaluationTracker:
output_path: str = None, output_path: str = None,
hub_results_org: str = "", hub_results_org: str = "",
hub_repo_name: str = "", hub_repo_name: str = "",
details_repo_name: str = "",
results_repo_name: str = "",
push_results_to_hub: bool = False, push_results_to_hub: bool = False,
push_samples_to_hub: bool = False, push_samples_to_hub: bool = False,
public_repo: bool = False, public_repo: bool = False,
token: str = "", token: str = "",
leaderboard_url: str = "", leaderboard_url: str = "",
point_of_contact: str = "", point_of_contact: str = "",
gated: bool = False,
) -> None: ) -> None:
""" """
Creates all the necessary loggers for evaluation tracking. Creates all the necessary loggers for evaluation tracking.
...@@ -123,12 +131,15 @@ class EvaluationTracker: ...@@ -123,12 +131,15 @@ class EvaluationTracker:
output_path (str): Path to save the results. If not provided, the results won't be saved. output_path (str): Path to save the results. If not provided, the results won't be saved.
hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token. hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`. hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
push_results_to_hub (bool): Whether to push the results to the Hugging Face hub. push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub. push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
public_repo (bool): Whether to push the results to a public or private repository. public_repo (bool): Whether to push the results to a public or private repository.
token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`. token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card. leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
point_of_contact (str): Contact information on the Hugging Face hub dataset card. point_of_contact (str): Contact information on the Hugging Face hub dataset card.
gated (bool): Whether to gate the repository.
""" """
self.general_config_tracker = GeneralConfigTracker() self.general_config_tracker = GeneralConfigTracker()
...@@ -139,6 +150,7 @@ class EvaluationTracker: ...@@ -139,6 +150,7 @@ class EvaluationTracker:
self.leaderboard_url = leaderboard_url self.leaderboard_url = leaderboard_url
self.point_of_contact = point_of_contact self.point_of_contact = point_of_contact
self.api = HfApi(token=token) if token else None self.api = HfApi(token=token) if token else None
self.gated_repo = gated
if not self.api and (push_results_to_hub or push_samples_to_hub): if not self.api and (push_results_to_hub or push_samples_to_hub):
raise ValueError( raise ValueError(
...@@ -156,9 +168,24 @@ class EvaluationTracker: ...@@ -156,9 +168,24 @@ class EvaluationTracker:
f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'." f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
) )
hub_repo_name = hub_repo_name if hub_repo_name else "lm-eval-results" if hub_repo_name == "":
self.hub_results_repo = f"{hub_results_org}/{hub_repo_name}" details_repo_name = (
self.hub_results_repo_private = f"{hub_results_org}/{hub_repo_name}-private" details_repo_name if details_repo_name != "" else "lm-eval-results"
)
results_repo_name = (
results_repo_name if results_repo_name != "" else details_repo_name
)
else:
details_repo_name = hub_repo_name
results_repo_name = hub_repo_name
eval_logger.warning(
"hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
)
self.details_repo = f"{hub_results_org}/{details_repo_name}"
self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
self.results_repo = f"{hub_results_org}/{results_repo_name}"
self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
def save_results_aggregated( def save_results_aggregated(
self, self,
...@@ -208,9 +235,9 @@ class EvaluationTracker: ...@@ -208,9 +235,9 @@ class EvaluationTracker:
if self.api and self.push_results_to_hub: if self.api and self.push_results_to_hub:
repo_id = ( repo_id = (
self.hub_results_repo self.results_repo
if self.public_repo if self.public_repo
else self.hub_results_repo_private else self.results_repo_private
) )
self.api.create_repo( self.api.create_repo(
repo_id=repo_id, repo_id=repo_id,
...@@ -218,10 +245,15 @@ class EvaluationTracker: ...@@ -218,10 +245,15 @@ class EvaluationTracker:
private=not self.public_repo, private=not self.public_repo,
exist_ok=True, exist_ok=True,
) )
self.api.upload_folder( self.api.upload_file(
repo_id=repo_id, repo_id=repo_id,
folder_path=str(path), path_or_fileobj=str(
path_in_repo=self.general_config_tracker.model_name_sanitized, path.joinpath(f"results_{self.date_id}.json")
),
path_in_repo=os.path.join(
self.general_config_tracker.model_name,
f"results_{self.date_id}.json",
),
repo_type="dataset", repo_type="dataset",
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}", commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
) )
...@@ -275,6 +307,7 @@ class EvaluationTracker: ...@@ -275,6 +307,7 @@ class EvaluationTracker:
sample["resps"] = sanitize_list(sample["resps"]) sample["resps"] = sanitize_list(sample["resps"])
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"]) sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
sample["arguments"] = arguments sample["arguments"] = arguments
sample["target"] = str(sample["target"])
sample_dump = ( sample_dump = (
json.dumps( json.dumps(
...@@ -285,14 +318,14 @@ class EvaluationTracker: ...@@ -285,14 +318,14 @@ class EvaluationTracker:
+ "\n" + "\n"
) )
with open(file_results_samples, "a") as f: with open(file_results_samples, "a", encoding="utf-8") as f:
f.write(sample_dump) f.write(sample_dump)
if self.api and self.push_samples_to_hub: if self.api and self.push_samples_to_hub:
repo_id = ( repo_id = (
self.hub_results_repo self.details_repo
if self.public_repo if self.public_repo
else self.hub_results_repo_private else self.details_repo_private
) )
self.api.create_repo( self.api.create_repo(
repo_id=repo_id, repo_id=repo_id,
...@@ -300,6 +333,18 @@ class EvaluationTracker: ...@@ -300,6 +333,18 @@ class EvaluationTracker:
private=not self.public_repo, private=not self.public_repo,
exist_ok=True, exist_ok=True,
) )
try:
if self.gated_repo:
headers = build_hf_headers()
r = get_session().put(
url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
headers=headers,
json={"gated": "auto"},
)
hf_raise_for_status(r)
except Exception as e:
eval_logger.warning("Could not gate the repository")
eval_logger.info(repr(e))
self.api.upload_folder( self.api.upload_folder(
repo_id=repo_id, repo_id=repo_id,
folder_path=str(path), folder_path=str(path),
...@@ -324,9 +369,7 @@ class EvaluationTracker: ...@@ -324,9 +369,7 @@ class EvaluationTracker:
""" """
eval_logger.info("Recreating metadata card") eval_logger.info("Recreating metadata card")
repo_id = ( repo_id = self.details_repo if self.public_repo else self.details_repo_private
self.hub_results_repo if self.public_repo else self.hub_results_repo_private
)
files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset") files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
results_files = get_results_filenames(files_in_repo) results_files = get_results_filenames(files_in_repo)
...@@ -357,7 +400,10 @@ class EvaluationTracker: ...@@ -357,7 +400,10 @@ class EvaluationTracker:
results_datetime, results_datetime,
) )
latest_task_results_datetime[samples_key] = latest_datetime latest_task_results_datetime[samples_key] = latest_datetime
latest_task_results_datetime[results_key] = latest_datetime latest_task_results_datetime[results_key] = max(
latest_task_results_datetime[results_key],
latest_datetime,
)
# Create metadata card # Create metadata card
card_metadata = MetadataConfigs() card_metadata = MetadataConfigs()
...@@ -374,6 +420,8 @@ class EvaluationTracker: ...@@ -374,6 +420,8 @@ class EvaluationTracker:
sanitized_last_eval_date_results = re.sub( sanitized_last_eval_date_results = re.sub(
r"[^\w\.]", "_", latest_task_results_datetime[config_name] r"[^\w\.]", "_", latest_task_results_datetime[config_name]
) )
if eval_date_sanitized == sanitized_last_eval_date_results:
# Ensure that all results files are listed in the metadata card # Ensure that all results files are listed in the metadata card
current_results = card_metadata.get(config_name, {"data_files": []}) current_results = card_metadata.get(config_name, {"data_files": []})
current_results["data_files"].append( current_results["data_files"].append(
...@@ -381,7 +429,6 @@ class EvaluationTracker: ...@@ -381,7 +429,6 @@ class EvaluationTracker:
) )
card_metadata[config_name] = current_results card_metadata[config_name] = current_results
# If the results file is the newest, update the "latest" field in the metadata card # If the results file is the newest, update the "latest" field in the metadata card
if eval_date_sanitized == sanitized_last_eval_date_results:
card_metadata[config_name]["data_files"].append( card_metadata[config_name]["data_files"].append(
{"split": "latest", "path": [str(results_filename)]} {"split": "latest", "path": [str(results_filename)]}
) )
...@@ -400,6 +447,7 @@ class EvaluationTracker: ...@@ -400,6 +447,7 @@ class EvaluationTracker:
sanitized_last_eval_date_results = re.sub( sanitized_last_eval_date_results = re.sub(
r"[^\w\.]", "_", latest_task_results_datetime[config_name] r"[^\w\.]", "_", latest_task_results_datetime[config_name]
) )
if eval_date_sanitized == sanitized_last_eval_date_results:
# Ensure that all sample results files are listed in the metadata card # Ensure that all sample results files are listed in the metadata card
current_details_for_task = card_metadata.get( current_details_for_task = card_metadata.get(
config_name, {"data_files": []} config_name, {"data_files": []}
...@@ -409,56 +457,10 @@ class EvaluationTracker: ...@@ -409,56 +457,10 @@ class EvaluationTracker:
) )
card_metadata[config_name] = current_details_for_task card_metadata[config_name] = current_details_for_task
# If the samples results file is the newest, update the "latest" field in the metadata card # If the samples results file is the newest, update the "latest" field in the metadata card
if eval_date_sanitized == sanitized_last_eval_date_results:
card_metadata[config_name]["data_files"].append( card_metadata[config_name]["data_files"].append(
{"split": "latest", "path": [str(results_filename)]} {"split": "latest", "path": [str(results_filename)]}
) )
# Special case for MMLU with a single split covering it all
# We add another config with all MMLU splits results together for easy inspection
SPECIAL_TASKS = ["mmlu", "gpqa", "minerva_math"]
for special_task in SPECIAL_TASKS:
if special_task in config_name:
special_task = f"{model_name}__{special_task}"
former_entry = card_metadata.get(special_task, {"data_files": []})
former_split = [
(i, entry)
for i, entry in enumerate(former_entry["data_files"])
if entry.get("split", None) == eval_date_sanitized
]
if len(former_split) == 0:
former_entry["data_files"].append(
{
"split": eval_date_sanitized,
"path": [str(results_filename)],
}
)
else:
split_index, _ = former_split[0]
former_entry["data_files"][split_index]["path"].append(
str(results_filename)
)
if eval_date_sanitized == sanitized_last_eval_date_results:
latest_split = [
(i, entry)
for i, entry in enumerate(former_entry["data_files"])
if entry.get("split", None) == "latest"
]
if len(latest_split) == 0:
former_entry["data_files"].append(
{"split": "latest", "path": [str(results_filename)]}
)
else:
latest_index, _ = latest_split[0]
former_entry["data_files"][latest_index]["path"].append(
str(results_filename)
)
card_metadata[special_task] = former_entry
# Get latest results and extract info to update metadata card examples # Get latest results and extract info to update metadata card examples
latest_datetime = max(latest_task_results_datetime.values()) latest_datetime = max(latest_task_results_datetime.values())
latest_model_name = max( latest_model_name = max(
......
...@@ -110,3 +110,34 @@ def add_env_info(storage: Dict[str, Any]): ...@@ -110,3 +110,34 @@ def add_env_info(storage: Dict[str, Any]):
"upper_git_hash": upper_dir_commit, # in case this repo is submodule "upper_git_hash": upper_dir_commit, # in case this repo is submodule
} }
storage.update(added_info) storage.update(added_info)
def add_tokenizer_info(storage: Dict[str, Any], lm):
if getattr(lm, "tokenizer", False):
try:
tokenizer_info = {
"tokenizer_pad_token": [
lm.tokenizer.pad_token,
str(lm.tokenizer.pad_token_id),
],
"tokenizer_eos_token": [
lm.tokenizer.eos_token,
str(lm.tokenizer.eos_token_id),
],
"tokenizer_bos_token": [
lm.tokenizer.bos_token,
str(lm.tokenizer.bos_token_id),
],
"eot_token_id": getattr(lm, "eot_token_id", None),
"max_length": getattr(lm, "max_length", None),
}
storage.update(tokenizer_info)
except Exception as err:
logger.debug(
f"Logging detailed tokenizer info failed with {err}, skipping..."
)
# seems gguf and textsynth do not have tokenizer
else:
logger.debug(
"LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
)
...@@ -15,10 +15,9 @@ logger = logging.getLogger(__name__) ...@@ -15,10 +15,9 @@ logger = logging.getLogger(__name__)
def get_wandb_printer() -> Literal["Printer"]: def get_wandb_printer() -> Literal["Printer"]:
"""Returns a wandb printer instance for pretty stdout.""" """Returns a wandb printer instance for pretty stdout."""
from wandb.sdk.lib.printer import get_printer from wandb.sdk.lib.printer import new_printer
from wandb.sdk.wandb_settings import Settings
printer = get_printer(Settings()._jupyter) printer = new_printer()
return printer return printer
...@@ -49,6 +48,9 @@ class WandbLogger: ...@@ -49,6 +48,9 @@ class WandbLogger:
self.wandb_args: Dict[str, Any] = kwargs self.wandb_args: Dict[str, Any] = kwargs
# pop the step key from the args to save for all logging calls
self.step = self.wandb_args.pop("step", None)
# initialize a W&B run # initialize a W&B run
if wandb.run is None: if wandb.run is None:
self.run = wandb.init(**self.wandb_args) self.run = wandb.init(**self.wandb_args)
...@@ -153,11 +155,11 @@ class WandbLogger: ...@@ -153,11 +155,11 @@ class WandbLogger:
# log the complete eval result to W&B Table # log the complete eval result to W&B Table
table = make_table(["Tasks"] + columns, "results") table = make_table(["Tasks"] + columns, "results")
self.run.log({"evaluation/eval_results": table}) self.run.log({"evaluation/eval_results": table}, step=self.step)
if "groups" in self.results.keys(): if "groups" in self.results.keys():
table = make_table(["Groups"] + columns, "groups") table = make_table(["Groups"] + columns, "groups")
self.run.log({"evaluation/group_eval_results": table}) self.run.log({"evaluation/group_eval_results": table}, step=self.step)
def _log_results_as_artifact(self) -> None: def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B.""" """Log results as JSON artifact to W&B."""
...@@ -175,13 +177,13 @@ class WandbLogger: ...@@ -175,13 +177,13 @@ class WandbLogger:
"""Log evaluation results to W&B.""" """Log evaluation results to W&B."""
# Log configs to wandb # Log configs to wandb
configs = self._get_config() configs = self._get_config()
self.run.config.update(configs) self.run.config.update(configs, allow_val_change=self.step is not None)
wandb_summary, self.wandb_results = self._sanitize_results_dict() wandb_summary, self.wandb_results = self._sanitize_results_dict()
# update wandb.run.summary with items that were removed # update wandb.run.summary with items that were removed
self.run.summary.update(wandb_summary) self.run.summary.update(wandb_summary)
# Log the evaluation metrics to wandb # Log the evaluation metrics to wandb
self.run.log(self.wandb_results) self.run.log(self.wandb_results, step=self.step)
# Log the evaluation metrics as W&B Table # Log the evaluation metrics as W&B Table
self._log_results_as_table() self._log_results_as_table()
# Log the results dict as json to W&B Artifacts # Log the results dict as json to W&B Artifacts
...@@ -330,7 +332,7 @@ class WandbLogger: ...@@ -330,7 +332,7 @@ class WandbLogger:
# log the samples as a W&B Table # log the samples as a W&B Table
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name)) df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
self.run.log({f"{task_name}_eval_results": df}) self.run.log({f"{task_name}_eval_results": df}, step=self.step)
# log the samples as a json file as W&B Artifact # log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
...@@ -349,4 +351,4 @@ class WandbLogger: ...@@ -349,4 +351,4 @@ class WandbLogger:
# log the samples as a json file as W&B Artifact # log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
self.run.log({f"{group}_eval_results": grouped_df}) self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
from . import ( from . import (
anthropic_llms, anthropic_llms,
api_models,
dummy, dummy,
gguf, gguf,
hf_vlms,
huggingface, huggingface,
ibm_watsonx_ai,
mamba_lm, mamba_lm,
nemo_lm, nemo_lm,
neuralmagic, neuralmagic,
neuron_optimum, neuron_optimum,
openai_completions, openai_completions,
optimum_ipex,
optimum_lm, optimum_lm,
textsynth, textsynth,
vllm_causallms, vllm_causallms,
vllm_vlms,
) )
......
from typing import Any, List, Tuple import os
from functools import cached_property
from typing import Any, Dict, List, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import retry_on_specific_exceptions from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -42,8 +45,8 @@ def anthropic_completion( ...@@ -42,8 +45,8 @@ def anthropic_completion(
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -105,8 +108,8 @@ def anthropic_chat( ...@@ -105,8 +108,8 @@ def anthropic_chat(
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -138,7 +141,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -138,7 +141,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
return messages() return messages()
@register_model("anthropic") @register_model("anthropic-completions")
class AnthropicLM(LM): class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 # TODO: not used REQ_CHUNK_SIZE = 20 # TODO: not used
...@@ -165,8 +168,8 @@ class AnthropicLM(LM): ...@@ -165,8 +168,8 @@ class AnthropicLM(LM):
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -214,8 +217,8 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -214,8 +217,8 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -271,90 +274,94 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -271,90 +274,94 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
@register_model("anthropic-chat", "anthropic-chat-completions") @register_model("anthropic-chat", "anthropic-chat-completions")
class AnthropicChatLM(AnthropicLM): class AnthropicChat(LocalCompletionsAPI):
REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__( def __init__(
self, self,
model: str, base_url="https://api.anthropic.com/v1/messages",
batch_size: int = 1, tokenizer_backend=None,
max_tokens: int = 256, **kwargs,
temperature: float = 0, # defaults to 1 ):
**kwargs, # top_p, top_k, etc. super().__init__(
) -> None: base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
"""Anthropic API wrapper.
:param model: str
Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
:param max_tokens: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
"""
super().__init__()
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
eval_logger.warning(
self.model = model "Chat completions does not support batching. Defaulting to batch size 1."
# defaults to os.environ.get("ANTHROPIC_API_KEY") )
self.client = anthropic.Anthropic() self._batch_size = 1
self.temperature = temperature self.anthropic_version = "2023-06-01"
self.max_tokens = max_tokens eval_logger.warning(
self.tokenizer = self.client.get_tokenizer() f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
self.kwargs = kwargs
@property
def max_gen_toks(self) -> int:
return self.max_tokens
def generate_until(self, requests) -> List[str]:
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
if not requests: @cached_property
return [] def api_key(self):
"""Override this property to return the API key for the API request."""
key = os.environ.get("ANTHROPIC_API_KEY", None)
if key is None:
raise ValueError(
"API key not found. Please set the ANTHROPIC_API_KEY environment variable."
)
return key
_requests: List[Tuple[str, dict]] = [req.args for req in requests] @cached_property
def header(self):
return {
"x-api-key": f"{self.api_key}",
"anthropic-version": self.anthropic_version,
}
res = [] def _create_payload(
for request in tqdm(_requests): self,
try: messages: List[Dict],
inp = request[0] generate=True,
request_args = request[1] gen_kwargs: dict = None,
# generation_kwargs eos="\n\nHuman:",
until = request_args.get("until") **kwargs,
max_tokens = request_args.get("max_gen_toks", self.max_length) ) -> dict:
temperature = request_args.get("temperature", self.temperature) system = (
response = anthropic_chat( messages[0].get("content") if messages[0].get("role") == "system" else None
client=self.client,
model=self.model,
prompt=inp,
max_tokens=max_tokens,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, # type: ignore
**self.kwargs,
) )
res.append(response) if system:
messages = messages[1:]
gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
if not isinstance(stop, list):
stop = [stop]
out = {
"messages": messages,
"model": self.model,
"max_tokens": max_tokens,
"temperature": temperature,
"stop_sequences": stop,
**gen_kwargs,
}
if system:
out["system"] = system
return out
def parse_generations(
self, outputs: Union[Dict, List[Dict]], **kwargs
) -> List[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choices in out["content"]:
res.append(choices["text"])
return res
self.cache_hook.add_partial("generate_until", request, response) def tok_encode(
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821 self,
eval_logger.critical(f"Server unreachable: {e.__cause__}") string: str,
break left_truncate_len=None,
except anthropic.APIStatusError as e: # type: ignore # noqa: F821 add_special_tokens=None,
eval_logger.critical(f"API error {e.status_code}: {e.message}") **kwargs,
break ) -> List[str]:
return [string]
return res def loglikelihood(self, requests, **kwargs):
raise NotImplementedError(
"Anthropic Chat Completions API does not support the return of loglikelihood"
)
This diff is collapsed.
...@@ -26,9 +26,9 @@ class DummyLM(LM): ...@@ -26,9 +26,9 @@ class DummyLM(LM):
def generate_until(self, requests, disable_tqdm: bool = False): def generate_until(self, requests, disable_tqdm: bool = False):
res = [] res = []
for ctx, _ in tqdm(requests, disable=disable_tqdm): for request in tqdm(requests, disable=disable_tqdm):
res.append("lol") res.append("lol")
assert ctx.strip() != "" assert request.arguments[0].strip() != ""
return res return res
......
...@@ -68,7 +68,9 @@ class GGUFLM(LM): ...@@ -68,7 +68,9 @@ class GGUFLM(LM):
logger.error(f"RequestException: {e}") logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying time.sleep(delay) # wait before retrying
else: else:
raise Exception(f"Failed to get a valid response after {retries} retries.") raise RuntimeError(
f"Failed to get a valid response after {retries} retries."
)
def loglikelihood(self, requests, disable_tqdm: bool = False): def loglikelihood(self, requests, disable_tqdm: bool = False):
if not requests: if not requests:
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment