Commit 2b56339e authored by Baber's avatar Baber
Browse files

Merge branch 'main' into longcxt

parents 0b533339 703fbffd
......@@ -85,5 +85,6 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -U transformers peft
- name: Test with pytest
run: python -m pytest tests/models --showlocals -s -vv
......@@ -58,7 +58,7 @@ This mode supports a number of command-line arguments, the details of which can
* `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.
* `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list [here](https://docs.wandb.ai/ref/python/init). e.g., ```--wandb_args project=test-project,name=test-run```
* `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list [here](https://docs.wandb.ai/ref/python/init). e.g., ```--wandb_args project=test-project,name=test-run```. Also allows for the passing of the step to log things at (passed to `wandb.run.log`), e.g., `--wandb_args step=123`.
* `--hf_hub_log_args` : Logs evaluation results to Hugging Face Hub. Accepts a string with the arguments separated by commas. Available arguments:
* `hub_results_org` - organization name on Hugging Face Hub, e.g., `EleutherAI`. If not provided, the results will be pushed to the owner of the Hugging Face token,
......
......@@ -37,6 +37,7 @@ Prompting / in-context formatting options:
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
- **assistant_prefill** (`str`, *optional*) — String to append after the <|assistant|> token. For example, if the task is to generate a question, the assistant_prefill could be "The answer is: " to prompt the model to generate an answer to the question. If not using a chat template then this string will be appended to the end of the prompt.
Runtime configuration options:
- **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input.
......
......@@ -257,6 +257,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
parser.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
return parser
......@@ -404,6 +409,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
**request_caching_args,
)
......
......@@ -113,13 +113,17 @@ class LM(abc.ABC):
"""
pass
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
:param chat_history: list[dict[str, str]]
A list of dictionaries with keys 'role' and 'content'.
Values are strings representing the role name and the content of the message, respectively.
:param add_generation_prompt: bool
Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
:return: str
A string representing the chat history in a format that can be used as input to the LM.
"""
......
from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union
import datasets
if TYPE_CHECKING:
from random import Random
from lm_eval.api.task import ConfigurableTask, Task
class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
def __init__(
self,
docs: list[dict],
task: Union["Task", "ConfigurableTask"],
fewshot_indices: Optional[Iterable] = None,
rnd: Optional["Random"] = None,
) -> None:
self.rnd = rnd
if not self.rnd:
raise ValueError(
......@@ -58,8 +71,9 @@ class ContextSampler:
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot):
def get_context(self, doc: dict, num_fewshot: int, assistant_prefill: str = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = assistant_prefill + " " if assistant_prefill else ""
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
......@@ -77,14 +91,14 @@ class ContextSampler:
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
labeled_examples += (
doc_content
if self.config.doc_to_choice is None or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content]
)
if self.config.doc_to_choice is None or isinstance(doc_content, str):
labeled_examples += doc_content
else:
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "":
labeled_examples += self.target_delimiter
labeled_examples += prefix
labeled_examples += (
str(doc_target[0])
if isinstance(doc_target, list)
......@@ -98,10 +112,13 @@ class ContextSampler:
def get_chat_context(
self,
doc,
num_fewshot,
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
):
# TODO: Do we need any other delimiter
prefix = assistant_prefill + " " if assistant_prefill else ""
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
......@@ -132,23 +149,28 @@ class ContextSampler:
chat_history.append(
{
"role": "assistant",
"content": str(doc_target[0])
"content": prefix + str(doc_target[0])
if isinstance(doc_target, list)
else doc_target
else prefix + doc_target
if self.config.doc_to_choice is None
or isinstance(doc_target, str)
else str(self.doc_to_choice(doc)[doc_target]),
else prefix + str(self.doc_to_choice(doc)[doc_target]),
}
)
else:
# get fewshot context as one user turn
chat_history.append(
{"role": "user", "content": self.get_context(doc, num_fewshot)}
{
"role": "user",
"content": self.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
),
}
)
return chat_history
def sample(self, n):
def sample(self, n: int):
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
......@@ -157,7 +179,7 @@ class ContextSampler:
class FirstNSampler(ContextSampler):
def sample(self, n) -> None:
def sample(self, n: int) -> None:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......@@ -169,7 +191,7 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler):
def sample(self, n) -> None:
def sample(self, n: int) -> None:
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
......@@ -179,7 +201,7 @@ class BalancedSampler(ContextSampler):
class ManualSampler(ContextSampler):
def sample(self, n) -> None:
def sample(self, n: int) -> None:
""" """
pass
......@@ -190,7 +212,7 @@ SAMPLER_REGISTRY = {
}
def get_sampler(name):
def get_sampler(name: str):
try:
return SAMPLER_REGISTRY[name]
except KeyError:
......
......@@ -76,6 +76,7 @@ class TaskConfig(dict):
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
......@@ -93,6 +94,7 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
assistant_prefill: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
......@@ -400,7 +402,7 @@ class Task(abc.ABC):
)
cache_key += f"-tokenizer{tokenizer_name}"
cached_instances = load_from_cache(file_name=cache_key)
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
if cache_requests and cached_instances and not rewrite_requests_cache:
cached_instances = cached_instances[:limit]
......@@ -444,6 +446,7 @@ class Task(abc.ABC):
apply_chat_template,
fewshot_as_multiturn,
chat_template,
assistant_prefill=self.config.assistant_prefill,
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
......@@ -734,6 +737,9 @@ class ConfigurableTask(Task):
# mark the task as requiring multimodality.
self.MULTIMODAL = True
if self.config.unsafe_code is not False:
self.UNSAFE_CODE = True
if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path
......@@ -1012,6 +1018,7 @@ class ConfigurableTask(Task):
labeled_examples: List[Dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
......@@ -1027,17 +1034,20 @@ class ConfigurableTask(Task):
else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question})
if assistant_prefill:
labeled_examples.append({"role": "assistant", "content": assistant_prefill})
@utils.positional_deprecated
def fewshot_context(
self,
doc: str,
doc: dict,
num_fewshot: int,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
) -> str:
assistant_prefill: Optional[str] = None,
) -> Union[str, List[str]]:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1056,7 +1066,6 @@ class ConfigurableTask(Task):
:returns: str
The fewshot context.
"""
if apply_chat_template:
labeled_examples = []
else:
......@@ -1090,19 +1099,28 @@ class ConfigurableTask(Task):
if apply_chat_template:
labeled_examples.extend(
self.sampler.get_chat_context(
doc, num_fewshot, fewshot_as_multiturn
doc,
num_fewshot,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
)
else:
labeled_examples += self.sampler.get_context(doc, num_fewshot)
labeled_examples += self.sampler.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
)
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
# TODO: append prefill?
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples, example, fewshot_as_multiturn
labeled_examples,
example,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
......@@ -1110,37 +1128,62 @@ class ConfigurableTask(Task):
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
self.append_target_question(
chat,
ex,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
# TODO: append prefill?
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=False if assistant_prefill else True,
)
)
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(
labeled_examples, choices[example], fewshot_as_multiturn
labeled_examples,
choices[example],
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
else:
self.append_target_question(
labeled_examples, str(example), fewshot_as_multiturn
labeled_examples,
str(example),
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=False if assistant_prefill else True,
)
else:
prefix = (
self.config.target_delimiter + assistant_prefill
if assistant_prefill is not None
else ""
)
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example
return labeled_examples + example + prefix
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
return [labeled_examples + ex + prefix for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
return labeled_examples + choices[example] + prefix
else:
return labeled_examples + str(example)
return labeled_examples + str(example) + prefix
def apply_filters(self):
def apply_filters(self) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
......@@ -1152,7 +1195,7 @@ class ConfigurableTask(Task):
def should_decontaminate(self):
return self.config.should_decontaminate
def doc_to_decontamination_query(self, doc):
def doc_to_decontamination_query(self, doc: dict):
if self.config.should_decontaminate:
if self.config.doc_to_decontamination_query is None:
return self.doc_to_text(doc)
......@@ -1515,9 +1558,9 @@ class ConfigurableTask(Task):
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
elif (
type(gold) is not type(result)
and "bypass" not in self._metric_fn_list.keys()
# TODO: handle this better
elif type(gold) is not type(result) and not (
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
):
# cast gold to the same type as result
gold = type(result)(gold)
......@@ -1573,7 +1616,10 @@ class ConfigurableTask(Task):
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
# This allows for multiple metrics to be returned from the same function
for k, v in result_score.items():
result_dict[k] = v
return result_dict
result_dict[metric] = result_score
else:
raise ValueError(
......
......@@ -21,7 +21,9 @@ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
def load_from_cache(file_name):
def load_from_cache(file_name: str, cache: bool = False):
if not cache:
return
try:
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
......
......@@ -74,6 +74,7 @@ def simple_evaluate(
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -316,6 +317,7 @@ def simple_evaluate(
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
verbosity=verbosity,
confirm_run_unsafe_code=confirm_run_unsafe_code,
)
if lm.rank == 0:
......@@ -375,6 +377,7 @@ def evaluate(
apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False,
verbosity: str = "INFO",
confirm_run_unsafe_code: bool = False,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -384,6 +387,10 @@ def evaluate(
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests.
:param rewrite_requests_cache: bool, optional
Rewrites all the request cache if set to `True`.
:param bootstrap_iters:
Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
:param write_out: bool
......@@ -399,6 +406,10 @@ def evaluate(
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param verbosity: str
Verbosity level for logging
:param confirm_run_unsafe_code: bool
Whether to confirm running tasks marked as unsafe.
:return
Dictionary of results
"""
......@@ -425,13 +436,19 @@ def evaluate(
):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
# validation check: are we running multimodal task <-> non-multimodal model class, or vice-versa.
# validation checks:
# 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
# 2.are we running code that is marked as unsafe.
incompatible_tasks = []
for task_output in eval_tasks:
task: Task = task_output.task
if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name)
elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
raise ValueError(
f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
)
if len(incompatible_tasks) > 0:
if not getattr(lm, "MULTIMODAL", False):
raise ValueError(
......@@ -441,7 +458,7 @@ def evaluate(
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end multimodality validation check
# end validation check
# Cache the limit arg.
limit_arg = limit
......
......@@ -7,6 +7,7 @@ from typing import List, Optional, Tuple, Union
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import (
aggregate_subtask_metrics,
mean,
pooled_sample_stderr,
stderr_for_metric,
)
......@@ -99,7 +100,12 @@ class TaskOutput:
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
for (metric, filter_key), items in self.sample_metrics.items():
try:
agg_fn = self.task.aggregation()[metric]
except KeyError:
# This is when process results output an arbitrary metric
# TODO: Handle this better and allow other aggregate functions other than mean.
agg_fn = mean
metric_key = f"{metric},{filter_key}"
self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric?
......
......@@ -4,7 +4,7 @@ from typing import List
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
from . import extraction, selection, transformation
from . import custom, extraction, selection, transformation
def build_filter_ensemble(
......
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
@register_filter("custom")
class CustomFilter(Filter):
"""
Custom filter that applies a custom, user-defined function to the model responses.
"""
def __init__(self, **kwargs) -> None:
self.filter_fn = kwargs.pop("filter_fn")
super().__init__(**kwargs)
def apply(self, resps, docs):
return self.filter_fn(resps, docs)
......@@ -8,12 +8,17 @@ from lm_eval.api.registry import register_filter
@register_filter("regex")
class RegexFilter(Filter):
""" """
"""A filter that extracts values from text using regex pattern matching.
This filter applies a regex pattern to each model response and extracts matched values.
If no match is found, returns a fallback value. Useful for extracting structured data
(like numbers) from unstructured model outputs.
"""
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
group_select: int = 0,
fallback: str = "[invalid]",
) -> None:
"""
......@@ -25,7 +30,7 @@ class RegexFilter(Filter):
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......@@ -55,12 +60,9 @@ class RegexFilter(Filter):
@register_filter("remove_whitespace")
class WhitespaceFilter(Filter):
""" """
def __init__(self) -> None:
pass
"""Filters out leading whitespace from responses."""
def apply(self, resps, docs):
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def filter_set(inst):
filtered_resp = []
for resp in inst:
......@@ -105,7 +107,7 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs):
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......
......@@ -48,6 +48,9 @@ class WandbLogger:
self.wandb_args: Dict[str, Any] = kwargs
# pop the step key from the args to save for all logging calls
self.step = self.wandb_args.pop("step", None)
# initialize a W&B run
if wandb.run is None:
self.run = wandb.init(**self.wandb_args)
......@@ -152,11 +155,11 @@ class WandbLogger:
# log the complete eval result to W&B Table
table = make_table(["Tasks"] + columns, "results")
self.run.log({"evaluation/eval_results": table})
self.run.log({"evaluation/eval_results": table}, step=self.step)
if "groups" in self.results.keys():
table = make_table(["Groups"] + columns, "groups")
self.run.log({"evaluation/group_eval_results": table})
self.run.log({"evaluation/group_eval_results": table}, step=self.step)
def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B."""
......@@ -174,13 +177,13 @@ class WandbLogger:
"""Log evaluation results to W&B."""
# Log configs to wandb
configs = self._get_config()
self.run.config.update(configs)
self.run.config.update(configs, allow_val_change=self.step is not None)
wandb_summary, self.wandb_results = self._sanitize_results_dict()
# update wandb.run.summary with items that were removed
self.run.summary.update(wandb_summary)
# Log the evaluation metrics to wandb
self.run.log(self.wandb_results)
self.run.log(self.wandb_results, step=self.step)
# Log the evaluation metrics as W&B Table
self._log_results_as_table()
# Log the results dict as json to W&B Artifacts
......@@ -329,7 +332,7 @@ class WandbLogger:
# log the samples as a W&B Table
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
self.run.log({f"{task_name}_eval_results": df})
self.run.log({f"{task_name}_eval_results": df}, step=self.step)
# log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name)
......@@ -348,4 +351,4 @@ class WandbLogger:
# log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name)
self.run.log({f"{group}_eval_results": grouped_df})
self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
......@@ -253,12 +253,15 @@ class TemplateAPI(TemplateLM):
return ""
def apply_chat_template(
self, chat_history: List[Dict[str, str]]
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
else:
# bit of a hack. We'll load back before sending to the API
......
......@@ -200,7 +200,9 @@ class HFMultimodalLM(HFLM):
return context_enc, continuation_enc, image_enc
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
self.chat_applied = True
if not self.interleave:
for content in chat_history:
......@@ -250,7 +252,9 @@ class HFMultimodalLM(HFLM):
)
return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True
chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......
......@@ -90,6 +90,7 @@ class HFLM(TemplateLM):
delta: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None,
**kwargs,
) -> None:
super().__init__()
......@@ -164,6 +165,7 @@ class HFLM(TemplateLM):
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
)
# determine which of 'causal' and 'seq2seq' backends to use for HF models
......@@ -178,6 +180,7 @@ class HFLM(TemplateLM):
revision=revision,
trust_remote_code=trust_remote_code,
use_fast_tokenizer=use_fast_tokenizer,
gguf_file=gguf_file,
)
# if we passed `pretrained` as a string, initialize our model now
......@@ -196,6 +199,7 @@ class HFLM(TemplateLM):
delta=delta,
autogptq=autogptq,
gptqmodel=gptqmodel,
gguf_file=gguf_file,
**kwargs,
)
......@@ -508,12 +512,14 @@ class HFLM(TemplateLM):
pretrained: str,
revision: str = "main",
trust_remote_code: bool = False,
gguf_file: Optional[str] = None,
) -> None:
"""Return the model config for HuggingFace models"""
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
)
def _create_model(
......@@ -535,6 +541,7 @@ class HFLM(TemplateLM):
delta: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None,
**kwargs,
) -> None:
"""
......@@ -579,6 +586,7 @@ class HFLM(TemplateLM):
revision=revision,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
**model_kwargs,
)
else:
......@@ -676,6 +684,7 @@ class HFLM(TemplateLM):
revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
gguf_file: Optional[str] = None,
) -> None:
"""
Helper method during initialization.
......@@ -683,14 +692,21 @@ class HFLM(TemplateLM):
Create a tokenizer object corresponding to the correct
tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
"""
kwargs = {
"revision": revision,
"trust_remote_code": trust_remote_code,
}
# gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
if gguf_file is not None:
kwargs["gguf_file"] = gguf_file
else:
kwargs["use_fast"] = use_fast_tokenizer
if tokenizer:
if isinstance(tokenizer, str):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
tokenizer, **kwargs
)
else:
assert isinstance(
......@@ -705,10 +721,7 @@ class HFLM(TemplateLM):
# get the HF hub name via accessor on model
model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
model_name, **kwargs
)
return None
......@@ -1369,13 +1382,18 @@ class HFLM(TemplateLM):
return res
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
try:
chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
except jinja2.exceptions.TemplateError:
eval_logger.warning(
......@@ -1383,7 +1401,10 @@ class HFLM(TemplateLM):
)
chat_history = [msg for msg in chat_history if msg["role"] != "system"]
chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
......
import os
from functools import cached_property
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union
from lm_eval.api.registry import register_model
......@@ -68,7 +69,9 @@ class LocalCompletionsAPI(TemplateAPI):
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choice, ctxlen in zip(out["choices"], ctxlens):
for choice, ctxlen in zip(
sorted(out["choices"], key=itemgetter("index")), ctxlens
):
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
......@@ -87,8 +90,10 @@ class LocalCompletionsAPI(TemplateAPI):
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]:
res.append(choices["text"])
tmp[choices["index"]] = choices["text"]
res = res + tmp
return res
@property
......@@ -157,8 +162,10 @@ class LocalChatCompletion(LocalCompletionsAPI):
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]:
res.append(choices["message"]["content"])
tmp[choices["index"]] = choices["message"]["content"]
res = res + tmp
return res
def tok_encode(
......
......@@ -184,14 +184,21 @@ class VLLM(TemplateLM):
def max_gen_toks(self):
return self._max_gen_toks
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_templated = self.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
@property
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
......
......@@ -144,7 +144,9 @@ class VLLM_VLM(VLLM):
)
return outputs
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
self.chat_applied = True
if not self.interleave:
for content in chat_history:
......@@ -194,7 +196,9 @@ class VLLM_VLM(VLLM):
)
return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True
chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
def generate_until(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment