Unverified Commit 703fbffd authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

assistant prefill (#2615)

* add assistant prefix

* add arc_challenge from llama

* nit

* nit

* nit

* add assistant prefix

* add mmlu_llama

* nit

* nit

* Revert "nit"

This reverts commit 6a97f8356237305e375212b966b30e8de59dd4bc.

* fix regex bug

* add assistant_prefix to vllm

* add `Question:`

* add mmlu_pro

* add fewshot assistant_prefix

* use `assistant_prefill`

* typehints

* nits

* nits

* add to docs

* add readme
parent e86cece6
......@@ -37,6 +37,7 @@ Prompting / in-context formatting options:
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
- **assistant_prefill** (`str`, *optional*) — String to append after the <|assistant|> token. For example, if the task is to generate a question, the assistant_prefill could be "The answer is: " to prompt the model to generate an answer to the question. If not using a chat template then this string will be appended to the end of the prompt.
Runtime configuration options:
- **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input.
......
......@@ -113,13 +113,17 @@ class LM(abc.ABC):
"""
pass
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
:param chat_history: list[dict[str, str]]
A list of dictionaries with keys 'role' and 'content'.
Values are strings representing the role name and the content of the message, respectively.
:param add_generation_prompt: bool
Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
:return: str
A string representing the chat history in a format that can be used as input to the LM.
"""
......
from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union
import datasets
if TYPE_CHECKING:
from random import Random
from lm_eval.api.task import ConfigurableTask, Task
class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
def __init__(
self,
docs: list[dict],
task: Union["Task", "ConfigurableTask"],
fewshot_indices: Optional[Iterable] = None,
rnd: Optional["Random"] = None,
) -> None:
self.rnd = rnd
if not self.rnd:
raise ValueError(
......@@ -58,8 +71,9 @@ class ContextSampler:
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot):
def get_context(self, doc: dict, num_fewshot: int, assistant_prefill: str = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = assistant_prefill + " " if assistant_prefill else ""
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
......@@ -77,14 +91,14 @@ class ContextSampler:
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
labeled_examples += (
doc_content
if self.config.doc_to_choice is None or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content]
)
if self.config.doc_to_choice is None or isinstance(doc_content, str):
labeled_examples += doc_content
else:
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "":
labeled_examples += self.target_delimiter
labeled_examples += prefix
labeled_examples += (
str(doc_target[0])
if isinstance(doc_target, list)
......@@ -98,10 +112,13 @@ class ContextSampler:
def get_chat_context(
self,
doc,
num_fewshot,
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
):
# TODO: Do we need any other delimiter
prefix = assistant_prefill + " " if assistant_prefill else ""
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
......@@ -132,23 +149,28 @@ class ContextSampler:
chat_history.append(
{
"role": "assistant",
"content": str(doc_target[0])
"content": prefix + str(doc_target[0])
if isinstance(doc_target, list)
else doc_target
else prefix + doc_target
if self.config.doc_to_choice is None
or isinstance(doc_target, str)
else str(self.doc_to_choice(doc)[doc_target]),
else prefix + str(self.doc_to_choice(doc)[doc_target]),
}
)
else:
# get fewshot context as one user turn
chat_history.append(
{"role": "user", "content": self.get_context(doc, num_fewshot)}
{
"role": "user",
"content": self.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
),
}
)
return chat_history
def sample(self, n):
def sample(self, n: int):
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
......@@ -157,7 +179,7 @@ class ContextSampler:
class FirstNSampler(ContextSampler):
def sample(self, n) -> None:
def sample(self, n: int) -> None:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......@@ -169,7 +191,7 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler):
def sample(self, n) -> None:
def sample(self, n: int) -> None:
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
......@@ -179,7 +201,7 @@ class BalancedSampler(ContextSampler):
class ManualSampler(ContextSampler):
def sample(self, n) -> None:
def sample(self, n: int) -> None:
""" """
pass
......@@ -190,7 +212,7 @@ SAMPLER_REGISTRY = {
}
def get_sampler(name):
def get_sampler(name: str):
try:
return SAMPLER_REGISTRY[name]
except KeyError:
......
......@@ -93,6 +93,7 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
assistant_prefill: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
......@@ -443,6 +444,7 @@ class Task(abc.ABC):
apply_chat_template,
fewshot_as_multiturn,
chat_template,
assistant_prefill=self.config.assistant_prefill,
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
......@@ -1004,6 +1006,7 @@ class ConfigurableTask(Task):
labeled_examples: List[Dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
......@@ -1019,17 +1022,20 @@ class ConfigurableTask(Task):
else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question})
if assistant_prefill:
labeled_examples.append({"role": "assistant", "content": assistant_prefill})
@utils.positional_deprecated
def fewshot_context(
self,
doc: str,
doc: dict,
num_fewshot: int,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
) -> str:
assistant_prefill: Optional[str] = None,
) -> Union[str, List[str]]:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1048,7 +1054,6 @@ class ConfigurableTask(Task):
:returns: str
The fewshot context.
"""
if apply_chat_template:
labeled_examples = []
else:
......@@ -1082,19 +1087,28 @@ class ConfigurableTask(Task):
if apply_chat_template:
labeled_examples.extend(
self.sampler.get_chat_context(
doc, num_fewshot, fewshot_as_multiturn
doc,
num_fewshot,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
)
else:
labeled_examples += self.sampler.get_context(doc, num_fewshot)
labeled_examples += self.sampler.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
)
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
# TODO: append prefill?
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples, example, fewshot_as_multiturn
labeled_examples,
example,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
......@@ -1102,37 +1116,62 @@ class ConfigurableTask(Task):
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
self.append_target_question(
chat,
ex,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
# TODO: append prefill?
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=False if assistant_prefill else True,
)
)
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(
labeled_examples, choices[example], fewshot_as_multiturn
labeled_examples,
choices[example],
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
else:
self.append_target_question(
labeled_examples, str(example), fewshot_as_multiturn
labeled_examples,
str(example),
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=False if assistant_prefill else True,
)
else:
prefix = (
self.config.target_delimiter + assistant_prefill
if assistant_prefill is not None
else ""
)
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example
return labeled_examples + example + prefix
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
return [labeled_examples + ex + prefix for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
return labeled_examples + choices[example] + prefix
else:
return labeled_examples + str(example)
return labeled_examples + str(example) + prefix
def apply_filters(self):
def apply_filters(self) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
......@@ -1144,7 +1183,7 @@ class ConfigurableTask(Task):
def should_decontaminate(self):
return self.config.should_decontaminate
def doc_to_decontamination_query(self, doc):
def doc_to_decontamination_query(self, doc: dict):
if self.config.should_decontaminate:
if self.config.doc_to_decontamination_query is None:
return self.doc_to_text(doc)
......
......@@ -8,12 +8,17 @@ from lm_eval.api.registry import register_filter
@register_filter("regex")
class RegexFilter(Filter):
""" """
"""A filter that extracts values from text using regex pattern matching.
This filter applies a regex pattern to each model response and extracts matched values.
If no match is found, returns a fallback value. Useful for extracting structured data
(like numbers) from unstructured model outputs.
"""
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
group_select: int = 0,
fallback: str = "[invalid]",
) -> None:
"""
......@@ -25,7 +30,7 @@ class RegexFilter(Filter):
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......@@ -55,12 +60,9 @@ class RegexFilter(Filter):
@register_filter("remove_whitespace")
class WhitespaceFilter(Filter):
""" """
def __init__(self) -> None:
pass
"""Filters out leading whitespace from responses."""
def apply(self, resps, docs):
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def filter_set(inst):
filtered_resp = []
for resp in inst:
......@@ -105,7 +107,7 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs):
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......
......@@ -253,12 +253,15 @@ class TemplateAPI(TemplateLM):
return ""
def apply_chat_template(
self, chat_history: List[Dict[str, str]]
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
else:
# bit of a hack. We'll load back before sending to the API
......
......@@ -200,7 +200,9 @@ class HFMultimodalLM(HFLM):
return context_enc, continuation_enc, image_enc
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
self.chat_applied = True
if not self.interleave:
for content in chat_history:
......@@ -250,7 +252,9 @@ class HFMultimodalLM(HFLM):
)
return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True
chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......
......@@ -1382,13 +1382,18 @@ class HFLM(TemplateLM):
return res
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
try:
chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
except jinja2.exceptions.TemplateError:
eval_logger.warning(
......@@ -1396,7 +1401,10 @@ class HFLM(TemplateLM):
)
chat_history = [msg for msg in chat_history if msg["role"] != "system"]
chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
......
......@@ -184,14 +184,21 @@ class VLLM(TemplateLM):
def max_gen_toks(self):
return self._max_gen_toks
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_templated = self.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
@property
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
......
......@@ -144,7 +144,9 @@ class VLLM_VLM(VLLM):
)
return outputs
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
self.chat_applied = True
if not self.interleave:
for content in chat_history:
......@@ -194,7 +196,9 @@ class VLLM_VLM(VLLM):
)
return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True
chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
def generate_until(
......
tag:
- llama
task: arc_challenge_chat
dataset_path: allenai/ai2_arc
dataset_name: ARC-Challenge
output_type: generate_until
training_split: train
validation_split: validation
test_split: test
fewshot_split: train
doc_to_text: 'Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices.text[0]}}\nB. {{choices.text[1]}}\nC. {{choices.text[2]}}{% if choices.text|length > 3 %}\nD. {{choices.text[3]}}{% endif %}\nYour response should end with "The best answer is [the_answer_letter]" where the [the_answer_letter] is one of A, B, C or D.'
assistant_prefill: 'The best answer is'
fewshot_delimiter: "\n\n"
doc_to_target: "{{ 'ABCD'[answerKey|int - 1] if answerKey|string in '1234' else answerKey }}"
num_fewshot: 0
generation_kwargs:
max_gen_toks: 100
until:
- "\n\n"
- "."
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
metadata:
version: 1.0
# Task-name
### Paper
Title: LLAMA Evals
Abstract: Evals reproducing those provided by the LLAMA team in the Hugging Face repo.
`Short description of paper / benchmark goes here:`
Homepage: `https://huggingface.co/collections/meta-llama/llama-31-evals-66a2c5a14c2093e58298ac7f`
Note: The tasks are formatted to be run with apply_chat_template and fewshot_as_multiturn.
### Citation
```
BibTeX-formatted citation goes here
```
### Groups, Tags, and Tasks
#### Groups
* `group_name`: `Short description`
#### Tags
* `tag_name`: `Short description`
#### Tasks
* `mmlu_llama`: `generation variant of MMLU`
* `arc_chalenge_chat`: `generation variant of ARC-Challenge using MMLU format`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split
output_type: generate_until
test_split: test
fewshot_split: dev
fewshot_config:
sampler: first_n
doc_to_text: "Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D."
assistant_prefill: "The best answer is"
doc_to_target: "{{['A.','B.','C.','D.'][answer]}}"
num_fewshot: 5
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- "\\$"
- "\\.$"
generation_kwargs:
until:
- "."
max_gen_toks: 10
filter_list:
- name: strict_match
filter:
- function: remove_whitespace
- function: take_first
metadata:
version: 1.0
dataset_kwargs:
trust_remote_code: true
group: mmlu_llama_humanities
group_alias: humanities
task:
- mmlu_llama_humanities_tasks
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 1
group: mmlu_llama_other
group_alias: other
task:
- mmlu_llama_other_tasks
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 1
group: mmlu_llama_social_sciences
group_alias: social sciences
task:
- mmlu_llama_social_sciences_tasks
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 1
group: mmlu_llama_stem
group_alias: stem
task:
- mmlu_llama_stem_tasks
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 0
group: mmlu_llama
task:
- mmlu_llama_stem
- mmlu_llama_other
- mmlu_llama_social_sciences
- mmlu_llama_humanities
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 1
"dataset_name": "abstract_algebra"
"include": "_continuation_template_yaml"
"tag": "mmlu_llama_stem_tasks"
"task": "mmlu_llama_abstract_algebra"
"task_alias": "abstract algebra"
"dataset_name": "anatomy"
"include": "_continuation_template_yaml"
"tag": "mmlu_llama_stem_tasks"
"task": "mmlu_llama_anatomy"
"task_alias": "anatomy"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment