Commit 8fada609 authored by Baber's avatar Baber
Browse files

Merge branch 'main' into mathvista

parents 0007b74a 1208afd3
...@@ -29,7 +29,7 @@ repos: ...@@ -29,7 +29,7 @@ repos:
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4 rev: v0.9.3
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff
...@@ -38,7 +38,7 @@ repos: ...@@ -38,7 +38,7 @@ repos:
# Run the formatter. # Run the formatter.
- id: ruff-format - id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.3.0 rev: v2.4.1
hooks: hooks:
- id: codespell - id: codespell
exclude: > exclude: >
......
import warnings
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer
from transformers import AutoConfig
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
def evaluate_lm_eval(lens_model: HookedTransformer, tasks: list[str], **kwargs):
class HFLikeModelAdapter(nn.Module):
"""Adapts HookedTransformer to match the HuggingFace interface expected by lm-eval"""
def __init__(self, model: HookedTransformer):
super().__init__()
self.model = model
self.tokenizer = model.tokenizer
self.config = AutoConfig.from_pretrained(model.cfg.tokenizer_name)
self.device = model.cfg.device
self.tie_weights = lambda: self
def forward(self, input_ids=None, attention_mask=None, **kwargs):
output = self.model(input_ids, attention_mask=attention_mask, **kwargs)
# Make sure output has the expected .logits attribute
if not hasattr(output, "logits"):
if isinstance(output, torch.Tensor):
output.logits = output
return output
# Only delegate specific attributes we know we need
def to(self, *args, **kwargs):
return self.model.to(*args, **kwargs)
def eval(self):
self.model.eval()
return self
def train(self, mode=True):
self.model.train(mode)
return self
model = HFLikeModelAdapter(lens_model)
warnings.filterwarnings("ignore", message="Failed to get model SHA for")
results = evaluator.simple_evaluate(
model=HFLM(pretrained=model, tokenizer=model.tokenizer),
tasks=tasks,
verbosity="WARNING",
**kwargs,
)
return results
if __name__ == "__main__":
# Load base model
model = HookedTransformer.from_pretrained("pythia-70m")
res = evaluate_lm_eval(model, tasks=["arc_easy"])
print(res["results"])
...@@ -112,6 +112,4 @@ class ConfigurableGroup(abc.ABC): ...@@ -112,6 +112,4 @@ class ConfigurableGroup(abc.ABC):
return self._config.group return self._config.group
def __repr__(self): def __repr__(self):
return ( return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
f"ConfigurableGroup(group={self.group}," f"group_alias={self.group_alias})"
)
...@@ -527,9 +527,9 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): ...@@ -527,9 +527,9 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None): def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
assert ( assert metrics is not None, (
metrics is not None "Need to pass a list of each subtask's metric for this stderr aggregation"
), "Need to pass a list of each subtask's metric for this stderr aggregation" )
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics) assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
......
...@@ -17,13 +17,13 @@ def register_model(*names): ...@@ -17,13 +17,13 @@ def register_model(*names):
def decorate(cls): def decorate(cls):
for name in names: for name in names:
assert issubclass( assert issubclass(cls, LM), (
cls, LM f"Model '{name}' ({cls.__name__}) must extend LM class"
), f"Model '{name}' ({cls.__name__}) must extend LM class" )
assert ( assert name not in MODEL_REGISTRY, (
name not in MODEL_REGISTRY f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." )
MODEL_REGISTRY[name] = cls MODEL_REGISTRY[name] = cls
return cls return cls
...@@ -48,9 +48,9 @@ func2task_index = {} ...@@ -48,9 +48,9 @@ func2task_index = {}
def register_task(name): def register_task(name):
def decorate(fn): def decorate(fn):
assert ( assert name not in TASK_REGISTRY, (
name not in TASK_REGISTRY f"task named '{name}' conflicts with existing registered task!"
), f"task named '{name}' conflicts with existing registered task!" )
TASK_REGISTRY[name] = fn TASK_REGISTRY[name] = fn
ALL_TASKS.add(name) ALL_TASKS.add(name)
...@@ -104,9 +104,9 @@ def register_metric(**args): ...@@ -104,9 +104,9 @@ def register_metric(**args):
]: ]:
if key in args: if key in args:
value = args[key] value = args[key]
assert ( assert value not in registry, (
value not in registry f"{key} named '{value}' conflicts with existing registered {key}!"
), f"{key} named '{value}' conflicts with existing registered {key}!" )
if key == "metric": if key == "metric":
registry[name] = fn registry[name] = fn
...@@ -140,9 +140,9 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable: ...@@ -140,9 +140,9 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
def register_aggregation(name: str): def register_aggregation(name: str):
def decorate(fn): def decorate(fn):
assert ( assert name not in AGGREGATION_REGISTRY, (
name not in AGGREGATION_REGISTRY f"aggregation named '{name}' conflicts with existing registered aggregation!"
), f"aggregation named '{name}' conflicts with existing registered aggregation!" )
AGGREGATION_REGISTRY[name] = fn AGGREGATION_REGISTRY[name] = fn
return fn return fn
......
...@@ -71,9 +71,9 @@ class ContextSampler: ...@@ -71,9 +71,9 @@ class ContextSampler:
) )
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, assistant_prefill: str = None): def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
prefix = assistant_prefill + " " if assistant_prefill else "" prefix = gen_prefix + " " if gen_prefix else ""
n_samples = ( n_samples = (
num_fewshot + 1 num_fewshot + 1
if self.config.fewshot_split == self.config.test_split if self.config.fewshot_split == self.config.test_split
...@@ -115,10 +115,10 @@ class ContextSampler: ...@@ -115,10 +115,10 @@ class ContextSampler:
doc: dict, doc: dict,
num_fewshot: int, num_fewshot: int,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None, gen_prefix: Optional[str] = None,
): ):
# TODO: Do we need any other delimiter # TODO: Do we need any other delimiter
prefix = assistant_prefill + " " if assistant_prefill else "" prefix = gen_prefix + " " if gen_prefix else ""
chat_history = [] chat_history = []
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
n_samples = ( n_samples = (
...@@ -163,7 +163,7 @@ class ContextSampler: ...@@ -163,7 +163,7 @@ class ContextSampler:
{ {
"role": "user", "role": "user",
"content": self.get_context( "content": self.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill doc, num_fewshot, gen_prefix=gen_prefix
), ),
} }
) )
...@@ -184,9 +184,9 @@ class FirstNSampler(ContextSampler): ...@@ -184,9 +184,9 @@ class FirstNSampler(ContextSampler):
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
""" """
assert ( assert n <= len(self.docs), (
n <= len(self.docs) f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." )
return self.docs[:n] return self.docs[:n]
......
...@@ -93,7 +93,7 @@ class TaskConfig(dict): ...@@ -93,7 +93,7 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: Optional[str] = None
assistant_prefill: Optional[str] = None gen_prefix: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks None # by default, not used in the code. allows for users to pass arbitrary info to tasks
) )
...@@ -371,6 +371,9 @@ class Task(abc.ABC): ...@@ -371,6 +371,9 @@ class Task(abc.ABC):
def doc_to_image(self, doc): def doc_to_image(self, doc):
raise NotImplementedError raise NotImplementedError
def doc_to_prefix(self, doc):
return ""
def build_all_requests( def build_all_requests(
self, self,
*, *,
...@@ -444,7 +447,7 @@ class Task(abc.ABC): ...@@ -444,7 +447,7 @@ class Task(abc.ABC):
apply_chat_template, apply_chat_template,
fewshot_as_multiturn, fewshot_as_multiturn,
chat_template, chat_template,
assistant_prefill=self.config.assistant_prefill, gen_prefix=self.doc_to_prefix(doc),
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -453,6 +456,7 @@ class Task(abc.ABC): ...@@ -453,6 +456,7 @@ class Task(abc.ABC):
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats), metadata=(self.config["task"], doc_id, self.config.repeats),
apply_chat_template=apply_chat_template, apply_chat_template=apply_chat_template,
chat_template=chat_template,
) )
if not isinstance(inst, list): if not isinstance(inst, list):
...@@ -544,13 +548,7 @@ class Task(abc.ABC): ...@@ -544,13 +548,7 @@ class Task(abc.ABC):
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context( def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
self,
doc,
num_fewshot,
rnd=None,
description=None,
):
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1006,7 +1004,7 @@ class ConfigurableTask(Task): ...@@ -1006,7 +1004,7 @@ class ConfigurableTask(Task):
labeled_examples: List[Dict[str, str]], labeled_examples: List[Dict[str, str]],
question: str, question: str,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None, gen_prefix: Optional[str] = None,
) -> None: ) -> None:
"""Adds a target question to the labeled examples list. """Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
...@@ -1022,8 +1020,8 @@ class ConfigurableTask(Task): ...@@ -1022,8 +1020,8 @@ class ConfigurableTask(Task):
else: else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant) # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question}) labeled_examples.append({"role": "user", "content": question})
if assistant_prefill: if gen_prefix:
labeled_examples.append({"role": "assistant", "content": assistant_prefill}) labeled_examples.append({"role": "assistant", "content": gen_prefix})
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context( def fewshot_context(
...@@ -1034,7 +1032,7 @@ class ConfigurableTask(Task): ...@@ -1034,7 +1032,7 @@ class ConfigurableTask(Task):
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Optional[Callable] = None,
assistant_prefill: Optional[str] = None, gen_prefix: Optional[str] = None,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1081,7 +1079,6 @@ class ConfigurableTask(Task): ...@@ -1081,7 +1079,6 @@ class ConfigurableTask(Task):
labeled_examples.append({"role": "system", "content": system_prompt}) labeled_examples.append({"role": "system", "content": system_prompt})
else: else:
labeled_examples = system_prompt labeled_examples = system_prompt
# if few-shot - append examples after the system prompt # if few-shot - append examples after the system prompt
if num_fewshot > 0: if num_fewshot > 0:
if apply_chat_template: if apply_chat_template:
...@@ -1090,25 +1087,27 @@ class ConfigurableTask(Task): ...@@ -1090,25 +1087,27 @@ class ConfigurableTask(Task):
doc, doc,
num_fewshot, num_fewshot,
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
) )
else: else:
labeled_examples += self.sampler.get_context( labeled_examples += self.sampler.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill doc, num_fewshot, gen_prefix=gen_prefix
) )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if apply_chat_template: if apply_chat_template:
if self.multiple_input: if self.multiple_input:
# TODO: append prefill? # TODO: append prefill?
if not labeled_examples:
return ""
return chat_template(labeled_examples) return chat_template(labeled_examples)
if isinstance(example, str): if isinstance(example, str):
self.append_target_question( self.append_target_question(
labeled_examples, labeled_examples,
example, example,
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
# for loglikelihood create a list of questions with appended choices # for loglikelihood create a list of questions with appended choices
elif isinstance(example, list): elif isinstance(example, list):
...@@ -1120,13 +1119,13 @@ class ConfigurableTask(Task): ...@@ -1120,13 +1119,13 @@ class ConfigurableTask(Task):
chat, chat,
ex, ex,
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
# TODO: append prefill? # TODO: append prefill?
labeled_examples_list.append( labeled_examples_list.append(
chat_template( chat_template(
chat, chat,
add_generation_prompt=False if assistant_prefill else True, add_generation_prompt=False if gen_prefix else True,
) )
) )
return labeled_examples_list return labeled_examples_list
...@@ -1138,24 +1137,24 @@ class ConfigurableTask(Task): ...@@ -1138,24 +1137,24 @@ class ConfigurableTask(Task):
labeled_examples, labeled_examples,
choices[example], choices[example],
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
else: else:
self.append_target_question( self.append_target_question(
labeled_examples, labeled_examples,
str(example), str(example),
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
# return lm.apply_chat_template(labeled_examples) # return lm.apply_chat_template(labeled_examples)
return chat_template( return chat_template(
labeled_examples, labeled_examples,
add_generation_prompt=False if assistant_prefill else True, add_generation_prompt=False if gen_prefix else True,
) )
else: else:
prefix = ( prefix = (
self.config.target_delimiter + assistant_prefill self.config.target_delimiter + gen_prefix
if assistant_prefill is not None if gen_prefix is not None
else "" else ""
) )
if self.multiple_input: if self.multiple_input:
...@@ -1342,10 +1341,19 @@ class ConfigurableTask(Task): ...@@ -1342,10 +1341,19 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc):
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
else:
return utils.apply_template(gen_prefix, doc)
return None
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
apply_chat_template = kwargs.pop("apply_chat_template", False) apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None)
aux_arguments = None aux_arguments = None
...@@ -1360,9 +1368,20 @@ class ConfigurableTask(Task): ...@@ -1360,9 +1368,20 @@ class ConfigurableTask(Task):
target_delimiter = "" target_delimiter = ""
if self.multiple_input: if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
# apply chat_template to choices if apply_chat_template
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
arguments = [ arguments = [
(ctx + choice, f"{target_delimiter}{cont}") for choice in choices (
ctx
+ (
chat_template([{"role": "user", "content": choice}])
if apply_chat_template
else choice
),
f"{target_delimiter}{cont}",
)
for choice in choices
] ]
else: else:
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
......
...@@ -151,7 +151,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d ...@@ -151,7 +151,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
elapsed = time.perf_counter() - start elapsed = time.perf_counter() - start
print(f"Read took {elapsed:0.5f} seconds.") print(f"Read took {elapsed:0.5f} seconds.")
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second") print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
print(duplicates) print(duplicates)
......
...@@ -34,9 +34,9 @@ class TakeKFilter(Filter): ...@@ -34,9 +34,9 @@ class TakeKFilter(Filter):
# need resp to be subscriptable to check below # need resp to be subscriptable to check below
resps = list(resps) resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert ( assert len(resps[0]) >= self.k, (
len(resps[0]) >= self.k f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." )
return map(lambda r: r[: self.k], resps) return map(lambda r: r[: self.k], resps)
......
...@@ -43,9 +43,9 @@ class MapFilter(Filter): ...@@ -43,9 +43,9 @@ class MapFilter(Filter):
""" """
if mapping_dict is None: if mapping_dict is None:
mapping_dict = {} mapping_dict = {}
assert isinstance( assert isinstance(mapping_dict, dict), (
mapping_dict, dict "Provided mapping_dict is not a dictionary"
), "Provided mapping_dict is not a dictionary" )
self.mapping_dict = mapping_dict self.mapping_dict = mapping_dict
self.default_value = default_value self.default_value = default_value
......
...@@ -488,7 +488,7 @@ class EvaluationTracker: ...@@ -488,7 +488,7 @@ class EvaluationTracker:
else: else:
dataset_summary += f"{self.general_config_tracker.model_name}\n" dataset_summary += f"{self.general_config_tracker.model_name}\n"
dataset_summary += ( dataset_summary += (
f"The dataset is composed of {len(card_metadata)-1} configuration(s), each one corresponding to one of the evaluated task.\n\n" f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each " f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n' 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
'An additional configuration "results" store all the aggregated results of the run.\n\n' 'An additional configuration "results" store all the aggregated results of the run.\n\n'
...@@ -501,7 +501,7 @@ class EvaluationTracker: ...@@ -501,7 +501,7 @@ class EvaluationTracker:
) )
dataset_summary += ( dataset_summary += (
"## Latest results\n\n" "## Latest results\n\n"
f'These are the [latest results from run {latest_datetime}]({last_results_file_path.replace("/resolve/", "/blob/")}) ' f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
"(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. " "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
'You find each in the results and the "latest" split for each eval):\n\n' 'You find each in the results and the "latest" split for each eval):\n\n'
f"```python\n{results_string}\n```" f"```python\n{results_string}\n```"
......
...@@ -225,7 +225,7 @@ class WandbLogger: ...@@ -225,7 +225,7 @@ class WandbLogger:
instance = [x["arguments"][0][0] for x in data] instance = [x["arguments"][0][0] for x in data]
labels = [x["arguments"][0][1] for x in data] labels = [x["arguments"][0][1] for x in data]
resps = [ resps = [
f'log probability of continuation is {x["resps"][0][0][0]} ' f"log probability of continuation is {x['resps'][0][0][0]} "
+ "\n\n" + "\n\n"
+ "continuation will {} generated with greedy sampling".format( + "continuation will {} generated with greedy sampling".format(
"not be" if not x["resps"][0][0][1] else "be" "not be" if not x["resps"][0][0][1] else "be"
...@@ -233,7 +233,7 @@ class WandbLogger: ...@@ -233,7 +233,7 @@ class WandbLogger:
for x in data for x in data
] ]
filtered_resps = [ filtered_resps = [
f'log probability of continuation is {x["filtered_resps"][0][0]} ' f"log probability of continuation is {x['filtered_resps'][0][0]} "
+ "\n\n" + "\n\n"
+ "continuation will {} generated with greedy sampling".format( + "continuation will {} generated with greedy sampling".format(
"not be" if not x["filtered_resps"][0][1] else "be" "not be" if not x["filtered_resps"][0][1] else "be"
......
...@@ -195,9 +195,9 @@ class TemplateAPI(TemplateLM): ...@@ -195,9 +195,9 @@ class TemplateAPI(TemplateLM):
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests""" """Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr): if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...] # for chat completions we need to decode the json string to list[dict,...]
assert ( assert self._batch_size == 1, (
self._batch_size == 1 "non-tokenized chat requests are only supported with batch_size=1"
), "non-tokenized chat requests are only supported with batch_size=1" )
# list[dict["role":..., "content":...],...] # list[dict["role":..., "content":...],...]
return json.loads(messages[0].prompt) return json.loads(messages[0].prompt)
...@@ -506,9 +506,9 @@ class TemplateAPI(TemplateLM): ...@@ -506,9 +506,9 @@ class TemplateAPI(TemplateLM):
return await tqdm_asyncio.gather(*tasks, desc="Requesting API") return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
assert ( assert self.tokenizer is not None, (
self.tokenizer is not None "Tokenizer is required for loglikelihood tasks to compute context lengths."
), "Tokenizer is required for loglikelihood tasks to compute context lengths." )
res = [] res = []
def _collate(req: LogLikelihoodInputs): def _collate(req: LogLikelihoodInputs):
......
...@@ -60,9 +60,9 @@ class HFMultimodalLM(HFLM): ...@@ -60,9 +60,9 @@ class HFMultimodalLM(HFLM):
super().__init__(pretrained, **kwargs) super().__init__(pretrained, **kwargs)
assert ( assert self.batch_size != "auto", (
self.batch_size != "auto" "Batch size 'auto' is not yet supported for hf-multimodal models."
), "Batch size 'auto' is not yet supported for hf-multimodal models." )
self.chat_applied: bool = False self.chat_applied: bool = False
# TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case # TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case
...@@ -82,9 +82,9 @@ class HFMultimodalLM(HFLM): ...@@ -82,9 +82,9 @@ class HFMultimodalLM(HFLM):
or getattr(self.config, "image_token_index", None) or getattr(self.config, "image_token_index", None)
) )
) )
assert ( assert self.image_token_id is not None, (
self.image_token_id is not None "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one." )
# get the string this token ID corresponds to # get the string this token ID corresponds to
self.image_token = self.tok_decode( self.image_token = self.tok_decode(
[self.image_token_id], skip_special_tokens=False [self.image_token_id], skip_special_tokens=False
......
...@@ -99,7 +99,9 @@ class HFLM(TemplateLM): ...@@ -99,7 +99,9 @@ class HFLM(TemplateLM):
eval_logger.warning( eval_logger.warning(
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
) )
assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" assert not parallelize, (
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
)
self._model = pretrained self._model = pretrained
self._device = self._model.device self._device = self._model.device
self._config = self._model.config self._config = self._model.config
...@@ -571,9 +573,9 @@ class HFLM(TemplateLM): ...@@ -571,9 +573,9 @@ class HFLM(TemplateLM):
if not autogptq and not gptqmodel: if not autogptq and not gptqmodel:
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
assert ( assert transformers.__version__ >= "4.30.0", (
transformers.__version__ >= "4.30.0" "load_in_4bit requires transformers >= 4.30.0"
), "load_in_4bit requires transformers >= 4.30.0" )
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0":
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
if model_kwargs.get("bnb_4bit_compute_dtype", None): if model_kwargs.get("bnb_4bit_compute_dtype", None):
...@@ -905,16 +907,16 @@ class HFLM(TemplateLM): ...@@ -905,16 +907,16 @@ class HFLM(TemplateLM):
self, logits: torch.Tensor, contlen: int = None, inplen: int = None self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.backend == "causal": if self.backend == "causal":
assert ( assert contlen and inplen, (
contlen and inplen "Must pass input len and cont. len to select scored logits for causal LM"
), "Must pass input len and cont. len to select scored logits for causal LM" )
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
elif self.backend == "seq2seq": elif self.backend == "seq2seq":
assert ( assert contlen and not inplen, (
contlen and not inplen "Selecting scored logits for Seq2SeqLM requires only cont. len"
), "Selecting scored logits for Seq2SeqLM requires only cont. len" )
# only discard right-padding. # only discard right-padding.
# the logits input to this fn only contain decoder-side tokens. # the logits input to this fn only contain decoder-side tokens.
logits = logits[:contlen] logits = logits[:contlen]
...@@ -1329,9 +1331,9 @@ class HFLM(TemplateLM): ...@@ -1329,9 +1331,9 @@ class HFLM(TemplateLM):
if self.backend == "causal": if self.backend == "causal":
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
assert ( assert max_ctx_len > 0, (
max_ctx_len > 0 f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." )
elif self.backend == "seq2seq": elif self.backend == "seq2seq":
# max len for inputs = encoder's whole max_length # max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length max_ctx_len = self.max_length
......
...@@ -206,7 +206,7 @@ class NEURON_HF(TemplateLM): ...@@ -206,7 +206,7 @@ class NEURON_HF(TemplateLM):
"Only float16/bfloat16/float32 are supported." "Only float16/bfloat16/float32 are supported."
) )
print(f"{'='*20} \n exporting model to neuron") print(f"{'=' * 20} \n exporting model to neuron")
self.model = CustomNeuronModelForCausalLM.from_pretrained( self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
...@@ -220,19 +220,17 @@ class NEURON_HF(TemplateLM): ...@@ -220,19 +220,17 @@ class NEURON_HF(TemplateLM):
) )
neuron_config = self.model.config.neuron neuron_config = self.model.config.neuron
print( print(
f"SUCCESS: neuron model exported with config {neuron_config}. \n {'='*20}" f"SUCCESS: neuron model exported with config {neuron_config}. \n {'=' * 20}"
) )
else: else:
print( print(f"{'=' * 20} \n loading neuron model with config {neuron_config}...")
f"{'='*20} \n loading neuron model with config" f" {neuron_config}..."
)
self.model = CustomNeuronModelForCausalLM.from_pretrained( self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
) )
print(f"SUCCESS: neuron model loaded. \n {'='*20}") print(f"SUCCESS: neuron model loaded. \n {'=' * 20}")
self.truncation = truncation self.truncation = truncation
...@@ -353,9 +351,9 @@ class NEURON_HF(TemplateLM): ...@@ -353,9 +351,9 @@ class NEURON_HF(TemplateLM):
) )
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(self, logits, contlen=None, inplen=None):
assert ( assert contlen and inplen, (
contlen and inplen "Must pass input len and cont. len to select scored logits for causal LM"
), "Must pass input len and cont. len to select scored logits for causal LM" )
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
......
...@@ -145,9 +145,9 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -145,9 +145,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
eos=None, eos=None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert ( assert type(messages) is not str, (
type(messages) is not str "chat-completions require the --apply_chat_template flag."
), "chat-completions require the --apply_chat_template flag." )
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs: if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens") max_tokens = gen_kwargs.pop("max_tokens")
...@@ -219,13 +219,12 @@ class OpenAICompletionsAPI(LocalCompletionsAPI): ...@@ -219,13 +219,12 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
return key return key
def loglikelihood(self, requests, **kwargs): def loglikelihood(self, requests, **kwargs):
assert ( assert self.model in [
self.model "babbage-002",
in [ "davinci-002",
"babbage-002", ], (
"davinci-002", f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
] )
), f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
return super().loglikelihood(requests, **kwargs) return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
...@@ -276,9 +275,9 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -276,9 +275,9 @@ class OpenAIChatCompletion(LocalChatCompletion):
eos="<|endoftext|>", eos="<|endoftext|>",
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert ( assert type(messages) is not str, (
type(messages) is not str "chat-completions require the --apply_chat_template flag."
), "chat-completions require the --apply_chat_template flag." )
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs: if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens") max_tokens = gen_kwargs.pop("max_tokens")
......
...@@ -21,9 +21,9 @@ class IPEXLM(HFLM): ...@@ -21,9 +21,9 @@ class IPEXLM(HFLM):
) -> None: ) -> None:
if "backend" in kwargs: if "backend" in kwargs:
# currently only supports causal models # currently only supports causal models
assert ( assert kwargs["backend"] == "causal", (
kwargs["backend"] == "causal" "Currently, only IPEXModelForCausalLM is supported."
), "Currently, only IPEXModelForCausalLM is supported." )
super().__init__( super().__init__(
backend=kwargs.pop("backend", "causal"), backend=kwargs.pop("backend", "causal"),
......
...@@ -29,9 +29,9 @@ class OptimumLM(HFLM): ...@@ -29,9 +29,9 @@ class OptimumLM(HFLM):
) -> None: ) -> None:
if "backend" in kwargs: if "backend" in kwargs:
# optimum currently only supports causal models # optimum currently only supports causal models
assert ( assert kwargs["backend"] == "causal", (
kwargs["backend"] == "causal" "Currently, only OVModelForCausalLM is supported."
), "Currently, only OVModelForCausalLM is supported." )
self.openvino_device = device self.openvino_device = device
......
...@@ -155,9 +155,9 @@ def pad_and_concat( ...@@ -155,9 +155,9 @@ def pad_and_concat(
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
seq2seq models. seq2seq models.
""" """
assert ( assert padding_side == "left" or padding_side == "right", (
padding_side == "left" or padding_side == "right" f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" )
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
if len(tensor.shape) == 2: if len(tensor.shape) == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment