Commit 8fada609 authored by Baber's avatar Baber
Browse files

Merge branch 'main' into mathvista

parents 0007b74a 1208afd3
......@@ -29,7 +29,7 @@ repos:
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
rev: v0.9.3
hooks:
# Run the linter.
- id: ruff
......@@ -38,7 +38,7 @@ repos:
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.1
hooks:
- id: codespell
exclude: >
......
import warnings
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer
from transformers import AutoConfig
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
def evaluate_lm_eval(lens_model: HookedTransformer, tasks: list[str], **kwargs):
class HFLikeModelAdapter(nn.Module):
"""Adapts HookedTransformer to match the HuggingFace interface expected by lm-eval"""
def __init__(self, model: HookedTransformer):
super().__init__()
self.model = model
self.tokenizer = model.tokenizer
self.config = AutoConfig.from_pretrained(model.cfg.tokenizer_name)
self.device = model.cfg.device
self.tie_weights = lambda: self
def forward(self, input_ids=None, attention_mask=None, **kwargs):
output = self.model(input_ids, attention_mask=attention_mask, **kwargs)
# Make sure output has the expected .logits attribute
if not hasattr(output, "logits"):
if isinstance(output, torch.Tensor):
output.logits = output
return output
# Only delegate specific attributes we know we need
def to(self, *args, **kwargs):
return self.model.to(*args, **kwargs)
def eval(self):
self.model.eval()
return self
def train(self, mode=True):
self.model.train(mode)
return self
model = HFLikeModelAdapter(lens_model)
warnings.filterwarnings("ignore", message="Failed to get model SHA for")
results = evaluator.simple_evaluate(
model=HFLM(pretrained=model, tokenizer=model.tokenizer),
tasks=tasks,
verbosity="WARNING",
**kwargs,
)
return results
if __name__ == "__main__":
# Load base model
model = HookedTransformer.from_pretrained("pythia-70m")
res = evaluate_lm_eval(model, tasks=["arc_easy"])
print(res["results"])
......@@ -112,6 +112,4 @@ class ConfigurableGroup(abc.ABC):
return self._config.group
def __repr__(self):
return (
f"ConfigurableGroup(group={self.group}," f"group_alias={self.group_alias})"
)
return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
......@@ -527,9 +527,9 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
assert (
metrics is not None
), "Need to pass a list of each subtask's metric for this stderr aggregation"
assert metrics is not None, (
"Need to pass a list of each subtask's metric for this stderr aggregation"
)
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
......
......@@ -17,13 +17,13 @@ def register_model(*names):
def decorate(cls):
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert issubclass(cls, LM), (
f"Model '{name}' ({cls.__name__}) must extend LM class"
)
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
assert name not in MODEL_REGISTRY, (
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
)
MODEL_REGISTRY[name] = cls
return cls
......@@ -48,9 +48,9 @@ func2task_index = {}
def register_task(name):
def decorate(fn):
assert (
name not in TASK_REGISTRY
), f"task named '{name}' conflicts with existing registered task!"
assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!"
)
TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
......@@ -104,9 +104,9 @@ def register_metric(**args):
]:
if key in args:
value = args[key]
assert (
value not in registry
), f"{key} named '{value}' conflicts with existing registered {key}!"
assert value not in registry, (
f"{key} named '{value}' conflicts with existing registered {key}!"
)
if key == "metric":
registry[name] = fn
......@@ -140,9 +140,9 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
def register_aggregation(name: str):
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
assert name not in AGGREGATION_REGISTRY, (
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY[name] = fn
return fn
......
......@@ -71,9 +71,9 @@ class ContextSampler:
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, assistant_prefill: str = None):
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = assistant_prefill + " " if assistant_prefill else ""
prefix = gen_prefix + " " if gen_prefix else ""
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
......@@ -115,10 +115,10 @@ class ContextSampler:
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
gen_prefix: Optional[str] = None,
):
# TODO: Do we need any other delimiter
prefix = assistant_prefill + " " if assistant_prefill else ""
prefix = gen_prefix + " " if gen_prefix else ""
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
......@@ -163,7 +163,7 @@ class ContextSampler:
{
"role": "user",
"content": self.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
doc, num_fewshot, gen_prefix=gen_prefix
),
}
)
......@@ -184,9 +184,9 @@ class FirstNSampler(ContextSampler):
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
"""
assert (
n <= len(self.docs)
), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
assert n <= len(self.docs), (
f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
)
return self.docs[:n]
......
......@@ -93,7 +93,7 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
assistant_prefill: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
......@@ -371,6 +371,9 @@ class Task(abc.ABC):
def doc_to_image(self, doc):
raise NotImplementedError
def doc_to_prefix(self, doc):
return ""
def build_all_requests(
self,
*,
......@@ -444,7 +447,7 @@ class Task(abc.ABC):
apply_chat_template,
fewshot_as_multiturn,
chat_template,
assistant_prefill=self.config.assistant_prefill,
gen_prefix=self.doc_to_prefix(doc),
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
......@@ -453,6 +456,7 @@ class Task(abc.ABC):
ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats),
apply_chat_template=apply_chat_template,
chat_template=chat_template,
)
if not isinstance(inst, list):
......@@ -544,13 +548,7 @@ class Task(abc.ABC):
return len(re.split(r"\s+", doc))
@utils.positional_deprecated
def fewshot_context(
self,
doc,
num_fewshot,
rnd=None,
description=None,
):
def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1006,7 +1004,7 @@ class ConfigurableTask(Task):
labeled_examples: List[Dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
gen_prefix: Optional[str] = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
......@@ -1022,8 +1020,8 @@ class ConfigurableTask(Task):
else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question})
if assistant_prefill:
labeled_examples.append({"role": "assistant", "content": assistant_prefill})
if gen_prefix:
labeled_examples.append({"role": "assistant", "content": gen_prefix})
@utils.positional_deprecated
def fewshot_context(
......@@ -1034,7 +1032,7 @@ class ConfigurableTask(Task):
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
assistant_prefill: Optional[str] = None,
gen_prefix: Optional[str] = None,
) -> Union[str, List[str]]:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1081,7 +1079,6 @@ class ConfigurableTask(Task):
labeled_examples.append({"role": "system", "content": system_prompt})
else:
labeled_examples = system_prompt
# if few-shot - append examples after the system prompt
if num_fewshot > 0:
if apply_chat_template:
......@@ -1090,25 +1087,27 @@ class ConfigurableTask(Task):
doc,
num_fewshot,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
)
else:
labeled_examples += self.sampler.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
doc, num_fewshot, gen_prefix=gen_prefix
)
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
# TODO: append prefill?
if not labeled_examples:
return ""
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples,
example,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
......@@ -1120,13 +1119,13 @@ class ConfigurableTask(Task):
chat,
ex,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
# TODO: append prefill?
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=False if assistant_prefill else True,
add_generation_prompt=False if gen_prefix else True,
)
)
return labeled_examples_list
......@@ -1138,24 +1137,24 @@ class ConfigurableTask(Task):
labeled_examples,
choices[example],
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
else:
self.append_target_question(
labeled_examples,
str(example),
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=False if assistant_prefill else True,
add_generation_prompt=False if gen_prefix else True,
)
else:
prefix = (
self.config.target_delimiter + assistant_prefill
if assistant_prefill is not None
self.config.target_delimiter + gen_prefix
if gen_prefix is not None
else ""
)
if self.multiple_input:
......@@ -1342,10 +1341,19 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc):
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
else:
return utils.apply_template(gen_prefix, doc)
return None
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None)
aux_arguments = None
......@@ -1360,9 +1368,20 @@ class ConfigurableTask(Task):
target_delimiter = ""
if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx
# apply chat_template to choices if apply_chat_template
cont = self.doc_to_target(doc)
arguments = [
(ctx + choice, f"{target_delimiter}{cont}") for choice in choices
(
ctx
+ (
chat_template([{"role": "user", "content": choice}])
if apply_chat_template
else choice
),
f"{target_delimiter}{cont}",
)
for choice in choices
]
else:
# Otherwise they are placed in the continuation
......
......@@ -151,7 +151,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
elapsed = time.perf_counter() - start
print(f"Read took {elapsed:0.5f} seconds.")
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")
print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
print(duplicates)
......
......@@ -34,9 +34,9 @@ class TakeKFilter(Filter):
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k
assert (
len(resps[0]) >= self.k
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
assert len(resps[0]) >= self.k, (
f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
)
return map(lambda r: r[: self.k], resps)
......
......@@ -43,9 +43,9 @@ class MapFilter(Filter):
"""
if mapping_dict is None:
mapping_dict = {}
assert isinstance(
mapping_dict, dict
), "Provided mapping_dict is not a dictionary"
assert isinstance(mapping_dict, dict), (
"Provided mapping_dict is not a dictionary"
)
self.mapping_dict = mapping_dict
self.default_value = default_value
......
......@@ -488,7 +488,7 @@ class EvaluationTracker:
else:
dataset_summary += f"{self.general_config_tracker.model_name}\n"
dataset_summary += (
f"The dataset is composed of {len(card_metadata)-1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
'An additional configuration "results" store all the aggregated results of the run.\n\n'
......@@ -501,7 +501,7 @@ class EvaluationTracker:
)
dataset_summary += (
"## Latest results\n\n"
f'These are the [latest results from run {latest_datetime}]({last_results_file_path.replace("/resolve/", "/blob/")}) '
f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
"(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
'You find each in the results and the "latest" split for each eval):\n\n'
f"```python\n{results_string}\n```"
......
......@@ -225,7 +225,7 @@ class WandbLogger:
instance = [x["arguments"][0][0] for x in data]
labels = [x["arguments"][0][1] for x in data]
resps = [
f'log probability of continuation is {x["resps"][0][0][0]} '
f"log probability of continuation is {x['resps'][0][0][0]} "
+ "\n\n"
+ "continuation will {} generated with greedy sampling".format(
"not be" if not x["resps"][0][0][1] else "be"
......@@ -233,7 +233,7 @@ class WandbLogger:
for x in data
]
filtered_resps = [
f'log probability of continuation is {x["filtered_resps"][0][0]} '
f"log probability of continuation is {x['filtered_resps'][0][0]} "
+ "\n\n"
+ "continuation will {} generated with greedy sampling".format(
"not be" if not x["filtered_resps"][0][1] else "be"
......
......@@ -195,9 +195,9 @@ class TemplateAPI(TemplateLM):
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...]
assert (
self._batch_size == 1
), "non-tokenized chat requests are only supported with batch_size=1"
assert self._batch_size == 1, (
"non-tokenized chat requests are only supported with batch_size=1"
)
# list[dict["role":..., "content":...],...]
return json.loads(messages[0].prompt)
......@@ -506,9 +506,9 @@ class TemplateAPI(TemplateLM):
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
assert (
self.tokenizer is not None
), "Tokenizer is required for loglikelihood tasks to compute context lengths."
assert self.tokenizer is not None, (
"Tokenizer is required for loglikelihood tasks to compute context lengths."
)
res = []
def _collate(req: LogLikelihoodInputs):
......
......@@ -60,9 +60,9 @@ class HFMultimodalLM(HFLM):
super().__init__(pretrained, **kwargs)
assert (
self.batch_size != "auto"
), "Batch size 'auto' is not yet supported for hf-multimodal models."
assert self.batch_size != "auto", (
"Batch size 'auto' is not yet supported for hf-multimodal models."
)
self.chat_applied: bool = False
# TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case
......@@ -82,9 +82,9 @@ class HFMultimodalLM(HFLM):
or getattr(self.config, "image_token_index", None)
)
)
assert (
self.image_token_id is not None
), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
assert self.image_token_id is not None, (
"Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
)
# get the string this token ID corresponds to
self.image_token = self.tok_decode(
[self.image_token_id], skip_special_tokens=False
......
......@@ -99,7 +99,9 @@ class HFLM(TemplateLM):
eval_logger.warning(
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
)
assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
assert not parallelize, (
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
)
self._model = pretrained
self._device = self._model.device
self._config = self._model.config
......@@ -571,9 +573,9 @@ class HFLM(TemplateLM):
if not autogptq and not gptqmodel:
if model_kwargs.get("load_in_4bit", None):
assert (
transformers.__version__ >= "4.30.0"
), "load_in_4bit requires transformers >= 4.30.0"
assert transformers.__version__ >= "4.30.0", (
"load_in_4bit requires transformers >= 4.30.0"
)
if transformers.__version__ >= "4.30.0":
if model_kwargs.get("load_in_4bit", None):
if model_kwargs.get("bnb_4bit_compute_dtype", None):
......@@ -905,16 +907,16 @@ class HFLM(TemplateLM):
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor:
if self.backend == "causal":
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
assert contlen and inplen, (
"Must pass input len and cont. len to select scored logits for causal LM"
)
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
elif self.backend == "seq2seq":
assert (
contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len"
assert contlen and not inplen, (
"Selecting scored logits for Seq2SeqLM requires only cont. len"
)
# only discard right-padding.
# the logits input to this fn only contain decoder-side tokens.
logits = logits[:contlen]
......@@ -1329,9 +1331,9 @@ class HFLM(TemplateLM):
if self.backend == "causal":
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
assert (
max_ctx_len > 0
), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
assert max_ctx_len > 0, (
f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
)
elif self.backend == "seq2seq":
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
......
......@@ -206,7 +206,7 @@ class NEURON_HF(TemplateLM):
"Only float16/bfloat16/float32 are supported."
)
print(f"{'='*20} \n exporting model to neuron")
print(f"{'=' * 20} \n exporting model to neuron")
self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
......@@ -220,19 +220,17 @@ class NEURON_HF(TemplateLM):
)
neuron_config = self.model.config.neuron
print(
f"SUCCESS: neuron model exported with config {neuron_config}. \n {'='*20}"
f"SUCCESS: neuron model exported with config {neuron_config}. \n {'=' * 20}"
)
else:
print(
f"{'='*20} \n loading neuron model with config" f" {neuron_config}..."
)
print(f"{'=' * 20} \n loading neuron model with config {neuron_config}...")
self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage,
)
print(f"SUCCESS: neuron model loaded. \n {'='*20}")
print(f"SUCCESS: neuron model loaded. \n {'=' * 20}")
self.truncation = truncation
......@@ -353,9 +351,9 @@ class NEURON_HF(TemplateLM):
)
def _select_cont_toks(self, logits, contlen=None, inplen=None):
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
assert contlen and inplen, (
"Must pass input len and cont. len to select scored logits for causal LM"
)
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
......
......@@ -145,9 +145,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
eos=None,
**kwargs,
) -> dict:
assert (
type(messages) is not str
), "chat-completions require the --apply_chat_template flag."
assert type(messages) is not str, (
"chat-completions require the --apply_chat_template flag."
)
gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens")
......@@ -219,13 +219,12 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
return key
def loglikelihood(self, requests, **kwargs):
assert (
self.model
in [
"babbage-002",
"davinci-002",
]
), f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
assert self.model in [
"babbage-002",
"davinci-002",
], (
f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
)
return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......@@ -276,9 +275,9 @@ class OpenAIChatCompletion(LocalChatCompletion):
eos="<|endoftext|>",
**kwargs,
) -> dict:
assert (
type(messages) is not str
), "chat-completions require the --apply_chat_template flag."
assert type(messages) is not str, (
"chat-completions require the --apply_chat_template flag."
)
gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens")
......
......@@ -21,9 +21,9 @@ class IPEXLM(HFLM):
) -> None:
if "backend" in kwargs:
# currently only supports causal models
assert (
kwargs["backend"] == "causal"
), "Currently, only IPEXModelForCausalLM is supported."
assert kwargs["backend"] == "causal", (
"Currently, only IPEXModelForCausalLM is supported."
)
super().__init__(
backend=kwargs.pop("backend", "causal"),
......
......@@ -29,9 +29,9 @@ class OptimumLM(HFLM):
) -> None:
if "backend" in kwargs:
# optimum currently only supports causal models
assert (
kwargs["backend"] == "causal"
), "Currently, only OVModelForCausalLM is supported."
assert kwargs["backend"] == "causal", (
"Currently, only OVModelForCausalLM is supported."
)
self.openvino_device = device
......
......@@ -155,9 +155,9 @@ def pad_and_concat(
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
assert padding_side == "left" or padding_side == "right", (
f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
)
for i, tensor in enumerate(tensors):
if len(tensor.shape) == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment