Commit 173b2bc3 authored by Baber's avatar Baber
Browse files

Merge branch 'main' into humaneval

# Conflicts:
#	lm_eval/api/task.py
parents 74344829 bb098f13
...@@ -3,7 +3,7 @@ import hashlib ...@@ -3,7 +3,7 @@ import hashlib
import json import json
import logging import logging
import os import os
from typing import Dict, List, Optional, Tuple, Type, TypeVar from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
import transformers import transformers
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
...@@ -55,7 +55,7 @@ class LM(abc.ABC): ...@@ -55,7 +55,7 @@ class LM(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: def loglikelihood_rolling(self, requests) -> List[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...@@ -101,14 +101,13 @@ class LM(abc.ABC): ...@@ -101,14 +101,13 @@ class LM(abc.ABC):
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until). A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str context: str
Context string Context string
until: [str] gen_kwargs: dict
The string sequences to generate until. These string sequences A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
may each span across multiple tokens, or may be part of one token.
:return: list[str] :return: list[str]
A list of strings continuation A list of model generated continuations.
continuation: str continuation: str
The generated continuation. The generated continuation.
""" """
...@@ -193,15 +192,13 @@ class LM(abc.ABC): ...@@ -193,15 +192,13 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property." "To use this model with chat templates, please implement the 'tokenizer_name' property."
) )
@property def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self) -> str: """Returns the chat template structure for user/assistant messages if a template is provided.
"""Must be defined for LM subclasses that implement Chat Templating. This method is intended to be overridden in a subclass to define a specific chat template format.
Should return the structure of the chat template applied to user/assistant messages. For models that do not support chat templates, this method returns None by default.
This is used only to save in the experiment results for reproducibility.
""" """
raise NotImplementedError(
"To use this model with chat templates, please implement the 'chat_template' property." return ""
)
def set_cache_hook(self, cache_hook) -> None: def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook self.cache_hook = cache_hook
...@@ -246,9 +243,10 @@ class CachingLM: ...@@ -246,9 +243,10 @@ class CachingLM:
# add hook to lm # add hook to lm
lm.set_cache_hook(self.get_cache_hook()) lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr): def __getattr__(self, attr: str):
lm_attr = getattr(self.lm, attr) lm_attr = getattr(self.lm, attr)
if not callable(lm_attr): if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr return lm_attr
def fn(requests): def fn(requests):
...@@ -283,8 +281,11 @@ class CachingLM: ...@@ -283,8 +281,11 @@ class CachingLM:
eval_logger.info( eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
) )
# actually run the LM on the requests that do not have cached results if remaining_reqs:
rem_res = getattr(self.lm, attr)(remaining_reqs) # actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
# stick the new ones back into the list and also cache any of the new ones # stick the new ones back into the list and also cache any of the new ones
resptr = 0 resptr = 0
...@@ -313,6 +314,8 @@ class TemplateLM(LM): ...@@ -313,6 +314,8 @@ class TemplateLM(LM):
and boilerplate often included in other LM subclasses. and boilerplate often included in other LM subclasses.
""" """
tokenizer = None
@property @property
@abc.abstractmethod @abc.abstractmethod
def eot_token_id(self): def eot_token_id(self):
...@@ -324,14 +327,19 @@ class TemplateLM(LM): ...@@ -324,14 +327,19 @@ class TemplateLM(LM):
return self.eot_token_id return self.eot_token_id
@abc.abstractmethod @abc.abstractmethod
def tok_encode(self, string: str, **kwargs): def tok_encode(self, string: str, **kwargs) -> List[int]:
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def _loglikelihood_tokens(self, requests, **kwargs): def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
pass pass
def _encode_pair(self, context, continuation): def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0: if n_spaces > 0:
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
...@@ -372,9 +380,110 @@ class TemplateLM(LM): ...@@ -372,9 +380,110 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[float]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
pass pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
"""
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
The template selection logic is adapted from the Transformers library's `apply_chat_template`
method in the Tokenizer class. The original implementation can be found at:
https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
This method ensures that the right template is chosen based on the following:
0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
1. If the model's tokenizer has multiple templates:
a. Use the specified template if it exists in the dictionary.
b. Use the default template from the list if no specific template is provided.
c. Raise an error if no default template exists and no specific template is provided.
2. If the model's tokenizer has a single template or no template:
a. Use the tokenizer's chat template if available.
b. Fall back to the default chat template if no tokenizer chat template exists.
Args:
chat_template (Union[bool, str]): Specifies the chat template to use.
- If False or None, no template is applied.
- If True, the default or only available template is used.
- If a string, the template with the matching name is used.
Returns:
Optional[str]: The selected chat template, or None if no template is applied.
"""
if self.tokenizer is None:
return ""
if chat_template is False or chat_template is None:
eval_logger.warning(
"model.chat_template was called with the chat_template set to False or None. "
"Therefore no chat template will be applied. Make sure this is an intended behavior."
)
return None
# Convert boolean chat_template to None to ensure compatibility with the adapted logic
if isinstance(chat_template, bool):
chat_template = None
using_default_template = False
# First, handle the cases when the model has a dict of multiple templates
try:
template = (
self.tokenizer.chat_template or self.tokenizer.default_chat_template
)
except AttributeError:
return None
if isinstance(template, dict):
using_default_dict = self.tokenizer.chat_template is None
if chat_template is not None:
if chat_template in template:
selected_template = template[chat_template]
if using_default_dict:
using_default_template = True
else:
raise ValueError(
f"The specified chat template '{chat_template}' is not available. "
f"Available template names are {sorted(template.keys())}."
)
else:
# If user didn't pass a chat template, use the default template from the dict
if "default" in template:
selected_template = template["default"]
using_default_template = True
else:
raise ValueError(
"This model has multiple chat templates with no default specified! Please either pass a chat "
"template or the name of the template you wish to use to the `chat_template` argument. Available "
f"template names are {sorted(template.keys())}."
)
# Cases when the model has a single template or no template
else:
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
if isinstance(chat_template, str):
eval_logger.warning(
"Chat template name provided, but the tokenizer's chat template is not a dictionary. "
"Using the tokenizer's chat template or the default template instead."
)
if self.tokenizer.chat_template is not None:
selected_template = self.tokenizer.chat_template
else:
selected_template = self.tokenizer.default_chat_template
using_default_template = True
if using_default_template:
eval_logger.warning(
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)
return selected_template
import logging import logging
from typing import Callable, Dict from typing import Callable, Dict, Union
import evaluate as hf_evaluate import evaluate as hf_evaluate
...@@ -185,8 +185,12 @@ def register_filter(name): ...@@ -185,8 +185,12 @@ def register_filter(name):
return decorate return decorate
def get_filter(filter_name: str) -> type: def get_filter(filter_name: Union[str, Callable]) -> Callable:
try: try:
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
except KeyError: except KeyError as e:
eval_logger.warning(f"filter `{filter_name}` is not registered!") if callable(filter_name):
return filter_name
else:
eval_logger.warning(f"filter `{filter_name}` is not registered!")
raise e
from functools import partial
import datasets import datasets
...@@ -15,9 +17,38 @@ class ContextSampler: ...@@ -15,9 +17,38 @@ class ContextSampler:
self.target_delimiter = self.config.target_delimiter self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter self.fewshot_delimiter = self.config.fewshot_delimiter
self.doc_to_text = self.task.doc_to_text if (
self.doc_to_target = self.task.doc_to_target self.config.fewshot_config is not None
self.doc_to_choice = self.task.doc_to_choice and self.config.fewshot_config.get("doc_to_text", None) is not None
):
self.doc_to_text = partial(
self.task.doc_to_text,
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
)
else:
self.doc_to_text = self.task.doc_to_text
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_target", None) is not None
):
self.doc_to_target = partial(
self.task.doc_to_target,
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
)
else:
self.doc_to_target = self.task.doc_to_target
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_choice", None) is not None
):
self.doc_to_choice = partial(
self.task.doc_to_choice,
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
)
else:
self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs() self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from if fewshot_indices: # subset few-shot docs from
...@@ -51,15 +82,17 @@ class ContextSampler: ...@@ -51,15 +82,17 @@ class ContextSampler:
if self.config.doc_to_choice is None or isinstance(doc_content, str) if self.config.doc_to_choice is None or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content] else self.doc_to_choice(doc)[doc_content]
) )
labeled_examples += self.target_delimiter
labeled_examples += ( if doc_target != "":
str(doc_target[0]) labeled_examples += self.target_delimiter
if isinstance(doc_target, list) labeled_examples += (
else doc_target str(doc_target[0])
if self.config.doc_to_choice is None or isinstance(doc_target, str) if isinstance(doc_target, list)
else str(self.doc_to_choice(doc)[doc_target]) else doc_target
) if self.config.doc_to_choice is None or isinstance(doc_target, str)
labeled_examples += self.fewshot_delimiter else str(self.doc_to_choice(doc)[doc_target])
)
labeled_examples += self.fewshot_delimiter
return labeled_examples return labeled_examples
......
...@@ -56,8 +56,7 @@ class TaskConfig(dict): ...@@ -56,8 +56,7 @@ class TaskConfig(dict):
# task naming/registry # task naming/registry
task: Optional[str] = None task: Optional[str] = None
task_alias: Optional[str] = None task_alias: Optional[str] = None
group: Optional[Union[str, list]] = None tag: Optional[Union[str, list]] = None
group_alias: Optional[Union[str, list]] = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
...@@ -68,13 +67,14 @@ class TaskConfig(dict): ...@@ -68,13 +67,14 @@ class TaskConfig(dict):
validation_split: Optional[str] = None validation_split: Optional[str] = None
test_split: Optional[str] = None test_split: Optional[str] = None
fewshot_split: Optional[str] = ( fewshot_split: Optional[str] = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
) )
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None use_prompt: Optional[str] = None
...@@ -365,18 +365,23 @@ class Task(abc.ABC): ...@@ -365,18 +365,23 @@ class Task(abc.ABC):
def doc_to_target(self, doc): def doc_to_target(self, doc):
pass pass
# not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc):
raise NotImplementedError
def build_all_requests( def build_all_requests(
self, self,
*, *,
limit=None, limit: Union[int, None] = None,
rank=None, rank: int = 0,
world_size=None, world_size: int = 1,
cache_requests=False, cache_requests: bool = False,
rewrite_requests_cache=False, rewrite_requests_cache: bool = False,
system_instruction=None, system_instruction: Optional[str] = None,
apply_chat_template=False, apply_chat_template: bool = False,
fewshot_as_multiturn=False, fewshot_as_multiturn: bool = False,
lm=None, chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
) -> None: ) -> None:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
...@@ -391,9 +396,9 @@ class Task(abc.ABC): ...@@ -391,9 +396,9 @@ class Task(abc.ABC):
if system_instruction is not None if system_instruction is not None
else "" else ""
) )
cache_key += f"-tokenizer{lm.tokenizer_name}" if apply_chat_template else "" cache_key += f"-tokenizer{tokenizer_name}"
cached_instances = load_from_cache(file_name=cache_key) cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
if cache_requests and cached_instances and not rewrite_requests_cache: if cache_requests and cached_instances and not rewrite_requests_cache:
cached_instances = cached_instances[:limit] cached_instances = cached_instances[:limit]
...@@ -436,7 +441,7 @@ class Task(abc.ABC): ...@@ -436,7 +441,7 @@ class Task(abc.ABC):
system_instruction, system_instruction,
apply_chat_template, apply_chat_template,
fewshot_as_multiturn, fewshot_as_multiturn,
lm, chat_template,
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -444,6 +449,7 @@ class Task(abc.ABC): ...@@ -444,6 +449,7 @@ class Task(abc.ABC):
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats), metadata=(self.config["task"], doc_id, self.config.repeats),
apply_chat_template=apply_chat_template,
) )
if not isinstance(inst, list): if not isinstance(inst, list):
...@@ -722,6 +728,10 @@ class ConfigurableTask(Task): ...@@ -722,6 +728,10 @@ class ConfigurableTask(Task):
) )
self.OUTPUT_TYPE = self.config.output_type self.OUTPUT_TYPE = self.config.output_type
if self.config.doc_to_image is not None:
# mark the task as requiring multimodality.
self.MULTIMODAL = True
if self.config.dataset_path is not None: if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path self.DATASET_PATH = self.config.dataset_path
...@@ -979,7 +989,7 @@ class ConfigurableTask(Task): ...@@ -979,7 +989,7 @@ class ConfigurableTask(Task):
else: else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning( eval_logger.warning(
f"Task '{self.config.task}': " f"[Task: {self.config.task}] "
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule." "using preconfigured rule."
) )
...@@ -1014,7 +1024,7 @@ class ConfigurableTask(Task): ...@@ -1014,7 +1024,7 @@ class ConfigurableTask(Task):
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
lm=None, chat_template: Optional[Callable] = None,
) -> str: ) -> str:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1029,8 +1039,8 @@ class ConfigurableTask(Task): ...@@ -1029,8 +1039,8 @@ class ConfigurableTask(Task):
Whether to apply the chat template to the fewshot context. Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool :param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param lm: :param chat_template:
Language model with definition of the tokenizer/function to use for applying the chat template. callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
...@@ -1077,7 +1087,7 @@ class ConfigurableTask(Task): ...@@ -1077,7 +1087,7 @@ class ConfigurableTask(Task):
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if apply_chat_template: if apply_chat_template:
if self.multiple_input: if self.multiple_input:
return lm.apply_chat_template(labeled_examples) return chat_template(labeled_examples)
if isinstance(example, str): if isinstance(example, str):
self.append_target_question( self.append_target_question(
labeled_examples, example, fewshot_as_multiturn labeled_examples, example, fewshot_as_multiturn
...@@ -1089,7 +1099,7 @@ class ConfigurableTask(Task): ...@@ -1089,7 +1099,7 @@ class ConfigurableTask(Task):
for ex in example: for ex in example:
chat = deepcopy(labeled_examples) chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn) self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(lm.apply_chat_template(chat)) labeled_examples_list.append(chat_template(chat))
return labeled_examples_list return labeled_examples_list
# if example is an integer, append the choice or convert to string # if example is an integer, append the choice or convert to string
elif isinstance(example, int): elif isinstance(example, int):
...@@ -1103,7 +1113,7 @@ class ConfigurableTask(Task): ...@@ -1103,7 +1113,7 @@ class ConfigurableTask(Task):
labeled_examples, str(example), fewshot_as_multiturn labeled_examples, str(example), fewshot_as_multiturn
) )
# return lm.apply_chat_template(labeled_examples) # return lm.apply_chat_template(labeled_examples)
return lm.apply_chat_template(labeled_examples) return chat_template(labeled_examples)
else: else:
if self.multiple_input: if self.multiple_input:
return labeled_examples return labeled_examples
...@@ -1158,9 +1168,11 @@ class ConfigurableTask(Task): ...@@ -1158,9 +1168,11 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc, doc_to_text=None):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
elif doc_to_text is not None:
doc_to_text = doc_to_text
else: else:
doc_to_text = self.config.doc_to_text doc_to_text = self.config.doc_to_text
...@@ -1192,9 +1204,11 @@ class ConfigurableTask(Task): ...@@ -1192,9 +1204,11 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]: def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
elif doc_to_target is not None:
doc_to_target = doc_to_target
else: else:
doc_to_target = self.config.doc_to_target doc_to_target = self.config.doc_to_target
...@@ -1236,9 +1250,11 @@ class ConfigurableTask(Task): ...@@ -1236,9 +1250,11 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]: def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif doc_to_choice is not None:
doc_to_choice = doc_to_choice
elif self.config.doc_to_choice is None: elif self.config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config") eval_logger.error("doc_to_choice was called but not set in config")
else: else:
...@@ -1260,9 +1276,36 @@ class ConfigurableTask(Task): ...@@ -1260,9 +1276,36 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
doc_to_image = self.config.doc_to_image
else:
return None
if isinstance(doc_to_image, list):
image_feature = [
self.doc_to_image(doc, feature) for feature in doc_to_image
]
return [feature for feature in image_feature if feature is not None]
elif isinstance(doc_to_image, str):
if doc_to_image in self.features:
return doc[doc_to_image]
else:
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
elif callable(doc_to_image):
return doc_to_image(doc)
else:
return None
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
apply_chat_template = kwargs.pop("apply_chat_template", False)
aux_arguments = None
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc)) arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
...@@ -1270,6 +1313,8 @@ class ConfigurableTask(Task): ...@@ -1270,6 +1313,8 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter target_delimiter = self.config.target_delimiter
if apply_chat_template:
target_delimiter = ""
if self.multiple_input: if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
...@@ -1280,6 +1325,37 @@ class ConfigurableTask(Task): ...@@ -1280,6 +1325,37 @@ class ConfigurableTask(Task):
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
aux_arguments = [("", f"{choice}") for choice in choices]
arguments.extend(aux_arguments)
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
multimodal_arg = {}
if (
self.config.doc_to_image
): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = {
**multimodal_arg,
**{"visual": self.doc_to_image(doc)},
}
if bool(multimodal_arg):
if isinstance(arguments, list):
arguments = [arg + (multimodal_arg,) for arg in arguments]
else:
arguments = arguments + (multimodal_arg,)
if self.OUTPUT_TYPE == "multiple_choice":
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
...@@ -1290,33 +1366,15 @@ class ConfigurableTask(Task): ...@@ -1290,33 +1366,15 @@ class ConfigurableTask(Task):
) )
for i, arg in enumerate(arguments) for i, arg in enumerate(arguments)
] ]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list return request_list
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs,
) )
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -1445,7 +1503,10 @@ class ConfigurableTask(Task): ...@@ -1445,7 +1503,10 @@ class ConfigurableTask(Task):
# we expect multiple_targets to be a list. # we expect multiple_targets to be a list.
elif self.multiple_target: elif self.multiple_target:
gold = list(gold) gold = list(gold)
elif type(gold) != type(result) and not isinstance(result, List): elif (
type(gold) is not type(result)
and "bypass" not in self._metric_fn_list.keys()
):
# cast gold to the same type as result # cast gold to the same type as result
gold = type(result)(gold) gold = type(result)(gold)
...@@ -1519,10 +1580,13 @@ class ConfigurableTask(Task): ...@@ -1519,10 +1580,13 @@ class ConfigurableTask(Task):
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
@property
def task_name(self) -> Any:
return getattr(self.config, "task", None)
def __repr__(self): def __repr__(self):
return ( return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE}," f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})" f"num_samples={len(self.eval_docs)})"
......
...@@ -21,7 +21,9 @@ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest() ...@@ -21,7 +21,9 @@ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
FILE_SUFFIX = f".{HASH_PREFIX}.pickle" FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
def load_from_cache(file_name): def load_from_cache(file_name: str, cache: bool = False):
if not cache:
return
try: try:
path = f"{PATH}/{file_name}{FILE_SUFFIX}" path = f"{PATH}/{file_name}{FILE_SUFFIX}"
......
...@@ -110,12 +110,15 @@ class TextReader: ...@@ -110,12 +110,15 @@ class TextReader:
def read_tqdm(self, update_frequency: int = 10000): def read_tqdm(self, update_frequency: int = 10000):
current_file_position = 0 current_file_position = 0
line_counter = 0 line_counter = 0
with open(self.file_path, "r", encoding="utf-8") as fh, tqdm.tqdm( with (
total=os.path.getsize(self.file_path), open(self.file_path, "r", encoding="utf-8") as fh,
dynamic_ncols=True, tqdm.tqdm(
unit="byte", total=os.path.getsize(self.file_path),
unit_scale=1, dynamic_ncols=True,
) as progress: unit="byte",
unit_scale=1,
) as progress,
):
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
......
...@@ -11,19 +11,25 @@ import torch ...@@ -11,19 +11,25 @@ import torch
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.api.task
import lm_eval.models import lm_eval.models
from lm_eval.caching.cache import delete_cache from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import ( from lm_eval.evaluator_utils import (
consolidate_group_results,
consolidate_results, consolidate_results,
get_sample_size, get_sample_size,
get_subtask_list,
get_task_list, get_task_list,
prepare_print_tasks, prepare_print_tasks,
print_writeout, print_writeout,
run_task_tests, run_task_tests,
) )
from lm_eval.loggers import EvaluationTracker from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, get_git_commit_hash from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import (
TaskManager,
get_task_dict,
)
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
handle_non_serializable, handle_non_serializable,
...@@ -35,7 +41,7 @@ from lm_eval.utils import ( ...@@ -35,7 +41,7 @@ from lm_eval.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.tasks import Task from lm_eval.api.task import Task
@positional_deprecated @positional_deprecated
...@@ -44,7 +50,7 @@ def simple_evaluate( ...@@ -44,7 +50,7 @@ def simple_evaluate(
model_args: Optional[Union[str, dict]] = None, model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None, tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None, num_fewshot: Optional[int] = None,
batch_size: Optional[int] = None, batch_size: Optional[Union[int, str]] = None,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
device: Optional[str] = None, device: Optional[str] = None,
use_cache: Optional[str] = None, use_cache: Optional[str] = None,
...@@ -58,7 +64,7 @@ def simple_evaluate( ...@@ -58,7 +64,7 @@ def simple_evaluate(
log_samples: bool = True, log_samples: bool = True,
evaluation_tracker: Optional[EvaluationTracker] = None, evaluation_tracker: Optional[EvaluationTracker] = None,
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
gen_kwargs: Optional[str] = None, gen_kwargs: Optional[str] = None,
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
...@@ -106,8 +112,11 @@ def simple_evaluate( ...@@ -106,8 +112,11 @@ def simple_evaluate(
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param system_instruction: str :param system_instruction: str
System instruction to be applied to the prompt System instruction to be applied to the prompt
:param apply_chat_template: bool :param apply_chat_template: Union[bool, str]
If True, apply chat template to the prompt Specifies whether to apply a chat template to the prompt.
- If set to True, the default chat template is applied.
- If set to a string, applies the specified chat template by name.
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool :param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param gen_kwargs: str :param gen_kwargs: str
...@@ -148,6 +157,9 @@ def simple_evaluate( ...@@ -148,6 +157,9 @@ def simple_evaluate(
seed_message.append(f"Setting torch manual seed to {torch_random_seed}") seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed) torch.manual_seed(torch_random_seed)
if fewshot_random_seed is not None:
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
if seed_message: if seed_message:
eval_logger.info(" | ".join(seed_message)) eval_logger.info(" | ".join(seed_message))
...@@ -199,7 +211,9 @@ def simple_evaluate( ...@@ -199,7 +211,9 @@ def simple_evaluate(
) )
else: else:
if not isinstance(model, lm_eval.api.model.LM): if not isinstance(model, lm_eval.api.model.LM):
raise TypeError raise TypeError(
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
)
eval_logger.info("Using pre-initialized model") eval_logger.info("Using pre-initialized model")
lm = model lm = model
...@@ -219,48 +233,58 @@ def simple_evaluate( ...@@ -219,48 +233,58 @@ def simple_evaluate(
task_manager = TaskManager(verbosity) task_manager = TaskManager(verbosity)
task_dict = get_task_dict(tasks, task_manager) task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if isinstance(task_obj, tuple):
_, task_obj = task_obj
if task_obj is None:
continue
if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True
)
if predict_only: # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
log_samples = True # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
eval_logger.info( def _adjust_config(task_dict):
f"Processing {task_name} in output-only mode. Metrics will not be calculated!" adjusted_task_dict = {}
) for task_name, task_obj in task_dict.items():
# we have to change the class properties post-hoc. This is pretty hacky. if isinstance(task_obj, dict):
task_obj.override_metric(metric_name="bypass") adjusted_task_dict = {
**adjusted_task_dict,
# override tasks' fewshot values to the provided num_fewshot arg value **{task_name: _adjust_config(task_obj)},
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that }
if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else: else:
eval_logger.warning( if task_obj.get_config("output_type") == "generate_until":
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" if gen_kwargs is not None:
) task_obj.set_config(
task_obj.set_config(key="num_fewshot", value=num_fewshot) key="generation_kwargs", value=gen_kwargs, update=True
else: )
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None: if predict_only:
task_obj.set_config(key="num_fewshot", value=0) eval_logger.info(
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
task_obj.set_fewshot_seed(seed=fewshot_random_seed) )
eval_logger.info( # we have to change the class properties post-hoc. This is pretty hacky.
f"Setting fewshot random generator seed to {fewshot_random_seed}" task_obj.override_metric(metric_name="bypass")
)
# override tasks' fewshot values to the provided num_fewshot arg value
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else:
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
else:
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (
default_num_fewshot := task_obj.get_config("num_fewshot")
) is None:
task_obj.set_config(key="num_fewshot", value=0)
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
adjusted_task_dict[task_name] = task_obj
return adjusted_task_dict
task_dict = _adjust_config(task_dict)
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -270,7 +294,10 @@ def simple_evaluate( ...@@ -270,7 +294,10 @@ def simple_evaluate(
model_source=model, model_source=model,
model_args=model_args, model_args=model_args,
system_instruction=system_instruction, system_instruction=system_instruction,
chat_template=lm.chat_template if apply_chat_template else None, chat_template=lm.chat_template(apply_chat_template)
if apply_chat_template
else None,
fewshot_as_multiturn=fewshot_as_multiturn,
) )
results = evaluate( results = evaluate(
...@@ -281,7 +308,7 @@ def simple_evaluate( ...@@ -281,7 +308,7 @@ def simple_evaluate(
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
write_out=write_out, write_out=write_out,
log_samples=log_samples, log_samples=True if predict_only else log_samples,
system_instruction=system_instruction, system_instruction=system_instruction,
apply_chat_template=apply_chat_template, apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn, fewshot_as_multiturn=fewshot_as_multiturn,
...@@ -325,6 +352,7 @@ def simple_evaluate( ...@@ -325,6 +352,7 @@ def simple_evaluate(
results["git_hash"] = get_git_commit_hash() results["git_hash"] = get_git_commit_hash()
results["date"] = start_date results["date"] = start_date
add_env_info(results) # additional environment info to results add_env_info(results) # additional environment info to results
add_tokenizer_info(results, lm) # additional info about tokenizer
return results return results
else: else:
return None return None
...@@ -341,7 +369,7 @@ def evaluate( ...@@ -341,7 +369,7 @@ def evaluate(
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
verbosity: str = "INFO", verbosity: str = "INFO",
): ):
...@@ -361,8 +389,11 @@ def evaluate( ...@@ -361,8 +389,11 @@ def evaluate(
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param system_instruction: str :param system_instruction: str
System instruction to be applied to the prompt System instruction to be applied to the prompt
:param apply_chat_template: bool :param apply_chat_template: Union[bool, str]
If True, apply chat template to the prompt Specifies whether to apply a chat template to the prompt.
- If set to True, the default chat template is applied.
- If set to a string, applies the specified chat template by name.
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool :param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:return :return
...@@ -371,6 +402,11 @@ def evaluate( ...@@ -371,6 +402,11 @@ def evaluate(
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
if apply_chat_template:
eval_logger.warning(
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
)
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = defaultdict(list) requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
...@@ -378,16 +414,40 @@ def evaluate( ...@@ -378,16 +414,40 @@ def evaluate(
padding_requests = defaultdict(int) padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request # get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict) eval_tasks = get_task_list(task_dict)
if not log_samples: if not log_samples:
if not all( if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
for task_output in eval_tasks for task_output in eval_tasks
): ):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks") raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
# validation check: are we running multimodal task <-> non-multimodal model class, or vice-versa.
incompatible_tasks = []
for task_output in eval_tasks: for task_output in eval_tasks:
task: Task = task_output.task task: Task = task_output.task
limit = get_sample_size(task, limit)
if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name)
if len(incompatible_tasks) > 0:
if not getattr(lm, "MULTIMODAL", False):
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
)
else:
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end multimodality validation check
# Cache the limit arg.
limit_arg = limit
limits = []
for task_output in eval_tasks:
task: Task = task_output.task
limit = get_sample_size(task, limit_arg)
limits.append(limit)
task.build_all_requests( task.build_all_requests(
limit=limit, limit=limit,
rank=lm.rank, rank=lm.rank,
...@@ -395,9 +455,14 @@ def evaluate( ...@@ -395,9 +455,14 @@ def evaluate(
cache_requests=cache_requests, cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
system_instruction=system_instruction, system_instruction=system_instruction,
apply_chat_template=apply_chat_template, apply_chat_template=bool(apply_chat_template),
fewshot_as_multiturn=fewshot_as_multiturn, fewshot_as_multiturn=fewshot_as_multiturn,
lm=lm, chat_template=getattr(lm, "apply_chat_template")
if apply_chat_template
else None,
tokenizer_name=getattr(lm, "tokenizer_name", "")
if apply_chat_template
else "",
) )
eval_logger.debug( eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
...@@ -452,7 +517,7 @@ def evaluate( ...@@ -452,7 +517,7 @@ def evaluate(
WORLD_SIZE = lm.world_size WORLD_SIZE = lm.world_size
### Postprocess outputs ### ### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_output in eval_tasks: for task_output, limit in zip(eval_tasks, limits):
task = task_output.task task = task_output.task
task.apply_filters() task.apply_filters()
...@@ -487,6 +552,8 @@ def evaluate( ...@@ -487,6 +552,8 @@ def evaluate(
"filtered_resps": [ "filtered_resps": [
req.filtered_resps[filter_key] for req in requests req.filtered_resps[filter_key] for req in requests
], ],
"filter": filter_key,
"metrics": list(metrics.keys()),
"doc_hash": hash_string( "doc_hash": hash_string(
json.dumps( json.dumps(
requests[0].doc, requests[0].doc,
...@@ -550,106 +617,45 @@ def evaluate( ...@@ -550,106 +617,45 @@ def evaluate(
### Calculate group metrics ### ### Calculate group metrics ###
if bool(results): if bool(results):
for group, task_list in reversed(task_hierarchy.items()): results, versions, show_group_table, *_ = consolidate_group_results(
if len(task_list) == 0: results, versions, task_dict
# task_hierarchy entries are either )
# `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`. results_agg, group_agg = prepare_print_tasks(task_dict, results)
# we only want to operate on groups here. subtask_list = get_subtask_list(task_dict)
continue
# collect all higher_is_better values for metrics
# collect all higher_is_better values for metrics # in the group's subtasks.
# in the group's subtasks. # TODO: clean this up ; unify with the below metric_list loop?
# TODO: clean this up ; unify with the below metric_list loop? _higher_is_better = {}
_higher_is_better = {} for group, task_list in subtask_list.items():
if (
len(task_list) != 0
): # subtask list will list "task_name": [] for solo tasks
for task in task_list: for task in task_list:
for m, h in higher_is_better[task].items(): for m, h in higher_is_better[task].items():
if m not in _higher_is_better.keys(): if m not in _higher_is_better.keys():
_higher_is_better[m] = h _higher_is_better[m] = h
if (
m in _higher_is_better
and _higher_is_better[m] is not None
and _higher_is_better[m] != h
):
eval_logger.warning(
f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
)
_higher_is_better[m] = None
higher_is_better[group] = _higher_is_better
# collect all metric keys used by a subtask in the group.
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
# compute group's pooled metric and stderr
results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
else:
results[group][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group]["samples"] = sum(sizes)
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
if len(left_tasks_list) == 0:
break
_task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg} if (
groups_agg = {**groups_agg, **_groups_agg} m in _higher_is_better
and _higher_is_better[m] is not None
for group_name, task_list in task_hierarchy.items(): and _higher_is_better[m] != h
if task_list: ):
num_fewshot[group_name] = num_fewshot[ eval_logger.warning(
task_list[0] f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
] # TODO: validate this )
_higher_is_better[m] = None
higher_is_better[group] = _higher_is_better
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}), **(
"group_subtasks": dict(reversed(task_hierarchy.items())), {"groups": dict(group_agg.items())}
if (bool(group_agg) & show_group_table)
else {}
),
"group_subtasks": dict(reversed(subtask_list.items())),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())), "n-shot": dict(sorted(num_fewshot.items())),
...@@ -662,7 +668,7 @@ def evaluate( ...@@ -662,7 +668,7 @@ def evaluate(
len(task_output.task.eval_docs), len(task_output.task.eval_docs),
), ),
} }
for task_output in eval_tasks for task_output, limit in zip(eval_tasks, limits)
}, },
} }
if log_samples: if log_samples:
......
...@@ -2,9 +2,15 @@ import collections ...@@ -2,9 +2,15 @@ import collections
import math import math
import pathlib import pathlib
import sys import sys
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lm_eval.api import metrics from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import (
aggregate_subtask_metrics,
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import Task
from lm_eval.utils import eval_logger, positional_deprecated from lm_eval.utils import eval_logger, positional_deprecated
...@@ -98,7 +104,7 @@ class TaskOutput: ...@@ -98,7 +104,7 @@ class TaskOutput:
self.agg_metrics[metric_key] = agg_fn(items) self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric? self.sample_len = len(items) # TODO: same sample size for each metric?
if isinstance(bootstrap_iters, int): if isinstance(bootstrap_iters, int):
stderr_fn = metrics.stderr_for_metric( stderr_fn = stderr_for_metric(
metric=agg_fn, metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100) bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"] if metric in ["bleu", "chrf", "ter"]
...@@ -116,23 +122,71 @@ class TaskOutput: ...@@ -116,23 +122,71 @@ class TaskOutput:
return ( return (
f"TaskOutput(task_name={self.task_name}, " f"TaskOutput(task_name={self.task_name}, "
f"group_name={self.group_name}, " f"group_name={self.group_name}, "
f"version={self.version}," f"version={self.version}, "
f"n_shot={self.n_shot}" f"n_shot={self.n_shot}, "
f"task_alias={self.task_alias}, group_alias={self.group_alias})" f"task_alias={self.task_alias}, "
f"group_alias={self.group_alias})"
) )
def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]: def get_task_list(task_dict: dict) -> List[TaskOutput]:
task_hierarchy = collections.defaultdict(list) outputs = []
outputs = list(TaskOutput.from_taskdict(x, y) for x, y in task_dict.items()) for task_name, task_obj in task_dict.items():
for task_output in outputs: if isinstance(task_obj, dict):
if group_name := task_output.group_name: _outputs = get_task_list(task_obj)
task_hierarchy[group_name].append(task_output.task_name) outputs.extend(_outputs)
else: else:
task_hierarchy[task_output.task_name] = [] task_output = TaskOutput.from_taskdict(task_name, task_obj)
# returns task_hierarchy tracking which groups contain which subtasks, outputs.append(task_output)
# and a list of TaskOutput classes for each non-group subtask
return task_hierarchy, [x for x in outputs if x.task] return outputs
def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list = {}
for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup):
# group_name = group_obj.group_name
group_name = group_obj.group_name
else:
group_name = group_obj
if isinstance(task_obj, dict):
_subtask_list = get_subtask_list(
task_obj, task_root=group_name, depth=depth + 1
)
if task_root:
subtask_list.setdefault((task_root, depth), []).extend(
[
_task
for (_task, _depth) in _subtask_list.keys()
if (_depth - 1) == depth
]
)
subtask_list = {**subtask_list, **_subtask_list}
else:
if isinstance(task_obj, ConfigurableGroup):
# group_or_task_name = task_obj.group_name
group_or_task_name = task_obj.group_name
elif isinstance(task_obj, Task):
# group_or_task_name = task_obj.task_name
group_or_task_name = task_obj.task_name
if task_root is None:
subtask_list.setdefault((group_or_task_name, depth), [])
else:
subtask_list.setdefault((task_root, depth), []).append(
group_or_task_name
)
if depth == 0:
_subtask_list = {}
for group_key, task_list in subtask_list.items():
group_name, depth = group_key
_subtask_list[group_name] = task_list
subtask_list = _subtask_list
return subtask_list
def print_writeout(task) -> None: def print_writeout(task) -> None:
...@@ -155,70 +209,95 @@ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]: ...@@ -155,70 +209,95 @@ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
def prepare_print_tasks( def prepare_print_tasks(
task_hierarchy: dict, results: dict, tab=0 task_dict: dict,
results: dict,
task_depth=0,
group_depth=0,
) -> Tuple[dict, dict]: ) -> Tuple[dict, dict]:
""" """
@param task_hierarchy: Dictionary representing the group hierarchy of tasks. Each key is a group name and its @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names. value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a @param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results. group name and its value is a dictionary of task results.
@param tab: The indentation level for printing the task @param task_depth: The indentation level for printing the task
hierarchy. Default is 0.
@param group_depth: The indentation level for printing the group
hierarchy. Default is 0. hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group. aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
""" """
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
(group_name, task_list), *_ = task_hierarchy.items()
task_list = sorted(task_list)
results_agg[group_name] = results[group_name].copy()
# results_agg[group_name]["tab"] = tab
if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
tab_string = " " * tab + "- " if tab > 0 else ""
if "alias" in results_agg[group_name]: def _sort_task_dict(task_dict):
results_agg[group_name]["alias"] = tab_string + results_agg[group_name]["alias"] """
else: Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
results_agg[group_name]["alias"] = tab_string + group_name Required so that we end up sorting within each sub-header correctly.
"""
if len(task_list) > 0:
groups_agg[group_name] = results[group_name].copy() return dict(
# groups_agg[group_name]["tab"] = tab sorted(
if "samples" in groups_agg[group_name]: task_dict.items(),
groups_agg[group_name].pop("samples") key=lambda item: item[0].group_name
if isinstance(item[0], ConfigurableGroup)
if "alias" in groups_agg[group_name]: else item[0],
groups_agg[group_name]["alias"] = (
tab_string + groups_agg[group_name]["alias"]
) )
else: )
groups_agg[group_name]["alias"] = tab_string + group_name
for task_name in task_list: task_agg = collections.defaultdict(dict)
if task_name in task_hierarchy: group_agg = collections.defaultdict(dict)
_task_hierarchy = { task_dict = _sort_task_dict(task_dict)
**{task_name: task_hierarchy[task_name]}, for task_or_group_name, task_or_group_obj in task_dict.items():
**task_hierarchy, tab_string = " " * task_depth + "- " if task_depth > 0 else ""
} if isinstance(task_or_group_name, ConfigurableGroup):
# string_name = task_or_group_name.group_name
name = task_or_group_name.group_name
from_configurable_group = True
task_or_group_obj = _sort_task_dict(task_or_group_obj)
elif isinstance(task_or_group_name, str):
name = task_or_group_name
if isinstance(task_or_group_obj, Task):
# string_name = task_or_group_obj.task_name
name = task_or_group_obj.task_name
from_configurable_group = False
task_agg[name] = results[name].copy()
if from_configurable_group:
if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias
else: else:
_task_hierarchy = { alias = task_or_group_name.group
**{task_name: []}, else:
**task_hierarchy, if "alias" in task_agg[name]:
} alias = task_agg[name]["alias"]
else:
_results_agg, _groups_agg = prepare_print_tasks( alias = name
_task_hierarchy, results, tab + 1
task_agg[name]["alias"] = tab_string + alias
if "samples" in task_agg[name]:
task_agg[name].pop("samples")
if from_configurable_group and (" " not in results[name]):
group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
group_agg[name] = results[name].copy()
group_agg[name]["alias"] = group_tab_string + alias
if "samples" in group_agg[name]:
group_agg[name].pop("samples")
if isinstance(task_or_group_obj, dict):
task_depth += 1
group_depth += 1
_task_agg, _group_agg = prepare_print_tasks(
task_or_group_obj, results, task_depth, group_depth
) )
results_agg = {**results_agg, **_results_agg} task_agg = {
groups_agg = {**groups_agg, **_groups_agg} **task_agg,
**_task_agg,
return results_agg, groups_agg }
group_agg = {**group_agg, **_group_agg}
task_depth -= 1
group_depth -= 1
return task_agg, group_agg
def consolidate_results( def consolidate_results(
...@@ -261,6 +340,8 @@ def consolidate_results( ...@@ -261,6 +340,8 @@ def consolidate_results(
for task_output in eval_tasks: for task_output in eval_tasks:
if "task_alias" in (task_config := task_output.task_config): if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_name]["alias"] = task_config["task_alias"] results[task_output.task_name]["alias"] = task_config["task_alias"]
else:
results[task_output.task_name]["alias"] = task_output.task_name
if group_alias := task_output.group_alias: if group_alias := task_output.group_alias:
if group_alias not in results and (group_name := task_output.group_name): if group_alias not in results and (group_name := task_output.group_name):
results[group_name]["alias"] = group_alias results[group_name]["alias"] = group_alias
...@@ -275,12 +356,153 @@ def consolidate_results( ...@@ -275,12 +356,153 @@ def consolidate_results(
metric_key metric_key
] ]
results[task_output.task_name]["samples"] = task_output.sample_len results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][ results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
f"{metric}_stderr,{filter_key}" task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] )
return results, samples, configs, versions, num_fewshot, higher_is_better return results, samples, configs, versions, num_fewshot, higher_is_better
def consolidate_group_results(
results,
versions,
task_dict,
task_root=None,
show_group_table=False,
task_aggregation_list=None,
) -> Tuple[dict, dict, bool, Union[None,]]:
"""
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
@return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
- results: A defaultdict with task names (and, after this function is called, group names of
groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
- versions: A defaultdict with task names (and, after this function is called, group names of
groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
- show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
- task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored.
"""
if task_root is None:
task_root = {}
if task_aggregation_list is None:
task_aggregation_list = {}
for group_or_task, group_or_task_info in task_dict.items():
# Convert to string
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group_or_task = group_or_task.group_name
else:
group_config = None
if isinstance(group_or_task_info, Task):
if task_root:
task_aggregation_list.setdefault(task_root, []).append(
group_or_task_info.task_name
)
else:
(
results,
versions,
show_group_table,
_task_aggregation_list,
) = consolidate_group_results(
results,
versions,
group_or_task_info,
group_or_task,
show_group_table,
task_aggregation_list,
)
if task_root:
task_aggregation_list.setdefault(task_root, []).extend(
task_aggregation_list.get(group_or_task, [])
)
if (group_config is None) or (
group_config["aggregate_metric_list"] is None
):
results[group_or_task][" "] = " "
continue
if "aggregate_metric_list" in group_config:
agg_metric_list = group_config["aggregate_metric_list"]
show_group_table = show_group_table | bool(
group_config["aggregate_metric_list"]
)
task_list = _task_aggregation_list[group_or_task]
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["task", "alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
for metric_config in agg_metric_list:
for filter_name in metric_config["filter_list"]:
if metric != ",".join([metric_config["metric"], filter_name]):
continue
# compute group's pooled metric and stderr
if metric_config["aggregation"] == "mean":
aggregate_fn = aggregate_subtask_metrics
elif callable(metric_config["aggregation"]):
aggregate_fn = metric_config["aggregation"]
else:
raise ValueError(
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
)
results[group_or_task][metric] = aggregate_fn(
metrics,
sizes,
metric_config["weight_by_size"],
)
# TODO: calculate groups' metrics using arbitrary agg fns
if "N/A" in stderrs:
results[group_or_task][stderr] = "N/A"
else:
# NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
results[group_or_task][stderr] = pooled_sample_stderr(
stderrs, sizes
)
results[group_or_task]["samples"] = sum(sizes)
group_metadata = group_config.get("metadata", None)
if group_metadata is not None:
versions[group_or_task] = group_metadata.get("version", None)
# print(results)
return results, versions, show_group_table, task_aggregation_list
@positional_deprecated @positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path: def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
""" """
......
...@@ -37,16 +37,18 @@ class RegexFilter(Filter): ...@@ -37,16 +37,18 @@ class RegexFilter(Filter):
if match: if match:
match = match[self.group_select] match = match[self.group_select]
if isinstance(match, tuple): if isinstance(match, tuple):
match = [m for m in match if m][0] match = [m for m in match if m]
if match:
match = match[0]
else:
match = self.fallback
match = match.strip() match = match.strip()
else: else:
match = self.fallback match = self.fallback
filtered.append(match) filtered.append(match)
return filtered return filtered
# print(resps)
filtered_resps = list(map(lambda x: filter_set(x), resps)) filtered_resps = list(map(lambda x: filter_set(x), resps))
# print(filtered_resps)
return filtered_resps return filtered_resps
...@@ -62,11 +64,8 @@ class WhitespaceFilter(Filter): ...@@ -62,11 +64,8 @@ class WhitespaceFilter(Filter):
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
if resp.startswith(" "): resp = resp.lstrip()
resp = resp[1:]
filtered_resp.append(resp) filtered_resp.append(resp)
return filtered_resp return filtered_resp
filtered_resps = [filter_set(resp) for resp in resps] filtered_resps = [filter_set(resp) for resp in resps]
...@@ -165,7 +164,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -165,7 +164,7 @@ class MultiChoiceRegexFilter(RegexFilter):
fallback_regex = re.compile("|".join(fallback_regexes)) fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile( without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})" rf":[\s]*({without_paren_fallback_regex})"
) )
filtered = [] filtered = []
......
import json import json
import os
import re import re
import time import time
from collections import defaultdict from collections import defaultdict
...@@ -14,6 +15,7 @@ from huggingface_hub import ( ...@@ -14,6 +15,7 @@ from huggingface_hub import (
HfApi, HfApi,
hf_hub_url, hf_hub_url,
) )
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
...@@ -48,6 +50,7 @@ class GeneralConfigTracker: ...@@ -48,6 +50,7 @@ class GeneralConfigTracker:
model_name_sanitized: str = None model_name_sanitized: str = None
system_instruction: str = None system_instruction: str = None
system_instruction_sha: str = None system_instruction_sha: str = None
fewshot_as_multiturn: bool = None
chat_template: str = None chat_template: str = None
chat_template_sha: str = None chat_template_sha: str = None
start_time: float = None start_time: float = None
...@@ -80,6 +83,7 @@ class GeneralConfigTracker: ...@@ -80,6 +83,7 @@ class GeneralConfigTracker:
model_args: str, model_args: str,
system_instruction: str, system_instruction: str,
chat_template: str, chat_template: str,
fewshot_as_multiturn: bool,
) -> None: ) -> None:
"""Logs model parameters and job ID.""" """Logs model parameters and job ID."""
self.model_source = model_source self.model_source = model_source
...@@ -91,6 +95,7 @@ class GeneralConfigTracker: ...@@ -91,6 +95,7 @@ class GeneralConfigTracker:
) )
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_sha = hash_string(chat_template) if chat_template else None self.chat_template_sha = hash_string(chat_template) if chat_template else None
self.fewshot_as_multiturn = fewshot_as_multiturn
def log_end_time(self) -> None: def log_end_time(self) -> None:
"""Logs the end time of the evaluation and calculates the total evaluation time.""" """Logs the end time of the evaluation and calculates the total evaluation time."""
...@@ -109,12 +114,15 @@ class EvaluationTracker: ...@@ -109,12 +114,15 @@ class EvaluationTracker:
output_path: str = None, output_path: str = None,
hub_results_org: str = "", hub_results_org: str = "",
hub_repo_name: str = "", hub_repo_name: str = "",
details_repo_name: str = "",
results_repo_name: str = "",
push_results_to_hub: bool = False, push_results_to_hub: bool = False,
push_samples_to_hub: bool = False, push_samples_to_hub: bool = False,
public_repo: bool = False, public_repo: bool = False,
token: str = "", token: str = "",
leaderboard_url: str = "", leaderboard_url: str = "",
point_of_contact: str = "", point_of_contact: str = "",
gated: bool = False,
) -> None: ) -> None:
""" """
Creates all the necessary loggers for evaluation tracking. Creates all the necessary loggers for evaluation tracking.
...@@ -123,12 +131,15 @@ class EvaluationTracker: ...@@ -123,12 +131,15 @@ class EvaluationTracker:
output_path (str): Path to save the results. If not provided, the results won't be saved. output_path (str): Path to save the results. If not provided, the results won't be saved.
hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token. hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`. hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
push_results_to_hub (bool): Whether to push the results to the Hugging Face hub. push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub. push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
public_repo (bool): Whether to push the results to a public or private repository. public_repo (bool): Whether to push the results to a public or private repository.
token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`. token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card. leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
point_of_contact (str): Contact information on the Hugging Face hub dataset card. point_of_contact (str): Contact information on the Hugging Face hub dataset card.
gated (bool): Whether to gate the repository.
""" """
self.general_config_tracker = GeneralConfigTracker() self.general_config_tracker = GeneralConfigTracker()
...@@ -139,6 +150,7 @@ class EvaluationTracker: ...@@ -139,6 +150,7 @@ class EvaluationTracker:
self.leaderboard_url = leaderboard_url self.leaderboard_url = leaderboard_url
self.point_of_contact = point_of_contact self.point_of_contact = point_of_contact
self.api = HfApi(token=token) if token else None self.api = HfApi(token=token) if token else None
self.gated_repo = gated
if not self.api and (push_results_to_hub or push_samples_to_hub): if not self.api and (push_results_to_hub or push_samples_to_hub):
raise ValueError( raise ValueError(
...@@ -156,9 +168,24 @@ class EvaluationTracker: ...@@ -156,9 +168,24 @@ class EvaluationTracker:
f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'." f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
) )
hub_repo_name = hub_repo_name if hub_repo_name else "lm-eval-results" if hub_repo_name == "":
self.hub_results_repo = f"{hub_results_org}/{hub_repo_name}" details_repo_name = (
self.hub_results_repo_private = f"{hub_results_org}/{hub_repo_name}-private" details_repo_name if details_repo_name != "" else "lm-eval-results"
)
results_repo_name = (
results_repo_name if results_repo_name != "" else details_repo_name
)
else:
details_repo_name = hub_repo_name
results_repo_name = hub_repo_name
eval_logger.warning(
"hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
)
self.details_repo = f"{hub_results_org}/{details_repo_name}"
self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
self.results_repo = f"{hub_results_org}/{results_repo_name}"
self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
def save_results_aggregated( def save_results_aggregated(
self, self,
...@@ -208,9 +235,9 @@ class EvaluationTracker: ...@@ -208,9 +235,9 @@ class EvaluationTracker:
if self.api and self.push_results_to_hub: if self.api and self.push_results_to_hub:
repo_id = ( repo_id = (
self.hub_results_repo self.results_repo
if self.public_repo if self.public_repo
else self.hub_results_repo_private else self.results_repo_private
) )
self.api.create_repo( self.api.create_repo(
repo_id=repo_id, repo_id=repo_id,
...@@ -218,10 +245,15 @@ class EvaluationTracker: ...@@ -218,10 +245,15 @@ class EvaluationTracker:
private=not self.public_repo, private=not self.public_repo,
exist_ok=True, exist_ok=True,
) )
self.api.upload_folder( self.api.upload_file(
repo_id=repo_id, repo_id=repo_id,
folder_path=str(path), path_or_fileobj=str(
path_in_repo=self.general_config_tracker.model_name_sanitized, path.joinpath(f"results_{self.date_id}.json")
),
path_in_repo=os.path.join(
self.general_config_tracker.model_name,
f"results_{self.date_id}.json",
),
repo_type="dataset", repo_type="dataset",
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}", commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
) )
...@@ -275,6 +307,7 @@ class EvaluationTracker: ...@@ -275,6 +307,7 @@ class EvaluationTracker:
sample["resps"] = sanitize_list(sample["resps"]) sample["resps"] = sanitize_list(sample["resps"])
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"]) sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
sample["arguments"] = arguments sample["arguments"] = arguments
sample["target"] = str(sample["target"])
sample_dump = ( sample_dump = (
json.dumps( json.dumps(
...@@ -285,14 +318,14 @@ class EvaluationTracker: ...@@ -285,14 +318,14 @@ class EvaluationTracker:
+ "\n" + "\n"
) )
with open(file_results_samples, "a") as f: with open(file_results_samples, "a", encoding="utf-8") as f:
f.write(sample_dump) f.write(sample_dump)
if self.api and self.push_samples_to_hub: if self.api and self.push_samples_to_hub:
repo_id = ( repo_id = (
self.hub_results_repo self.details_repo
if self.public_repo if self.public_repo
else self.hub_results_repo_private else self.details_repo_private
) )
self.api.create_repo( self.api.create_repo(
repo_id=repo_id, repo_id=repo_id,
...@@ -300,6 +333,18 @@ class EvaluationTracker: ...@@ -300,6 +333,18 @@ class EvaluationTracker:
private=not self.public_repo, private=not self.public_repo,
exist_ok=True, exist_ok=True,
) )
try:
if self.gated_repo:
headers = build_hf_headers()
r = get_session().put(
url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
headers=headers,
json={"gated": "auto"},
)
hf_raise_for_status(r)
except Exception as e:
eval_logger.warning("Could not gate the repository")
eval_logger.info(repr(e))
self.api.upload_folder( self.api.upload_folder(
repo_id=repo_id, repo_id=repo_id,
folder_path=str(path), folder_path=str(path),
...@@ -324,9 +369,7 @@ class EvaluationTracker: ...@@ -324,9 +369,7 @@ class EvaluationTracker:
""" """
eval_logger.info("Recreating metadata card") eval_logger.info("Recreating metadata card")
repo_id = ( repo_id = self.details_repo if self.public_repo else self.details_repo_private
self.hub_results_repo if self.public_repo else self.hub_results_repo_private
)
files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset") files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
results_files = get_results_filenames(files_in_repo) results_files = get_results_filenames(files_in_repo)
...@@ -357,7 +400,10 @@ class EvaluationTracker: ...@@ -357,7 +400,10 @@ class EvaluationTracker:
results_datetime, results_datetime,
) )
latest_task_results_datetime[samples_key] = latest_datetime latest_task_results_datetime[samples_key] = latest_datetime
latest_task_results_datetime[results_key] = latest_datetime latest_task_results_datetime[results_key] = max(
latest_task_results_datetime[results_key],
latest_datetime,
)
# Create metadata card # Create metadata card
card_metadata = MetadataConfigs() card_metadata = MetadataConfigs()
...@@ -374,14 +420,15 @@ class EvaluationTracker: ...@@ -374,14 +420,15 @@ class EvaluationTracker:
sanitized_last_eval_date_results = re.sub( sanitized_last_eval_date_results = re.sub(
r"[^\w\.]", "_", latest_task_results_datetime[config_name] r"[^\w\.]", "_", latest_task_results_datetime[config_name]
) )
# Ensure that all results files are listed in the metadata card
current_results = card_metadata.get(config_name, {"data_files": []})
current_results["data_files"].append(
{"split": eval_date_sanitized, "path": [str(results_filename)]}
)
card_metadata[config_name] = current_results
# If the results file is the newest, update the "latest" field in the metadata card
if eval_date_sanitized == sanitized_last_eval_date_results: if eval_date_sanitized == sanitized_last_eval_date_results:
# Ensure that all results files are listed in the metadata card
current_results = card_metadata.get(config_name, {"data_files": []})
current_results["data_files"].append(
{"split": eval_date_sanitized, "path": [str(results_filename)]}
)
card_metadata[config_name] = current_results
# If the results file is the newest, update the "latest" field in the metadata card
card_metadata[config_name]["data_files"].append( card_metadata[config_name]["data_files"].append(
{"split": "latest", "path": [str(results_filename)]} {"split": "latest", "path": [str(results_filename)]}
) )
...@@ -400,65 +447,20 @@ class EvaluationTracker: ...@@ -400,65 +447,20 @@ class EvaluationTracker:
sanitized_last_eval_date_results = re.sub( sanitized_last_eval_date_results = re.sub(
r"[^\w\.]", "_", latest_task_results_datetime[config_name] r"[^\w\.]", "_", latest_task_results_datetime[config_name]
) )
# Ensure that all sample results files are listed in the metadata card
current_details_for_task = card_metadata.get(
config_name, {"data_files": []}
)
current_details_for_task["data_files"].append(
{"split": eval_date_sanitized, "path": [str(results_filename)]}
)
card_metadata[config_name] = current_details_for_task
# If the samples results file is the newest, update the "latest" field in the metadata card
if eval_date_sanitized == sanitized_last_eval_date_results: if eval_date_sanitized == sanitized_last_eval_date_results:
# Ensure that all sample results files are listed in the metadata card
current_details_for_task = card_metadata.get(
config_name, {"data_files": []}
)
current_details_for_task["data_files"].append(
{"split": eval_date_sanitized, "path": [str(results_filename)]}
)
card_metadata[config_name] = current_details_for_task
# If the samples results file is the newest, update the "latest" field in the metadata card
card_metadata[config_name]["data_files"].append( card_metadata[config_name]["data_files"].append(
{"split": "latest", "path": [str(results_filename)]} {"split": "latest", "path": [str(results_filename)]}
) )
# Special case for MMLU with a single split covering it all
# We add another config with all MMLU splits results together for easy inspection
SPECIAL_TASKS = ["mmlu", "gpqa", "minerva_math"]
for special_task in SPECIAL_TASKS:
if special_task in config_name:
special_task = f"{model_name}__{special_task}"
former_entry = card_metadata.get(special_task, {"data_files": []})
former_split = [
(i, entry)
for i, entry in enumerate(former_entry["data_files"])
if entry.get("split", None) == eval_date_sanitized
]
if len(former_split) == 0:
former_entry["data_files"].append(
{
"split": eval_date_sanitized,
"path": [str(results_filename)],
}
)
else:
split_index, _ = former_split[0]
former_entry["data_files"][split_index]["path"].append(
str(results_filename)
)
if eval_date_sanitized == sanitized_last_eval_date_results:
latest_split = [
(i, entry)
for i, entry in enumerate(former_entry["data_files"])
if entry.get("split", None) == "latest"
]
if len(latest_split) == 0:
former_entry["data_files"].append(
{"split": "latest", "path": [str(results_filename)]}
)
else:
latest_index, _ = latest_split[0]
former_entry["data_files"][latest_index]["path"].append(
str(results_filename)
)
card_metadata[special_task] = former_entry
# Get latest results and extract info to update metadata card examples # Get latest results and extract info to update metadata card examples
latest_datetime = max(latest_task_results_datetime.values()) latest_datetime = max(latest_task_results_datetime.values())
latest_model_name = max( latest_model_name = max(
......
...@@ -110,3 +110,34 @@ def add_env_info(storage: Dict[str, Any]): ...@@ -110,3 +110,34 @@ def add_env_info(storage: Dict[str, Any]):
"upper_git_hash": upper_dir_commit, # in case this repo is submodule "upper_git_hash": upper_dir_commit, # in case this repo is submodule
} }
storage.update(added_info) storage.update(added_info)
def add_tokenizer_info(storage: Dict[str, Any], lm):
if getattr(lm, "tokenizer", False):
try:
tokenizer_info = {
"tokenizer_pad_token": [
lm.tokenizer.pad_token,
str(lm.tokenizer.pad_token_id),
],
"tokenizer_eos_token": [
lm.tokenizer.eos_token,
str(lm.tokenizer.eos_token_id),
],
"tokenizer_bos_token": [
lm.tokenizer.bos_token,
str(lm.tokenizer.bos_token_id),
],
"eot_token_id": getattr(lm, "eot_token_id", None),
"max_length": getattr(lm, "max_length", None),
}
storage.update(tokenizer_info)
except Exception as err:
logger.debug(
f"Logging detailed tokenizer info failed with {err}, skipping..."
)
# seems gguf and textsynth do not have tokenizer
else:
logger.debug(
"LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
)
...@@ -15,10 +15,9 @@ logger = logging.getLogger(__name__) ...@@ -15,10 +15,9 @@ logger = logging.getLogger(__name__)
def get_wandb_printer() -> Literal["Printer"]: def get_wandb_printer() -> Literal["Printer"]:
"""Returns a wandb printer instance for pretty stdout.""" """Returns a wandb printer instance for pretty stdout."""
from wandb.sdk.lib.printer import get_printer from wandb.sdk.lib.printer import new_printer
from wandb.sdk.wandb_settings import Settings
printer = get_printer(Settings()._jupyter) printer = new_printer()
return printer return printer
...@@ -49,6 +48,9 @@ class WandbLogger: ...@@ -49,6 +48,9 @@ class WandbLogger:
self.wandb_args: Dict[str, Any] = kwargs self.wandb_args: Dict[str, Any] = kwargs
# pop the step key from the args to save for all logging calls
self.step = self.wandb_args.pop("step", None)
# initialize a W&B run # initialize a W&B run
if wandb.run is None: if wandb.run is None:
self.run = wandb.init(**self.wandb_args) self.run = wandb.init(**self.wandb_args)
...@@ -153,11 +155,11 @@ class WandbLogger: ...@@ -153,11 +155,11 @@ class WandbLogger:
# log the complete eval result to W&B Table # log the complete eval result to W&B Table
table = make_table(["Tasks"] + columns, "results") table = make_table(["Tasks"] + columns, "results")
self.run.log({"evaluation/eval_results": table}) self.run.log({"evaluation/eval_results": table}, step=self.step)
if "groups" in self.results.keys(): if "groups" in self.results.keys():
table = make_table(["Groups"] + columns, "groups") table = make_table(["Groups"] + columns, "groups")
self.run.log({"evaluation/group_eval_results": table}) self.run.log({"evaluation/group_eval_results": table}, step=self.step)
def _log_results_as_artifact(self) -> None: def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B.""" """Log results as JSON artifact to W&B."""
...@@ -175,13 +177,13 @@ class WandbLogger: ...@@ -175,13 +177,13 @@ class WandbLogger:
"""Log evaluation results to W&B.""" """Log evaluation results to W&B."""
# Log configs to wandb # Log configs to wandb
configs = self._get_config() configs = self._get_config()
self.run.config.update(configs) self.run.config.update(configs, allow_val_change=self.step is not None)
wandb_summary, self.wandb_results = self._sanitize_results_dict() wandb_summary, self.wandb_results = self._sanitize_results_dict()
# update wandb.run.summary with items that were removed # update wandb.run.summary with items that were removed
self.run.summary.update(wandb_summary) self.run.summary.update(wandb_summary)
# Log the evaluation metrics to wandb # Log the evaluation metrics to wandb
self.run.log(self.wandb_results) self.run.log(self.wandb_results, step=self.step)
# Log the evaluation metrics as W&B Table # Log the evaluation metrics as W&B Table
self._log_results_as_table() self._log_results_as_table()
# Log the results dict as json to W&B Artifacts # Log the results dict as json to W&B Artifacts
...@@ -330,7 +332,7 @@ class WandbLogger: ...@@ -330,7 +332,7 @@ class WandbLogger:
# log the samples as a W&B Table # log the samples as a W&B Table
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name)) df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
self.run.log({f"{task_name}_eval_results": df}) self.run.log({f"{task_name}_eval_results": df}, step=self.step)
# log the samples as a json file as W&B Artifact # log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
...@@ -349,4 +351,4 @@ class WandbLogger: ...@@ -349,4 +351,4 @@ class WandbLogger:
# log the samples as a json file as W&B Artifact # log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
self.run.log({f"{group}_eval_results": grouped_df}) self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
from . import ( from . import (
anthropic_llms, anthropic_llms,
api_models,
dummy, dummy,
gguf, gguf,
hf_vlms,
huggingface, huggingface,
ibm_watsonx_ai,
mamba_lm, mamba_lm,
nemo_lm, nemo_lm,
neuralmagic, neuralmagic,
neuron_optimum, neuron_optimum,
openai_completions, openai_completions,
optimum_ipex,
optimum_lm, optimum_lm,
textsynth, textsynth,
vllm_causallms, vllm_causallms,
vllm_vlms,
) )
......
from typing import Any, List, Tuple import os
from functools import cached_property
from typing import Any, Dict, List, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import retry_on_specific_exceptions from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -42,8 +45,8 @@ def anthropic_completion( ...@@ -42,8 +45,8 @@ def anthropic_completion(
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -105,8 +108,8 @@ def anthropic_chat( ...@@ -105,8 +108,8 @@ def anthropic_chat(
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -138,7 +141,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -138,7 +141,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
return messages() return messages()
@register_model("anthropic") @register_model("anthropic-completions")
class AnthropicLM(LM): class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 # TODO: not used REQ_CHUNK_SIZE = 20 # TODO: not used
...@@ -165,8 +168,8 @@ class AnthropicLM(LM): ...@@ -165,8 +168,8 @@ class AnthropicLM(LM):
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -214,8 +217,8 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -214,8 +217,8 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -271,90 +274,94 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -271,90 +274,94 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
@register_model("anthropic-chat", "anthropic-chat-completions") @register_model("anthropic-chat", "anthropic-chat-completions")
class AnthropicChatLM(AnthropicLM): class AnthropicChat(LocalCompletionsAPI):
REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__( def __init__(
self, self,
model: str, base_url="https://api.anthropic.com/v1/messages",
batch_size: int = 1, tokenizer_backend=None,
max_tokens: int = 256, **kwargs,
temperature: float = 0, # defaults to 1 ):
**kwargs, # top_p, top_k, etc. super().__init__(
) -> None: base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
"""Anthropic API wrapper. )
eval_logger.warning(
:param model: str "Chat completions does not support batching. Defaulting to batch size 1."
Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229' )
:param max_tokens: int self._batch_size = 1
Maximum number of tokens to sample from the model self.anthropic_version = "2023-06-01"
:param temperature: float eval_logger.warning(
Sampling temperature f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
:param kwargs: Any )
Additional model_args to pass to the API client
"""
super().__init__()
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
)
self.model = model
# defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens = max_tokens
self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
@property
def max_gen_toks(self) -> int:
return self.max_tokens
def generate_until(self, requests) -> List[str]: @cached_property
try: def api_key(self):
import anthropic """Override this property to return the API key for the API request."""
except ModuleNotFoundError: key = os.environ.get("ANTHROPIC_API_KEY", None)
raise Exception( if key is None:
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ raise ValueError(
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", "API key not found. Please set the ANTHROPIC_API_KEY environment variable."
) )
return key
if not requests: @cached_property
return [] def header(self):
return {
_requests: List[Tuple[str, dict]] = [req.args for req in requests] "x-api-key": f"{self.api_key}",
"anthropic-version": self.anthropic_version,
}
def _create_payload(
self,
messages: List[Dict],
generate=True,
gen_kwargs: dict = None,
eos="\n\nHuman:",
**kwargs,
) -> dict:
system = (
messages[0].get("content") if messages[0].get("role") == "system" else None
)
if system:
messages = messages[1:]
gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
if not isinstance(stop, list):
stop = [stop]
out = {
"messages": messages,
"model": self.model,
"max_tokens": max_tokens,
"temperature": temperature,
"stop_sequences": stop,
**gen_kwargs,
}
if system:
out["system"] = system
return out
def parse_generations(
self, outputs: Union[Dict, List[Dict]], **kwargs
) -> List[str]:
res = [] res = []
for request in tqdm(_requests): if not isinstance(outputs, list):
try: outputs = [outputs]
inp = request[0] for out in outputs:
request_args = request[1] for choices in out["content"]:
# generation_kwargs res.append(choices["text"])
until = request_args.get("until")
max_tokens = request_args.get("max_gen_toks", self.max_length)
temperature = request_args.get("temperature", self.temperature)
response = anthropic_chat(
client=self.client,
model=self.model,
prompt=inp,
max_tokens=max_tokens,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, # type: ignore
**self.kwargs,
)
res.append(response)
self.cache_hook.add_partial("generate_until", request, response)
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e: # type: ignore # noqa: F821
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
return res return res
def tok_encode(
self,
string: str,
left_truncate_len=None,
add_special_tokens=None,
**kwargs,
) -> List[str]:
return [string]
def loglikelihood(self, requests, **kwargs):
raise NotImplementedError(
"Anthropic Chat Completions API does not support the return of loglikelihood"
)
import abc
import asyncio
import copy
import itertools
import json
from functools import cached_property
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Tuple,
Union,
)
try:
import requests
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
except ModuleNotFoundError:
pass
from importlib.util import find_spec
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
# utility class to keep track of json encoded chats
class JsonChatStr(NamedTuple):
prompt: str
def encode(self, encoding):
return self.prompt.encode(encoding)
eval_logger = utils.eval_logger
class TemplateAPI(TemplateLM):
def __init__(
self,
model: str = None,
pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
base_url: str = None,
tokenizer: Optional[str] = None,
# Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False
tokenizer_backend: Optional[
Literal["tiktoken", "huggingface", "None", "none"]
] = "huggingface",
truncate: bool = False,
# number of concurrent requests. More useful if not batching
num_concurrent: int = 1,
max_retries: int = 3,
max_gen_toks: int = 256,
batch_size: Union[str, int] = 1,
seed: int = 1234,
max_length: Optional[int] = 2048,
add_bos_token: bool = False,
custom_prefix_token_id: int = None,
# send the requests as tokens or strings
tokenized_requests: bool = True,
trust_remote_code: bool = False,
revision: Optional[str] = "main",
use_fast_tokenizer: bool = True,
verify_certificate: bool = True,
eos_string: str = None,
# timeout in seconds
timeout: int = 300,
**kwargs,
) -> None:
super().__init__()
missing_packages = [
pkg
for pkg in ["aiohttp", "tqdm", "tenacity", "requests"]
if find_spec(pkg) is None
]
if missing_packages:
raise ModuleNotFoundError(
f"Attempted to use an API model, but the required packages {missing_packages} are not installed. "
'Please install these via `pip install lm-eval[api]` or `pip install -e ."[api]"`'
)
self.model = model or pretrained
self.base_url = base_url
self.tokenizer = tokenizer
if not isinstance(batch_size, int) and "auto" in batch_size:
eval_logger.warning(
"Automatic batch size is not supported for API models. Defaulting to batch size 1."
)
elif int(batch_size) > 1:
eval_logger.warning(
"Batch size > 1 detected. Ensure your API supports batched requests with varying total sequence lengths."
)
self._batch_size = int(batch_size) if batch_size != "auto" else 1
self._truncate = truncate
self._max_gen_toks = int(max_gen_toks)
self._seed = int(seed)
# max_length - 1 as we always have 1 token for generation
eval_logger.info(f"Using max length {max_length} - 1")
self.max_length = max_length - 1
if int(num_concurrent) <= 1:
eval_logger.info(
"Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1."
)
self._concurrent = int(num_concurrent)
self.tokenizer_backend = (
None if tokenizer_backend in ("None", "none") else tokenizer_backend
)
self.add_bos_token = add_bos_token
self.custom_prefix_token_id = custom_prefix_token_id
self.tokenized_requests = tokenized_requests
self.max_retries = int(max_retries)
self.verify_certificate = verify_certificate
self._eos_string = eos_string
self.timeout = int(timeout)
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None:
self.tokenizer = None
self.tokenized_requests = False
else:
if self.tokenizer is None:
if self.tokenizer_backend == "huggingface":
import transformers
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
self.tokenizer if self.tokenizer else self.model,
trust_remote_code=trust_remote_code,
revision=revision,
use_fast=use_fast_tokenizer,
)
# Not used as the API will handle padding but to mirror the behavior of the HFLM
self.tokenizer = configure_pad_token(self.tokenizer)
elif self.tokenizer_backend == "tiktoken":
try:
import tiktoken
self.tokenizer = tiktoken.encoding_for_model(self.model)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. "
"Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`."
) from e
if "openai" not in self.base_url:
eval_logger.warning(
f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
)
else:
import transformers
assert isinstance(tokenizer, str), "tokenizer must be a string"
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer,
trust_remote_code=trust_remote_code,
revision=revision,
use_fast=use_fast_tokenizer,
)
@abc.abstractmethod
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
*,
generate: bool = True,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
eos: str = None,
**kwargs,
) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API."""
raise NotImplementedError
def create_message(
self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
generate=False,
) -> Union[List[List[int]], List[dict], List[str], str]:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...]
assert (
self._batch_size == 1
), "non-tokenized chat requests are only supported with batch_size=1"
# list[dict["role":..., "content":...],...]
return json.loads(messages[0].prompt)
if not self.tokenized_requests:
# if messages are tokenized:
if isinstance(messages[0][0], int):
# assuming decoding is lossless. However, this is only for loglikelihood requests
# as we need to compute the context length. For generations, we don't need to tokenize.
messages = self.decode_batch(messages)
if self._batch_size <= 1:
# if batch is 1 return str
return messages[0]
else:
# list[str,...]
return messages
# list[list[int], ...]
return messages
@staticmethod
@abc.abstractmethod
def parse_logprobs(
outputs: Union[Any, List[Any]],
tokens: List[List[int]] = None,
ctxlen: List[int] = None,
**kwargs,
) -> List[Tuple[float, bool]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise NotImplementedError
@cached_property
def api_key(self) -> str:
"""Override this property to return the API key for the API request."""
return ""
@cached_property
def header(self) -> dict:
"""Override this property to return the headers for the API request."""
return {"Authorization": f"Bearer {self.api_key}"}
@property
def tokenizer_name(self) -> str:
"""Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
"""
return ""
def apply_chat_template(
self, chat_history: List[Dict[str, str]]
) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
)
else:
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(json.dumps(chat_history))
@cached_property
def eot_token_id(self) -> Optional[int]:
if self.tokenizer is None:
return None
else:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token_id
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token
@cached_property
def eos_string(self) -> Optional[str]:
if self._eos_string:
return self._eos_string
elif self.tokenizer is not None:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token])
else:
eval_logger.warning(
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
)
return None
@cached_property
def prefix_token_id(self) -> Optional[int]:
if self.tokenizer is None:
return None
else:
if self.custom_prefix_token_id is not None:
return self.custom_prefix_token_id
if self.tokenizer_backend == "huggingface":
if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id
else:
return self.tokenizer.eot_token
def tok_encode(
self,
string: str,
left_truncate_len: int = None,
add_special_tokens: bool = False,
truncation: bool = False,
**kwargs,
) -> Union[List[List[int]], List[int], List[str]]:
if self.tokenizer_backend is None:
return [string]
elif self.tokenizer_backend == "huggingface":
# by default for CausalLM - false or self.add_bos_token is set
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
return_attention_mask=False,
).input_ids
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
if not isinstance(string, str):
encoding = [enc[-left_truncate_len:] for enc in encoding]
else:
encoding = encoding[-left_truncate_len:]
return encoding
else:
try:
encoding = self.tokenizer.encode(string)
except Exception:
encoding = self.tokenizer.encode_batch(string)
return encoding
def decode_batch(self, tokens: List[List[int]]) -> List[str]:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.batch_decode(tokens)
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode_batch(tokens)
def model_call(
self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
*,
generate: bool = True,
gen_kwargs: Optional[Dict] = None,
**kwargs,
) -> Optional[dict]:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs)
try:
response = requests.post(
self.base_url,
json=self._create_payload(
self.create_message(messages),
generate=generate,
gen_kwargs=gen_kwargs,
seed=self._seed,
eos=self.eos_string,
**kwargs,
),
headers=self.header,
verify=self.verify_certificate,
)
if not response.ok:
eval_logger.warning(
f"API request failed with error message: {response.text}. Retrying..."
)
response.raise_for_status()
return response.json()
except RetryError:
eval_logger.error(
"API request failed after multiple retries. Please check the API status."
)
return None
async def amodel_call(
self,
session: ClientSession,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
*,
generate: bool = True,
cache_keys: list = None,
ctxlens: Optional[List[int]] = None,
gen_kwargs: Optional[Dict] = None,
**kwargs,
) -> Union[List[str], List[Tuple[float, bool]], None]:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs)
payload = self._create_payload(
self.create_message(messages),
generate=generate,
gen_kwargs=gen_kwargs,
seed=self._seed,
**kwargs,
)
cache_method = "generate_until" if generate else "loglikelihood"
try:
async with session.post(
self.base_url,
json=payload,
headers=self.header,
) as response:
if not response.ok:
error_text = await response.text()
eval_logger.warning(
f"API request failed with error message: {error_text}. Retrying..."
)
# raising exception will retry the request
response.raise_for_status()
outputs = await response.json()
answers = (
self.parse_generations(
outputs=outputs,
)
if generate
else self.parse_logprobs(
outputs=outputs,
tokens=messages,
ctxlens=ctxlens,
)
)
if cache_keys:
for res, cache in zip(answers, cache_keys):
self.cache_hook.add_partial(cache_method, cache, res)
return answers
# If the retries also fail
except RetryError:
eval_logger.error(
"API request failed after multiple retries. Please check the API status."
)
return None
def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]]
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
inputs = []
ctxlens = []
cache_keys = []
for chunk in chunks:
for cache_key, context_enc, continuation_enc in chunk:
# max_length - 1 as we always have 1 token for generation
inp = (context_enc + continuation_enc)[-self.max_length :]
if len(inp) < len(context_enc + continuation_enc):
eval_logger.warning(
f"Context length ({len(context_enc)}) + continuation length ({len(continuation_enc)}) > max_length ({self.max_length}). Left truncating context."
)
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - self.max_length
)
inputs.append(inp)
ctxlens.append(ctxlen)
cache_keys.append(cache_key)
return inputs, ctxlens, cache_keys
async def get_batched_requests(
self,
requests: list,
cache_keys: list,
*,
generate: bool = True,
ctxlens: List[int] = None,
**kwargs,
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent)
async with ClientSession(
connector=conn, timeout=ClientTimeout(total=self.timeout)
) as session:
retry_: Callable[..., Awaitable[Any]] = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.amodel_call)
# Create tasks for each batch of request
tasks = [
asyncio.create_task(
retry_(
session=session,
messages=message,
cache_keys=cache_key,
generate=generate,
ctxlens=ctxlen,
**kwargs,
)
)
for message, cache_key, ctxlen in zip(
chunks(requests, n=self._batch_size),
chunks(cache_keys, n=self._batch_size),
chunks(ctxlens, n=self._batch_size),
)
]
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
assert (
self.tokenizer is not None
), "Tokenizer is required for loglikelihood tasks to compute context lengths."
res = []
def _collate(req: LogLikelihoodInputs):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = req[1] + req[2]
return -len(toks), tuple(toks)
re_ord = Collator(
requests,
sort_fn=_collate,
group_by=None,
)
# if concurrent then we'll batch in the async context
chunked = re_ord.get_batched(n=self._batch_size if self._concurrent <= 1 else 0)
if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests([chunk])
outputs = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.model_call)(messages=inputs, generate=False)
if isinstance(outputs, dict):
outputs = [outputs]
for answer_, cache_key in zip(
self.parse_logprobs(
outputs=outputs, tokens=inputs, ctxlens=ctxlens
),
cache_keys,
):
if answer_ is not None:
res.append(answer_)
# cache requests that aren't from a loglikelihood_rolling request
if cache_key is not None:
self.cache_hook.add_partial(
"loglikelihood", cache_key, answer_
)
pbar.update(1)
else:
inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests(chunked)
res = itertools.chain.from_iterable(
asyncio.run(
self.get_batched_requests(
inputs, cache_keys, generate=False, ctxlens=ctxlens
)
)
)
return re_ord.get_original(res)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
def _collate_gen(_requests):
# sort by the length of the non-tokenized contexts
return -len(_requests[0])
# Let the API deal with tokenization
requests, all_gen_kwargs = zip(*(req.args for req in requests))
if self.tokenized_requests:
encodings_list = self.tok_encode(
requests, add_special_tokens=self.add_bos_token
)
else:
encodings_list = [None] * len(requests)
requests = [
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
]
re_ord = Collator(
requests,
sort_fn=_collate_gen,
group_by="gen_kwargs",
)
chunked = re_ord.get_batched(
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
)
if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
if self.tokenized_requests:
max_gen_toks = all_gen_kwargs[0].get(
"max_gen_toks", self._max_gen_toks
)
max_context_len = self.max_length - max_gen_toks
encodings_list = [x[-max_context_len:] for x in encodings_list]
if any(
len(x) + max_gen_toks > self.max_length for x in encodings_list
):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts
outputs = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.model_call)(
messages=req,
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
for generated_text, context in zip(
self.parse_generations(
outputs=outputs,
contexts=contexts,
),
contexts,
):
if generated_text is not None:
res.append(generated_text)
# partial caching
if context is not None:
self.cache_hook.add_partial(
"generate_until",
(context, all_gen_kwargs[0]),
generated_text,
)
pbar.update(1)
else:
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
if self.tokenized_requests:
max_gen_toks = all_gen_kwargs[0].get(
"max_gen_toks", self._max_gen_toks
)
max_context_len = self.max_length - max_gen_toks
encodings_list = [x[-max_context_len:] for x in encodings_list]
if any(
len(x) + max_gen_toks > self.max_length for x in encodings_list
):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable(
asyncio.run(
self.get_batched_requests(
req,
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
)
)
res.extend(results)
return re_ord.get_original(res)
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.prefix_token_id,
# max_seq_len - (1 for context)
max_seq_len=self.max_length - 1,
context_len=1,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
disable_tqdm=True,
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods
...@@ -26,9 +26,9 @@ class DummyLM(LM): ...@@ -26,9 +26,9 @@ class DummyLM(LM):
def generate_until(self, requests, disable_tqdm: bool = False): def generate_until(self, requests, disable_tqdm: bool = False):
res = [] res = []
for ctx, _ in tqdm(requests, disable=disable_tqdm): for request in tqdm(requests, disable=disable_tqdm):
res.append("lol") res.append("lol")
assert ctx.strip() != "" assert request.arguments[0].strip() != ""
return res return res
......
...@@ -68,7 +68,9 @@ class GGUFLM(LM): ...@@ -68,7 +68,9 @@ class GGUFLM(LM):
logger.error(f"RequestException: {e}") logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying time.sleep(delay) # wait before retrying
else: else:
raise Exception(f"Failed to get a valid response after {retries} retries.") raise RuntimeError(
f"Failed to get a valid response after {retries} retries."
)
def loglikelihood(self, requests, disable_tqdm: bool = False): def loglikelihood(self, requests, disable_tqdm: bool = False):
if not requests: if not requests:
......
import copy
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from tqdm import tqdm
from transformers import BatchEncoding
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import (
Collator,
flatten_image_list,
handle_stop_sequences,
pad_and_concat,
replace_placeholders,
stop_sequences_criteria,
)
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
eval_logger = utils.eval_logger
@register_model("hf-multimodal")
class HFMultimodalLM(HFLM):
"""
An abstracted Hugging Face model class for multimodal LMs like Llava and Idefics.
"""
AUTO_MODEL_CLASS = transformers.AutoModelForVision2Seq
MULTIMODAL = True # flag to indicate, for now, that this model type can run multimodal requests
def __init__(
self,
pretrained: Union[str, transformers.PreTrainedModel],
image_token_id: Optional[int] = None,
image_string: Optional[str] = None,
interleave: bool = True,
# TODO: handle whitespace in image placeholder (replacement)
max_images: Optional[int] = 999,
convert_img_format=False,
**kwargs,
):
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
# modify init behavior.
super().__init__(pretrained, **kwargs)
assert (
self.batch_size != "auto"
), "Batch size 'auto' is not yet supported for hf-multimodal models."
self.chat_applied: bool = False
# TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case
# HF AutoModelForVision2Seq models have an `image_token_id` value in their configs
# denoting the token which indicates a location where an image will be substituted in.
# This can take different string values across models, e.g. <image> for Idefics2 and <|image_pad|> for Qwen2-VL
self.interleave = interleave
self.max_images = max_images
self.rgb = convert_img_format
# WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors!
if not image_string:
self.image_token_id = (
int(image_token_id)
if image_token_id
else (
getattr(self.config, "image_token_id", None)
or getattr(self.config, "image_token_index", None)
)
)
assert (
self.image_token_id is not None
), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
# get the string this token ID corresponds to
self.image_token = self.tok_decode(
[self.image_token_id], skip_special_tokens=False
)
if image_token_id is not None:
eval_logger.info(
f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!"
)
else:
eval_logger.info(
f"A non-default image_token string with string value image_string='{image_string}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!"
)
self.image_token = image_string
def _create_tokenizer(
self,
pretrained: Union[str, transformers.PreTrainedModel],
tokenizer: Optional[
Union[
str,
transformers.ProcessorMixin,
]
],
revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False,
**kwargs,
) -> None:
"""
Helper method during initialization.
For the multimodal variant, we initialize not just
`self.tokenizer` but also `self.processor`.
"""
if tokenizer:
if isinstance(tokenizer, str):
return transformers.AutoProcessor.from_pretrained(
tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
# use_fast=use_fast_tokenizer,
)
else:
assert isinstance(
tokenizer, transformers.ProcessorMixin
) # TODO: check this condition
return tokenizer
# Get tokenizer based on 'pretrained'
if isinstance(pretrained, str):
model_name = pretrained
else:
# get the HF hub name via accessor on model
model_name = self.model.name_or_path
self.processor = transformers.AutoProcessor.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
# use_fast=use_fast_tokenizer,
)
self.tokenizer = self.processor.tokenizer
def tok_multimodal_encode(
self, string, images, left_truncate_len=None, add_special_tokens=None
):
"""Helper function which encodes an image + string combo using AutoProcessor"""
# We inherit special token kwarg setup from HFLM.tok_encode
# special_tokens_kwargs = {}
# by default for CausalLM - false or self.add_bos_token is set
# if add_special_tokens is None:
# special_tokens_kwargs = {"add_special_tokens": False or self.add_bos_token}
# otherwise the method explicitly defines the value
# else:
# special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
# encode text+images
# TODO: why does (Qwen2-VL) processor error when attempting to add special tokens to text?
encoding = self.processor(
text=string, images=images, return_tensors=None
) # , **special_tokens_kwargs)
# remove (and store) our tokenized text
text_encoding = encoding.pop("input_ids")
encoding.pop("attention_mask")
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
text_encoding = text_encoding[-left_truncate_len:]
return text_encoding, encoding # image_encoding is a dict
def _encode_multimodal_pair(self, context, continuation, images):
"""Helper function to perform the role of TemplateLM._encode_pair
Except allowing for image input to also be processed alongside `context`.
This method is a bit messy due to the need to defer conversion of image and text token input
into PyTorch tensors until the main inference loop.
"""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
# TODO: replace default <image> placeholder with self.image_token, for contexts
whole_enc, image_enc = self.tok_multimodal_encode(
context + continuation, images
)
context_enc, _ = self.tok_multimodal_encode(context, images)
# tok_multimodal_encode returns List[List[int]] for tokenized text. Get rid of the batch dim
# since we only are encoding a single string.
# TODO: this is a bit hacky, it'd be nice to make this generally cleaner
whole_enc, context_enc = whole_enc[0], context_enc[0]
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc, image_enc
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
self.chat_applied = True
if not self.interleave:
for content in chat_history:
c = []
text = content["content"]
# Count and remove image placeholders
image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")
# Add image entries
for _ in range(image_count):
c.append({"type": "image", "image": None})
# Add single text entry at the end
c.append({"type": "text", "text": text})
content["content"] = c
else:
for content in chat_history:
c = []
text = content["content"]
expected_image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
actual_image_count = 0
text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER)
for i, part in enumerate(text_parts):
# TODO: concatenate text parts (esp. if skipping images)?
if part: # Add non-empty text parts
c.append({"type": "text", "text": part})
if (
(i < len(text_parts) - 1) and i < self.max_images
): # Add image placeholder after each split except the last
c.append({"type": "image"})
actual_image_count += 1
content["content"] = c
if actual_image_count != expected_image_count:
raise ValueError(
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
)
return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True
)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
if hasattr(self.processor, "apply_chat_template"):
_tokenizer = self.tokenizer
self.tokenizer = self.processor
selected_template = super().chat_template(chat_template)
self.tokenizer = _tokenizer
return selected_template
else:
return super().chat_template(chat_template)
def tok_batch_multimodal_encode(
self,
strings: List[str], # note that input signature of this fn is different
images: List[List], # TODO: images are pil.Image at the moment, update typehint
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Union[
BatchEncoding, Dict[str, torch.Tensor]
]: # note that this return signature differs from HFLM tok_batch_encode.
# NOTE: here, we replace <image> tags with our model's corresponding image_token string value.
if not self.chat_applied:
# TODO<baber>: This still keeps the whitespace in the image placeholder, which is not ideal.
strings = [
replace_placeholders(
string, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
)
for string in strings
]
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
# add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
images = [img[: self.max_images] for img in images]
if self.rgb:
images = [[img.convert("RGB") for img in sublist] for sublist in images]
# certain models like llava expect a single-level image list even for bs>1, multi-image. TODO: port this over to loglikelihoods
if getattr(self.config, "model_type", "") == "llava":
images = flatten_image_list(images)
encoding = self.processor(
images=images,
text=strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
# **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added?
)
encoding.to( # TODO: our other tokenization methods in HFLM don't typically move to device. this breaks convention
self.device, self.model.dtype
) # TODO: This only casts the pixel values. Should they always be float16?
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len:
]
self.tokenizer.padding_side = old_padding_side
return encoding
def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None):
"""
TODO: update docstring
"""
# note: imgs is a dict.
with torch.no_grad():
return self.model(inps, **imgs).logits
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
stopping_criteria = stop_sequences_criteria(
self.tokenizer,
stop,
inputs["input_ids"].shape[1],
inputs["input_ids"].shape[0],
)
return self.model.generate(
**inputs,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
def _batch_images(self, image_encs):
"""
Helper function: batch together image encodings across examples in a batch.
# TODO: for variable-sized images, this may break down.
"""
batched_imgs = {}
for key in image_encs[0].keys():
batched_imgs[key] = torch.cat(
[
torch.tensor(
image_enc[key], device=self.device, dtype=self.model.dtype
)
for image_enc in image_encs
],
dim=0,
)
return batched_imgs
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
raise NotImplementedError(
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
)
def loglikelihood(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
raise NotImplementedError(
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
)
new_reqs = []
for context, continuation, aux_arguments in [req.args for req in requests]:
if context == "":
raise ValueError(
"Must get non-empty context for multimodal requests! You might be trying to run 'loglikelihood_rolling', which is not supported in the multimodal case."
)
else:
visuals = aux_arguments["visual"]
context_enc, continuation_enc, image_enc = self._encode_multimodal_pair(
context, continuation, visuals
)
# TODO: key to pick for caching images
new_reqs.append(
(
(context, continuation, visuals),
context_enc,
continuation_enc,
image_enc,
)
)
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
def _loglikelihood_tokens(
self,
requests: List[
Tuple[Tuple[None, str, str], List[int], List[int], List[int]]
], # TODO: update typehint to be correct
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
res = []
# TODO: **improve multimodal collation.** We currently ignore image size when ordering docs. ideally we'd take them into account
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = req[1] + req[2]
return -len(toks), tuple(toks)
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-1] + req[-3] + req[-2][:-1]
re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
n_reordered_requests = len(re_ord)
batch_size = (
self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None
)
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests with text+image input",
)
for chunk in chunks:
imgs = []
inps = []
cont_toks_list = []
inplens = []
padding_len_inp = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc, image_enc in chunk:
# sanity check
assert len(image_enc) > 0
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left
# TODO: assuming that we won't handle enc-dec Vision2Seq models. Is that a safe assumption?
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)
inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)
imgs.append(image_enc)
# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
# batch our examples' image inputs together
batched_imgs = self._batch_images(
imgs
) # TODO: fix/test for bs>1 case with differently-sized imgs!
multi_logits = F.log_softmax(
self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs),
dim=-1,
) # [batch, padding_length (inp or cont), vocab]
for (
request_str,
ctx_tokens,
_,
image_encs,
), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# also discards + checks for "virtual tokens" in the causal LM's input window
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze(0) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
# check for one-token continuation cache hits.
# noop in case group_by != "contexts" or no cache hit and returns the
# original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
for request_str, cont_toks, logits in re_ord.get_cache(
req_str=request_str,
cxt_toks=ctx_tokens,
cont_toks=cont_toks,
logits=logits,
):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial(
"loglikelihood", request_str, answer
) # TODO: choose convention for adding images into the cache key
pbar.update(1)
pbar.close()
return re_ord.get_original(res)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
# TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs)
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests with text+image input",
)
# TODO: port auto-batch sizing into this.
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(
[reg.args for reg in requests],
_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
### Up to here: was identical to non-multimodal HFLM generate_until ###
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
visuals = [arg["visual"] for arg in aux_arguments]
if not isinstance(contexts, list):
contexts = list(
contexts
) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
# TODO: could we upstream this workaround to HF?
### this part onward: same as HFLM ###
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
### end stuff that's entirely copied verbatim from HFLM ###
max_ctx_len = self.max_length - max_gen_toks
inputs = self.tok_batch_multimodal_encode(
contexts,
visuals,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
context_enc = inputs["input_ids"]
if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
cont = self._model_multimodal_generate(inputs, stop=until, **kwargs)
del inputs
torch.cuda.empty_cache()
import gc
gc.collect()
### essentially same as HFLM beyond this line!
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only VLM
cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
res.append(s)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), s
) # TODO: cache key for multimodal input should be what?
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res
...@@ -4,15 +4,16 @@ from datetime import timedelta ...@@ -4,15 +4,16 @@ from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Union from typing import Dict, List, Literal, Optional, Tuple, Union
import jinja2
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers import transformers
from accelerate import ( from accelerate import (
Accelerator, Accelerator,
DistributedType,
InitProcessGroupKwargs, InitProcessGroupKwargs,
find_executable_batch_size, find_executable_batch_size,
) )
from accelerate.utils import get_max_memory
from huggingface_hub import HfApi from huggingface_hub import HfApi
from packaging import version from packaging import version
from peft import PeftModel from peft import PeftModel
...@@ -30,7 +31,9 @@ from lm_eval.api.registry import register_model ...@@ -30,7 +31,9 @@ from lm_eval.api.registry import register_model
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
clear_torch_cache, clear_torch_cache,
configure_pad_token,
get_dtype, get_dtype,
handle_stop_sequences,
pad_and_concat, pad_and_concat,
stop_sequences_criteria, stop_sequences_criteria,
) )
...@@ -39,31 +42,6 @@ from lm_eval.models.utils import ( ...@@ -39,31 +42,6 @@ from lm_eval.models.utils import (
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
def _get_accelerate_args(
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
gpus: Optional[int] = None,
) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
max_memory = {}
if max_memory_per_gpu is not None:
max_memory_per_gpu_map = {
device_idx: max_memory_per_gpu for device_idx in range(gpus)
}
max_memory.update(max_memory_per_gpu_map)
if max_cpu_memory is not None:
max_memory["cpu"] = max_cpu_memory
args = {}
if max_memory:
args["max_memory"] = max_memory
args["device_map"] = device_map_option
args["offload_folder"] = offload_folder
return args
@register_model("hf-auto", "hf", "huggingface") @register_model("hf-auto", "hf", "huggingface")
class HFLM(TemplateLM): class HFLM(TemplateLM):
""" """
...@@ -79,7 +57,7 @@ class HFLM(TemplateLM): ...@@ -79,7 +57,7 @@ class HFLM(TemplateLM):
def __init__( def __init__(
self, self,
pretrained: Union[str, transformers.PreTrainedModel], pretrained: Union[str, transformers.PreTrainedModel],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
...@@ -104,7 +82,6 @@ class HFLM(TemplateLM): ...@@ -104,7 +82,6 @@ class HFLM(TemplateLM):
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = "./offload", offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
...@@ -112,10 +89,11 @@ class HFLM(TemplateLM): ...@@ -112,10 +89,11 @@ class HFLM(TemplateLM):
peft: Optional[str] = None, peft: Optional[str] = None,
delta: Optional[str] = None, delta: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
# optionally: take in an already-initialized transformers.PreTrainedModel # optionally: take in an already-initialized transformers.PreTrainedModel
if not isinstance(pretrained, str): if not isinstance(pretrained, str):
eval_logger.warning( eval_logger.warning(
...@@ -127,21 +105,6 @@ class HFLM(TemplateLM): ...@@ -127,21 +105,6 @@ class HFLM(TemplateLM):
self._config = self._model.config self._config = self._model.config
gpus = 0 gpus = 0
if tokenizer:
assert isinstance(
tokenizer, transformers.PreTrainedTokenizer
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
self.tokenizer = tokenizer
else:
# Get tokenizer
model_name = self._model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)
else: else:
assert isinstance(device, str) assert isinstance(device, str)
assert isinstance(pretrained, str) assert isinstance(pretrained, str)
...@@ -156,6 +119,7 @@ class HFLM(TemplateLM): ...@@ -156,6 +119,7 @@ class HFLM(TemplateLM):
if "npu" in accelerator.device.type: if "npu" in accelerator.device.type:
gpus = torch.npu.device_count() gpus = torch.npu.device_count()
# using one process with no model parallelism
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
...@@ -181,14 +145,19 @@ class HFLM(TemplateLM): ...@@ -181,14 +145,19 @@ class HFLM(TemplateLM):
if torch.cuda.is_available() if torch.cuda.is_available()
else torch.device("cpu") else torch.device("cpu")
) )
else: else: # Parallelism managed by accelerate
if device != "cuda": if device != "cuda":
eval_logger.info( eval_logger.info(
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
) )
# TODO: include in warning that `load_in_8bit` etc. affect this too # TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = torch.device(device) self._device = (
self.accelerator.device
if hasattr(self, "accelerator")
else torch.device(device)
)
revision = str(revision) # cast to string if not already one
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
...@@ -196,9 +165,10 @@ class HFLM(TemplateLM): ...@@ -196,9 +165,10 @@ class HFLM(TemplateLM):
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
) )
# determine which of 'causal' and 'seq2seq' backends to use # determine which of 'causal' and 'seq2seq' backends to use for HF models
self._get_backend( self._get_backend(
config=self.config, backend=backend, trust_remote_code=trust_remote_code config=self.config, backend=backend, trust_remote_code=trust_remote_code
) )
...@@ -210,6 +180,7 @@ class HFLM(TemplateLM): ...@@ -210,6 +180,7 @@ class HFLM(TemplateLM):
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_fast_tokenizer=use_fast_tokenizer, use_fast_tokenizer=use_fast_tokenizer,
gguf_file=gguf_file,
) )
# if we passed `pretrained` as a string, initialize our model now # if we passed `pretrained` as a string, initialize our model now
...@@ -221,13 +192,14 @@ class HFLM(TemplateLM): ...@@ -221,13 +192,14 @@ class HFLM(TemplateLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
parallelize=parallelize, parallelize=parallelize,
gpus=gpus, gpus=gpus,
device_map_option=device_map_option,
max_memory_per_gpu=max_memory_per_gpu, max_memory_per_gpu=max_memory_per_gpu,
max_cpu_memory=max_cpu_memory, max_cpu_memory=max_cpu_memory,
offload_folder=offload_folder, offload_folder=offload_folder,
peft=peft, peft=peft,
delta=delta, delta=delta,
autogptq=autogptq, autogptq=autogptq,
gptqmodel=gptqmodel,
gguf_file=gguf_file,
**kwargs, **kwargs,
) )
...@@ -236,52 +208,17 @@ class HFLM(TemplateLM): ...@@ -236,52 +208,17 @@ class HFLM(TemplateLM):
self.model.eval() self.model.eval()
self.model.tie_weights() self.model.tie_weights()
if isinstance(pretrained, str) and (gpus >= 1 or str(self.device) == "mps"):
# TODO: can remove this whole snippet except in the mps case, perhaps?
if not (parallelize or autogptq or hasattr(self, "accelerator")):
# place model onto device requested manually,
# if not using HF Accelerate or device_map
# or any other option that preloads model onto device
try:
self.model.to(self.device)
except ValueError:
eval_logger.debug(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
)
self.truncation = truncation self.truncation = truncation
self.logits_cache = logits_cache self.logits_cache = logits_cache
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use # select (or create) a pad token to use
if self.tokenizer.pad_token: self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
pass
elif self.tokenizer.unk_token:
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else:
if getattr(self.config, "model_type", None) == "qwen":
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>"
elif (
self.tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
or self.tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
):
# The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
# The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
# ---
# Note that the world tokenizer class name, might change in the future for the final huggingface merge
# https://github.com/huggingface/transformers/pull/26963
assert self.tokenizer.pad_token_id == 0
else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
# TODO: override this for Gemma
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if getattr(self.config, "model_type", None) == "gemma": if "gemma" in getattr(self.config, "model_type", ""):
self.add_bos_token = True self.add_bos_token = True
eval_logger.info( eval_logger.info(
f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it." f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
) )
self._max_length = max_length self._max_length = max_length
...@@ -301,49 +238,46 @@ class HFLM(TemplateLM): ...@@ -301,49 +238,46 @@ class HFLM(TemplateLM):
self.batch_size_per_gpu = int(batch_size) self.batch_size_per_gpu = int(batch_size)
if isinstance(pretrained, str): if isinstance(pretrained, str):
if gpus >= 1 or str(self.device) == "mps":
# TODO: can remove this whole snippet except in the mps case, perhaps?
if not (parallelize or autogptq or hasattr(self, "accelerator")):
# place model onto device requested manually,
# if not using HF Accelerate or device_map
# or any other option that preloads model onto device
try:
self.model.to(self.device)
except ValueError:
eval_logger.debug(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
)
# multigpu data-parallel support when launched with accelerate # multigpu data-parallel support when launched with accelerate
if gpus > 1: if gpus > 1:
if parallelize: if accelerator.num_processes > 1:
if accelerator.num_processes > 1: if parallelize:
raise RuntimeError( eval_logger.warning(
"Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher." "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
) )
else: elif gpus > accelerator.num_processes:
pass
elif accelerator.num_processes == 1:
# if we aren't launching via accelerate, ditch
self._rank = 0
self._world_size = 1
else:
if gpus > accelerator.num_processes:
eval_logger.warning( eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. " "WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script " "If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. " "with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices." f"Current run will proceed with {accelerator.num_processes} devices."
) )
assert ( if self.accelerator.is_local_main_process:
accelerator.distributed_type eval_logger.info(
in [ f"Using {gpus} devices with data parallelism"
DistributedType.FSDP, )
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
]
), "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(
self.model, evaluation_mode=True
)
self._device = torch.device(f"{accelerator.device}") self._device = torch.device(f"{accelerator.device}")
self.accelerator = accelerator self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes self._world_size = self.accelerator.num_processes
else:
# if we aren't launching via accelerate, ditch
self._rank = 0
self._world_size = 1
else: else:
# if a PreTrainedModel was passed into HFLM, we forgo distributed setup. # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
eval_logger.warning( eval_logger.warning(
...@@ -358,6 +292,94 @@ class HFLM(TemplateLM): ...@@ -358,6 +292,94 @@ class HFLM(TemplateLM):
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
) )
def _get_accelerate_args(
self,
parallelize: Optional[bool] = None,
device_map: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
gpus: Optional[int] = None,
) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
if (
num_machines == 0
and hasattr(self, "accelerator")
and self.accelerator is not None
):
eval_logger.info(
"We are not in a distributed setting for accelerate. Setting model_parallel to False."
)
parallelize = False
if parallelize is None:
# If parallelism is unset by the user, we automatically assign model parallelism
# if enough extra GPUs are available
max_memory_all_gpus = get_max_memory()
# We just want gpu, not cpu, max memory
if "cpu" in max_memory_all_gpus:
del max_memory_all_gpus["cpu"]
parallelize = bool(num_local_processes < len(max_memory_all_gpus))
eval_logger.info(
f"Setting model parallel to {parallelize} since "
f"the number of local processes is {num_local_processes} "
f"and the number of GPUs is {len(max_memory_all_gpus)}"
)
args = {}
if parallelize: # Model parallelism will be used
max_memory = {}
if max_memory_per_gpu is not None: # Using the provided memory requirements
max_memory_per_gpu_map = {
device_idx: max_memory_per_gpu for device_idx in range(gpus)
}
else: # Estimating the possible memory requirements
max_memory_all_gpus = get_max_memory()
if "cpu" in max_memory_all_gpus:
del max_memory_all_gpus["cpu"]
if not hasattr(self, "accelerator"):
max_memory_per_gpu_map = {
k: v for k, v in max_memory_all_gpus.items()
}
else:
# use only 1 / num_processes of the GPUs if we are running under accelerate launch
max_memory_per_gpu_map = {
k: v
for k, v in max_memory_all_gpus.items()
if k % num_local_processes
== (self.accelerator.process_index % num_local_processes)
}
args["max_memory"] = max_memory_per_gpu_map
args["device_map"] = "auto" if device_map is None else device_map
eval_logger.info(
f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
)
if max_cpu_memory is not None:
max_memory["cpu"] = max_cpu_memory
args["offload_folder"] = offload_folder
elif (
device_map is None
): # No model parallelism, we use the default provided device for our model
if hasattr(self, "accelerator"):
device_map = {"": f"{self.accelerator.device}"}
else:
device_map = {"": str(self.device)}
args["max_memory"] = None
args["device_map"] = device_map
eval_logger.info(
f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
)
else:
args["max_memory"] = None
args["device_map"] = None
eval_logger.info("Model parallel was set to False.")
return args
@property @property
def config(self): def config(self):
# return the associated transformers.AutoConfig for the given pretrained model. # return the associated transformers.AutoConfig for the given pretrained model.
...@@ -423,33 +445,31 @@ class HFLM(TemplateLM): ...@@ -423,33 +445,31 @@ class HFLM(TemplateLM):
def tokenizer_name(self) -> str: def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__") return self.tokenizer.name_or_path.replace("/", "__")
@property
def chat_template(self) -> str:
if self.tokenizer.chat_template is not None:
return self.tokenizer.chat_template
return self.tokenizer.default_chat_template
def _get_backend( def _get_backend(
self, self,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig], config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
model type to be used. sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
**If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
user must set `self.backend` to be either "causal" or "seq2seq" manually!**
""" """
assert backend in ["default", "causal", "seq2seq"] assert backend in ["default", "causal", "seq2seq"]
if backend != "default": if backend != "default":
# if we've settled on non-default backend, use that manually # if we've settled on non-default backend, use that manually
if backend == "causal": if backend == "causal":
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.backend = backend
elif backend == "seq2seq": elif backend == "seq2seq":
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM self.backend = backend
eval_logger.info( eval_logger.info(
f"Overrode HF model backend type, and using type '{backend}'" f"Overrode HF model backend type, and using type '{self.backend}'"
) )
else: else:
# determine and use the default HF backend for this model, based on its config + metadata. # determine and use the default HF backend for this model, based on its config + metadata.
...@@ -460,37 +480,46 @@ class HFLM(TemplateLM): ...@@ -460,37 +480,46 @@ class HFLM(TemplateLM):
# first check if model type is listed under seq2seq models, since some # first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models. # these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM self.backend = "seq2seq"
eval_logger.debug(f"Using model type '{self.backend}'")
elif ( elif (
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
): ):
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.backend = "causal"
eval_logger.debug(f"Using model type '{self.backend}'")
else: else:
if not trust_remote_code: if not trust_remote_code:
eval_logger.warning( eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \ "HF model type is neither marked as CausalLM or Seq2SeqLM. \
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
"Setting backend to causal"
) )
# if model type is neither in HF transformers causal or seq2seq model registries # if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM # then we default to assuming AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.backend = "causal"
eval_logger.info(
f"Model type cannot be determined. Using default model type '{self.backend}'"
)
assert self.AUTO_MODEL_CLASS in [ if self.AUTO_MODEL_CLASS is None:
transformers.AutoModelForCausalLM, if self.backend == "causal":
transformers.AutoModelForSeq2SeqLM, self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
] elif self.backend == "seq2seq":
return None self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
def _get_config( def _get_config(
self, self,
pretrained: str, pretrained: str,
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
gguf_file: Optional[str] = None,
) -> None: ) -> None:
"""Return the model config for HuggingFace models"""
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
) )
def _create_model( def _create_model(
...@@ -504,7 +533,6 @@ class HFLM(TemplateLM): ...@@ -504,7 +533,6 @@ class HFLM(TemplateLM):
# (accelerate naive PP (device_map) options) # (accelerate naive PP (device_map) options)
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
gpus: Optional[int] = None, gpus: Optional[int] = None,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload", offload_folder: Optional[str] = "./offload",
...@@ -512,6 +540,8 @@ class HFLM(TemplateLM): ...@@ -512,6 +540,8 @@ class HFLM(TemplateLM):
peft: Optional[str] = None, peft: Optional[str] = None,
delta: Optional[str] = None, delta: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
...@@ -528,27 +558,18 @@ class HFLM(TemplateLM): ...@@ -528,27 +558,18 @@ class HFLM(TemplateLM):
model_kwargs = kwargs if kwargs else {} model_kwargs = kwargs if kwargs else {}
if parallelize: model_kwargs.update(
model_kwargs.update( self._get_accelerate_args(
_get_accelerate_args( parallelize=parallelize,
device_map_option, # TODO: phase out device_map_option? device_map=kwargs.get("device_map", None),
max_memory_per_gpu, max_memory_per_gpu=max_memory_per_gpu,
max_cpu_memory, max_cpu_memory=max_cpu_memory,
offload_folder, offload_folder=offload_folder,
gpus, gpus=gpus,
)
) )
elif "device_map" not in model_kwargs: )
# set a device_map to initialize model on the right GPU.
# this is needed because it seems that the default behavior
# for quantized models now seems to be device_map="auto"
# which breaks data-parallel mode.
if hasattr(self, "accelerator"):
model_kwargs.update({"device_map": {"": f"{self.accelerator.device}"}})
else:
model_kwargs.update({"device_map": {"": str(self.device)}})
if not autogptq: if not autogptq and not gptqmodel:
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
assert ( assert (
transformers.__version__ >= "4.30.0" transformers.__version__ >= "4.30.0"
...@@ -559,31 +580,52 @@ class HFLM(TemplateLM): ...@@ -559,31 +580,52 @@ class HFLM(TemplateLM):
model_kwargs["bnb_4bit_compute_dtype"] = get_dtype( model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
model_kwargs["bnb_4bit_compute_dtype"] model_kwargs["bnb_4bit_compute_dtype"]
) )
self._model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
torch_dtype=get_dtype(dtype), torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
**model_kwargs, **model_kwargs,
) )
else: else:
try: if autogptq and gptqmodel:
from auto_gptq import AutoGPTQForCausalLM raise ValueError(
except ModuleNotFoundError: "Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
raise Exception(
"Tried to load auto_gptq, but auto-gptq is not installed ",
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
) )
self._model = AutoGPTQForCausalLM.from_quantized( if autogptq:
pretrained, try:
trust_remote_code=trust_remote_code, from auto_gptq import AutoGPTQForCausalLM
model_basename=None if autogptq is True else Path(autogptq).stem, except ModuleNotFoundError as exception:
use_safetensors=True raise type(exception)(
if autogptq is True "Tried to load auto_gptq, but auto-gptq is not installed ",
else autogptq.endswith(".safetensors"), "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
**model_kwargs, )
)
self._model = AutoGPTQForCausalLM.from_quantized(
pretrained,
trust_remote_code=trust_remote_code,
model_basename=None if autogptq is True else Path(autogptq).stem,
use_safetensors=True
if autogptq is True
else autogptq.endswith(".safetensors"),
**model_kwargs,
)
if gptqmodel:
try:
from gptqmodel import GPTQModel
except ModuleNotFoundError as exception:
raise type(exception)(
"Tried to load gptqmodel, but gptqmodel is not installed ",
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
)
self._model = GPTQModel.from_quantized(
pretrained, trust_remote_code=trust_remote_code, **model_kwargs
)
if peft and delta: if peft and delta:
raise ValueError( raise ValueError(
...@@ -596,10 +638,10 @@ class HFLM(TemplateLM): ...@@ -596,10 +638,10 @@ class HFLM(TemplateLM):
raise AssertionError("load_in_4bit requires peft >= 0.4.0") raise AssertionError("load_in_4bit requires peft >= 0.4.0")
if self._model.config.vocab_size != len(self.tokenizer): if self._model.config.vocab_size != len(self.tokenizer):
# resize model for LoRAs with added tokens # resize model for LoRAs with added tokens
self._model.resize_token_embeddings(len(self.tokenizer))
eval_logger.info( eval_logger.info(
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..." f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
) )
self._model.resize_token_embeddings(len(self.tokenizer))
self._model = PeftModel.from_pretrained( self._model = PeftModel.from_pretrained(
self._model, peft, revision=revision self._model, peft, revision=revision
) )
...@@ -642,6 +684,7 @@ class HFLM(TemplateLM): ...@@ -642,6 +684,7 @@ class HFLM(TemplateLM):
revision: Optional[str] = "main", revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
gguf_file: Optional[str] = None,
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
...@@ -649,14 +692,21 @@ class HFLM(TemplateLM): ...@@ -649,14 +692,21 @@ class HFLM(TemplateLM):
Create a tokenizer object corresponding to the correct Create a tokenizer object corresponding to the correct
tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed. tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
""" """
kwargs = {
"revision": revision,
"trust_remote_code": trust_remote_code,
}
# gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
if gguf_file is not None:
kwargs["gguf_file"] = gguf_file
else:
kwargs["use_fast"] = use_fast_tokenizer
if tokenizer: if tokenizer:
if isinstance(tokenizer, str): if isinstance(tokenizer, str):
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer, tokenizer, **kwargs
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
) )
else: else:
assert isinstance( assert isinstance(
...@@ -671,10 +721,7 @@ class HFLM(TemplateLM): ...@@ -671,10 +721,7 @@ class HFLM(TemplateLM):
# get the HF hub name via accessor on model # get the HF hub name via accessor on model
model_name = self.model.name_or_path model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name, model_name, **kwargs
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
) )
return None return None
...@@ -694,7 +741,7 @@ class HFLM(TemplateLM): ...@@ -694,7 +741,7 @@ class HFLM(TemplateLM):
# if OOM, then halves batch_size and tries again # if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size) @find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size): def forward_batch(batch_size):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: if self.backend == "seq2seq":
length = max(max_context_enc, max_cont_enc) length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones( batched_conts = torch.ones(
(batch_size, length), device=self.device (batch_size, length), device=self.device
...@@ -745,7 +792,7 @@ class HFLM(TemplateLM): ...@@ -745,7 +792,7 @@ class HFLM(TemplateLM):
# by default for CausalLM - false or self.add_bos_token is set # by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
special_tokens_kwargs = { special_tokens_kwargs = {
"add_special_tokens": False or self.add_bos_token "add_special_tokens": False or self.add_bos_token
} }
...@@ -773,7 +820,7 @@ class HFLM(TemplateLM): ...@@ -773,7 +820,7 @@ class HFLM(TemplateLM):
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
add_special_tokens = {} add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
add_special_tokens = {"add_special_tokens": False or self.add_bos_token} add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
encoding = self.tokenizer( encoding = self.tokenizer(
...@@ -784,6 +831,12 @@ class HFLM(TemplateLM): ...@@ -784,6 +831,12 @@ class HFLM(TemplateLM):
**add_special_tokens, **add_special_tokens,
) )
if left_truncate_len: if left_truncate_len:
original_lengths = encoding["input_ids"].size(1)
if original_lengths > left_truncate_len:
eval_logger.warn(
f"Left truncation applied. Original sequence length was {original_lengths}, "
f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
)
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][ encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len: :, -left_truncate_len:
...@@ -851,14 +904,14 @@ class HFLM(TemplateLM): ...@@ -851,14 +904,14 @@ class HFLM(TemplateLM):
def _select_cont_toks( def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
assert ( assert (
contlen and inplen contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM" ), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.backend == "seq2seq":
assert ( assert (
contlen and not inplen contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len" ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
...@@ -871,8 +924,6 @@ class HFLM(TemplateLM): ...@@ -871,8 +924,6 @@ class HFLM(TemplateLM):
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> List[float]:
loglikelihoods = []
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -881,10 +932,17 @@ class HFLM(TemplateLM): ...@@ -881,10 +932,17 @@ class HFLM(TemplateLM):
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
for (string,) in tqdm( # First, collect all windows from all requests
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0)) all_windows = [] # List of (request_idx, window) tuples
request_window_counts = [] # Track number of windows per request
for req_idx, (string,) in enumerate(
tqdm(
[req.args for req in requests],
disable=(disable_tqdm or (self.rank != 0)),
)
): ):
rolling_token_windows = list( rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
...@@ -897,34 +955,55 @@ class HFLM(TemplateLM): ...@@ -897,34 +955,55 @@ class HFLM(TemplateLM):
) )
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows] windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
if self.world_size > 1:
# We pad out the external document-level iterator so the inner iterator doesn't hang
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank] # Store windows with their request index
if pad_amnt > 0: all_windows.extend((req_idx, window) for window in windows)
rolling_token_windows += pad_amnt * [rolling_token_windows[0]] request_window_counts.append(len(windows))
string_nll = self._loglikelihood_tokens( # Handle distributed case padding
requests=rolling_token_windows, pad_amnt = 0
disable_tqdm=True, if self.world_size > 1:
override_bs=adaptive_batch_size, mytensor = torch.tensor(len(all_windows), device=self.device)
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
all_windows += pad_amnt * [all_windows[0]]
all_nlls = []
batch_size = adaptive_batch_size or self.batch_size
for i in range(0, len(all_windows), batch_size):
batch = all_windows[i : i + batch_size]
# Extract just the windows for processing, keeping track of request indices
batch_indices, batch_windows = zip(*batch)
batch_nlls = self._loglikelihood_tokens(
requests=batch_windows,
disable_tqdm=False,
override_bs=len(batch_windows),
) )
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
if (self.world_size > 1) and (pad_amnt > 0): # Remove padding if necessary
string_nll = [x[0] for x in string_nll[:-pad_amnt]] if (self.world_size > 1) and (pad_amnt > 0):
else: all_nlls = all_nlls[:-pad_amnt]
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll) # Reconstruct per-request loglikelihoods
loglikelihoods.append(string_nll) loglikelihoods = []
current_idx = 0
for window_count in request_window_counts:
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), request_total
)
return loglikelihoods return loglikelihoods
...@@ -978,8 +1057,7 @@ class HFLM(TemplateLM): ...@@ -978,8 +1057,7 @@ class HFLM(TemplateLM):
requests, requests,
sort_fn=_collate, sort_fn=_collate,
group_by="contexts" group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if self.backend == "causal" and self.logits_cache
and self.logits_cache
else None, else None,
group_fn=_lookup_one_token_cont, group_fn=_lookup_one_token_cont,
) )
...@@ -1036,14 +1114,21 @@ class HFLM(TemplateLM): ...@@ -1036,14 +1114,21 @@ class HFLM(TemplateLM):
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
total_length = len(context_enc) + len(continuation_enc)
if total_length > self.max_length + 1:
eval_logger.warn(
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
f"exceeds model's maximum length ({self.max_length}). "
f"Truncating {total_length - self.max_length + 1} tokens from the left."
)
inp = torch.tensor( inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long, dtype=torch.long,
device=self.device, device=self.device,
) )
(inplen,) = inp.shape (inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.backend == "seq2seq":
inp = torch.tensor( inp = torch.tensor(
(context_enc)[-self.max_length :], (context_enc)[-self.max_length :],
dtype=torch.long, dtype=torch.long,
...@@ -1083,11 +1168,11 @@ class HFLM(TemplateLM): ...@@ -1083,11 +1168,11 @@ class HFLM(TemplateLM):
# create encoder attn mask and batched conts, if seq2seq # create encoder attn mask and batched conts, if seq2seq
call_kwargs = {} call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
batched_inps = pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right" padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.backend == "seq2seq":
# TODO: left-pad encoder inps and mask? # TODO: left-pad encoder inps and mask?
batched_inps = pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps padding_len_inp, inps
...@@ -1118,7 +1203,7 @@ class HFLM(TemplateLM): ...@@ -1118,7 +1203,7 @@ class HFLM(TemplateLM):
# from prompt/prefix tuning tokens, if applicable # from prompt/prefix tuning tokens, if applicable
ctx_len = ( ctx_len = (
inplen + (logits.shape[0] - padding_len_inp) inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if self.backend == "causal"
else None else None
) )
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
...@@ -1154,7 +1239,13 @@ class HFLM(TemplateLM): ...@@ -1154,7 +1239,13 @@ class HFLM(TemplateLM):
res.append(answer) res.append(answer)
self.cache_hook.add_partial("loglikelihood", request_str, answer) if request_str is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial(
"loglikelihood", request_str, answer
)
pbar.update(1) pbar.update(1)
pbar.close() pbar.close()
...@@ -1214,43 +1305,34 @@ class HFLM(TemplateLM): ...@@ -1214,43 +1305,34 @@ class HFLM(TemplateLM):
group_fn=lambda x: x[1], group_fn=lambda x: x[1],
) )
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments. # unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys(): # add EOS token to stop sequences
until = kwargs.pop("until") until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: assert (
max_ctx_len > 0
), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
elif self.backend == "seq2seq":
# max len for inputs = encoder's whole max_length # max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length max_ctx_len = self.max_length
...@@ -1277,7 +1359,7 @@ class HFLM(TemplateLM): ...@@ -1277,7 +1359,7 @@ class HFLM(TemplateLM):
cont_toks_list = cont.tolist() cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts): for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM # discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
cont_toks = cont_toks[context_enc.shape[1] :] cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks) s = self.tok_decode(cont_toks)
...@@ -1304,9 +1386,20 @@ class HFLM(TemplateLM): ...@@ -1304,9 +1386,20 @@ class HFLM(TemplateLM):
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
""" """
return self.tokenizer.apply_chat_template( try:
chat_history, tokenize=False, add_generation_prompt=True chat_templated = self.tokenizer.apply_chat_template(
) chat_history, tokenize=False, add_generation_prompt=True
)
except jinja2.exceptions.TemplateError:
eval_logger.warning(
"Failed to apply chat template. removing the system role in chat history."
)
chat_history = [msg for msg in chat_history if msg["role"] != "system"]
chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
)
return chat_templated
def get_model_info(self) -> dict: def get_model_info(self) -> dict:
""" """
...@@ -1332,7 +1425,7 @@ class HFLM(TemplateLM): ...@@ -1332,7 +1425,7 @@ class HFLM(TemplateLM):
model_info = HfApi().model_info(repo_id=pretrained, revision=revision) model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
return model_info.sha return model_info.sha
except Exception as e: except Exception as e:
eval_logger.warn( eval_logger.debug(
f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}" f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
) )
return "" return ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment