Unverified Commit 703fbffd authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

assistant prefill (#2615)

* add assistant prefix

* add arc_challenge from llama

* nit

* nit

* nit

* add assistant prefix

* add mmlu_llama

* nit

* nit

* Revert "nit"

This reverts commit 6a97f8356237305e375212b966b30e8de59dd4bc.

* fix regex bug

* add assistant_prefix to vllm

* add `Question:`

* add mmlu_pro

* add fewshot assistant_prefix

* use `assistant_prefill`

* typehints

* nits

* nits

* add to docs

* add readme
parent e86cece6
...@@ -37,6 +37,7 @@ Prompting / in-context formatting options: ...@@ -37,6 +37,7 @@ Prompting / in-context formatting options:
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks. - **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples. - **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested. - **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
- **assistant_prefill** (`str`, *optional*) — String to append after the <|assistant|> token. For example, if the task is to generate a question, the assistant_prefill could be "The answer is: " to prompt the model to generate an answer to the question. If not using a chat template then this string will be appended to the end of the prompt.
Runtime configuration options: Runtime configuration options:
- **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input. - **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input.
......
...@@ -113,13 +113,17 @@ class LM(abc.ABC): ...@@ -113,13 +113,17 @@ class LM(abc.ABC):
""" """
pass pass
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
""" """
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
:param chat_history: list[dict[str, str]] :param chat_history: list[dict[str, str]]
A list of dictionaries with keys 'role' and 'content'. A list of dictionaries with keys 'role' and 'content'.
Values are strings representing the role name and the content of the message, respectively. Values are strings representing the role name and the content of the message, respectively.
:param add_generation_prompt: bool
Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
:return: str :return: str
A string representing the chat history in a format that can be used as input to the LM. A string representing the chat history in a format that can be used as input to the LM.
""" """
......
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union
import datasets import datasets
if TYPE_CHECKING:
from random import Random
from lm_eval.api.task import ConfigurableTask, Task
class ContextSampler: class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: def __init__(
self,
docs: list[dict],
task: Union["Task", "ConfigurableTask"],
fewshot_indices: Optional[Iterable] = None,
rnd: Optional["Random"] = None,
) -> None:
self.rnd = rnd self.rnd = rnd
if not self.rnd: if not self.rnd:
raise ValueError( raise ValueError(
...@@ -58,8 +71,9 @@ class ContextSampler: ...@@ -58,8 +71,9 @@ class ContextSampler:
) )
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot): def get_context(self, doc: dict, num_fewshot: int, assistant_prefill: str = None):
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
prefix = assistant_prefill + " " if assistant_prefill else ""
n_samples = ( n_samples = (
num_fewshot + 1 num_fewshot + 1
if self.config.fewshot_split == self.config.test_split if self.config.fewshot_split == self.config.test_split
...@@ -77,14 +91,14 @@ class ContextSampler: ...@@ -77,14 +91,14 @@ class ContextSampler:
for doc in selected_docs: for doc in selected_docs:
doc_content = self.doc_to_text(doc) doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc) doc_target = self.doc_to_target(doc)
labeled_examples += ( if self.config.doc_to_choice is None or isinstance(doc_content, str):
doc_content labeled_examples += doc_content
if self.config.doc_to_choice is None or isinstance(doc_content, str) else:
else self.doc_to_choice(doc)[doc_content] labeled_examples += self.doc_to_choice(doc)[doc_content]
)
if doc_target != "": if doc_target != "":
labeled_examples += self.target_delimiter labeled_examples += self.target_delimiter
labeled_examples += prefix
labeled_examples += ( labeled_examples += (
str(doc_target[0]) str(doc_target[0])
if isinstance(doc_target, list) if isinstance(doc_target, list)
...@@ -98,10 +112,13 @@ class ContextSampler: ...@@ -98,10 +112,13 @@ class ContextSampler:
def get_chat_context( def get_chat_context(
self, self,
doc, doc: dict,
num_fewshot, num_fewshot: int,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
): ):
# TODO: Do we need any other delimiter
prefix = assistant_prefill + " " if assistant_prefill else ""
chat_history = [] chat_history = []
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
n_samples = ( n_samples = (
...@@ -132,23 +149,28 @@ class ContextSampler: ...@@ -132,23 +149,28 @@ class ContextSampler:
chat_history.append( chat_history.append(
{ {
"role": "assistant", "role": "assistant",
"content": str(doc_target[0]) "content": prefix + str(doc_target[0])
if isinstance(doc_target, list) if isinstance(doc_target, list)
else doc_target else prefix + doc_target
if self.config.doc_to_choice is None if self.config.doc_to_choice is None
or isinstance(doc_target, str) or isinstance(doc_target, str)
else str(self.doc_to_choice(doc)[doc_target]), else prefix + str(self.doc_to_choice(doc)[doc_target]),
} }
) )
else: else:
# get fewshot context as one user turn # get fewshot context as one user turn
chat_history.append( chat_history.append(
{"role": "user", "content": self.get_context(doc, num_fewshot)} {
"role": "user",
"content": self.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
),
}
) )
return chat_history return chat_history
def sample(self, n): def sample(self, n: int):
""" """
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
""" """
...@@ -157,7 +179,7 @@ class ContextSampler: ...@@ -157,7 +179,7 @@ class ContextSampler:
class FirstNSampler(ContextSampler): class FirstNSampler(ContextSampler):
def sample(self, n) -> None: def sample(self, n: int) -> None:
""" """
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...@@ -169,7 +191,7 @@ class FirstNSampler(ContextSampler): ...@@ -169,7 +191,7 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler): class BalancedSampler(ContextSampler):
def sample(self, n) -> None: def sample(self, n: int) -> None:
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random? TODO: what order should they be in? maybe random?
...@@ -179,7 +201,7 @@ class BalancedSampler(ContextSampler): ...@@ -179,7 +201,7 @@ class BalancedSampler(ContextSampler):
class ManualSampler(ContextSampler): class ManualSampler(ContextSampler):
def sample(self, n) -> None: def sample(self, n: int) -> None:
""" """ """ """
pass pass
...@@ -190,7 +212,7 @@ SAMPLER_REGISTRY = { ...@@ -190,7 +212,7 @@ SAMPLER_REGISTRY = {
} }
def get_sampler(name): def get_sampler(name: str):
try: try:
return SAMPLER_REGISTRY[name] return SAMPLER_REGISTRY[name]
except KeyError: except KeyError:
......
...@@ -93,6 +93,7 @@ class TaskConfig(dict): ...@@ -93,6 +93,7 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: Optional[str] = None
assistant_prefill: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks None # by default, not used in the code. allows for users to pass arbitrary info to tasks
) )
...@@ -443,6 +444,7 @@ class Task(abc.ABC): ...@@ -443,6 +444,7 @@ class Task(abc.ABC):
apply_chat_template, apply_chat_template,
fewshot_as_multiturn, fewshot_as_multiturn,
chat_template, chat_template,
assistant_prefill=self.config.assistant_prefill,
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -1004,6 +1006,7 @@ class ConfigurableTask(Task): ...@@ -1004,6 +1006,7 @@ class ConfigurableTask(Task):
labeled_examples: List[Dict[str, str]], labeled_examples: List[Dict[str, str]],
question: str, question: str,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
) -> None: ) -> None:
"""Adds a target question to the labeled examples list. """Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
...@@ -1019,17 +1022,20 @@ class ConfigurableTask(Task): ...@@ -1019,17 +1022,20 @@ class ConfigurableTask(Task):
else: else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant) # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question}) labeled_examples.append({"role": "user", "content": question})
if assistant_prefill:
labeled_examples.append({"role": "assistant", "content": assistant_prefill})
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context( def fewshot_context(
self, self,
doc: str, doc: dict,
num_fewshot: int, num_fewshot: int,
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Optional[Callable] = None,
) -> str: assistant_prefill: Optional[str] = None,
) -> Union[str, List[str]]:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1048,7 +1054,6 @@ class ConfigurableTask(Task): ...@@ -1048,7 +1054,6 @@ class ConfigurableTask(Task):
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
if apply_chat_template: if apply_chat_template:
labeled_examples = [] labeled_examples = []
else: else:
...@@ -1082,19 +1087,28 @@ class ConfigurableTask(Task): ...@@ -1082,19 +1087,28 @@ class ConfigurableTask(Task):
if apply_chat_template: if apply_chat_template:
labeled_examples.extend( labeled_examples.extend(
self.sampler.get_chat_context( self.sampler.get_chat_context(
doc, num_fewshot, fewshot_as_multiturn doc,
num_fewshot,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
) )
) )
else: else:
labeled_examples += self.sampler.get_context(doc, num_fewshot) labeled_examples += self.sampler.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
)
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if apply_chat_template: if apply_chat_template:
if self.multiple_input: if self.multiple_input:
# TODO: append prefill?
return chat_template(labeled_examples) return chat_template(labeled_examples)
if isinstance(example, str): if isinstance(example, str):
self.append_target_question( self.append_target_question(
labeled_examples, example, fewshot_as_multiturn labeled_examples,
example,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
) )
# for loglikelihood create a list of questions with appended choices # for loglikelihood create a list of questions with appended choices
elif isinstance(example, list): elif isinstance(example, list):
...@@ -1102,37 +1116,62 @@ class ConfigurableTask(Task): ...@@ -1102,37 +1116,62 @@ class ConfigurableTask(Task):
# copy chat history for each example and append the answer # copy chat history for each example and append the answer
for ex in example: for ex in example:
chat = deepcopy(labeled_examples) chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn) self.append_target_question(
labeled_examples_list.append(chat_template(chat)) chat,
ex,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
)
# TODO: append prefill?
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=False if assistant_prefill else True,
)
)
return labeled_examples_list return labeled_examples_list
# if example is an integer, append the choice or convert to string # if example is an integer, append the choice or convert to string
elif isinstance(example, int): elif isinstance(example, int):
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
self.append_target_question( self.append_target_question(
labeled_examples, choices[example], fewshot_as_multiturn labeled_examples,
choices[example],
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
) )
else: else:
self.append_target_question( self.append_target_question(
labeled_examples, str(example), fewshot_as_multiturn labeled_examples,
str(example),
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
) )
# return lm.apply_chat_template(labeled_examples) # return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples) return chat_template(
labeled_examples,
add_generation_prompt=False if assistant_prefill else True,
)
else: else:
prefix = (
self.config.target_delimiter + assistant_prefill
if assistant_prefill is not None
else ""
)
if self.multiple_input: if self.multiple_input:
return labeled_examples return labeled_examples
if isinstance(example, str): if isinstance(example, str):
return labeled_examples + example return labeled_examples + example + prefix
elif isinstance(example, list): elif isinstance(example, list):
return [labeled_examples + ex for ex in example] return [labeled_examples + ex + prefix for ex in example]
elif isinstance(example, int): elif isinstance(example, int):
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
return labeled_examples + choices[example] return labeled_examples + choices[example] + prefix
else: else:
return labeled_examples + str(example) return labeled_examples + str(example) + prefix
def apply_filters(self): def apply_filters(self) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
...@@ -1144,7 +1183,7 @@ class ConfigurableTask(Task): ...@@ -1144,7 +1183,7 @@ class ConfigurableTask(Task):
def should_decontaminate(self): def should_decontaminate(self):
return self.config.should_decontaminate return self.config.should_decontaminate
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc: dict):
if self.config.should_decontaminate: if self.config.should_decontaminate:
if self.config.doc_to_decontamination_query is None: if self.config.doc_to_decontamination_query is None:
return self.doc_to_text(doc) return self.doc_to_text(doc)
......
...@@ -8,12 +8,17 @@ from lm_eval.api.registry import register_filter ...@@ -8,12 +8,17 @@ from lm_eval.api.registry import register_filter
@register_filter("regex") @register_filter("regex")
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """A filter that extracts values from text using regex pattern matching.
This filter applies a regex pattern to each model response and extracts matched values.
If no match is found, returns a fallback value. Useful for extracting structured data
(like numbers) from unstructured model outputs.
"""
def __init__( def __init__(
self, self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)", regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0, group_select: int = 0,
fallback: str = "[invalid]", fallback: str = "[invalid]",
) -> None: ) -> None:
""" """
...@@ -25,7 +30,7 @@ class RegexFilter(Filter): ...@@ -25,7 +30,7 @@ class RegexFilter(Filter):
self.group_select = group_select self.group_select = group_select
self.fallback = fallback self.fallback = fallback
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -55,12 +60,9 @@ class RegexFilter(Filter): ...@@ -55,12 +60,9 @@ class RegexFilter(Filter):
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
""" """ """Filters out leading whitespace from responses."""
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
...@@ -105,7 +107,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -105,7 +107,7 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
......
...@@ -253,12 +253,15 @@ class TemplateAPI(TemplateLM): ...@@ -253,12 +253,15 @@ class TemplateAPI(TemplateLM):
return "" return ""
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]] self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]: ) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model.""" """Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests: if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
else: else:
# bit of a hack. We'll load back before sending to the API # bit of a hack. We'll load back before sending to the API
......
...@@ -200,7 +200,9 @@ class HFMultimodalLM(HFLM): ...@@ -200,7 +200,9 @@ class HFMultimodalLM(HFLM):
return context_enc, continuation_enc, image_enc return context_enc, continuation_enc, image_enc
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
self.chat_applied = True self.chat_applied = True
if not self.interleave: if not self.interleave:
for content in chat_history: for content in chat_history:
...@@ -250,7 +252,9 @@ class HFMultimodalLM(HFLM): ...@@ -250,7 +252,9 @@ class HFMultimodalLM(HFLM):
) )
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......
...@@ -1382,13 +1382,18 @@ class HFLM(TemplateLM): ...@@ -1382,13 +1382,18 @@ class HFLM(TemplateLM):
return res return res
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
""" """
try: try:
chat_templated = self.tokenizer.apply_chat_template( chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
except jinja2.exceptions.TemplateError: except jinja2.exceptions.TemplateError:
eval_logger.warning( eval_logger.warning(
...@@ -1396,7 +1401,10 @@ class HFLM(TemplateLM): ...@@ -1396,7 +1401,10 @@ class HFLM(TemplateLM):
) )
chat_history = [msg for msg in chat_history if msg["role"] != "system"] chat_history = [msg for msg in chat_history if msg["role"] != "system"]
chat_templated = self.tokenizer.apply_chat_template( chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
return chat_templated return chat_templated
......
...@@ -184,14 +184,21 @@ class VLLM(TemplateLM): ...@@ -184,14 +184,21 @@ class VLLM(TemplateLM):
def max_gen_toks(self): def max_gen_toks(self):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
""" """
return self.tokenizer.apply_chat_template( chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
return chat_templated
@property @property
def tokenizer_name(self) -> str: def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__") return self.tokenizer.name_or_path.replace("/", "__")
......
...@@ -144,7 +144,9 @@ class VLLM_VLM(VLLM): ...@@ -144,7 +144,9 @@ class VLLM_VLM(VLLM):
) )
return outputs return outputs
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
self.chat_applied = True self.chat_applied = True
if not self.interleave: if not self.interleave:
for content in chat_history: for content in chat_history:
...@@ -194,7 +196,9 @@ class VLLM_VLM(VLLM): ...@@ -194,7 +196,9 @@ class VLLM_VLM(VLLM):
) )
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
def generate_until( def generate_until(
......
tag:
- llama
task: arc_challenge_chat
dataset_path: allenai/ai2_arc
dataset_name: ARC-Challenge
output_type: generate_until
training_split: train
validation_split: validation
test_split: test
fewshot_split: train
doc_to_text: 'Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices.text[0]}}\nB. {{choices.text[1]}}\nC. {{choices.text[2]}}{% if choices.text|length > 3 %}\nD. {{choices.text[3]}}{% endif %}\nYour response should end with "The best answer is [the_answer_letter]" where the [the_answer_letter] is one of A, B, C or D.'
assistant_prefill: 'The best answer is'
fewshot_delimiter: "\n\n"
doc_to_target: "{{ 'ABCD'[answerKey|int - 1] if answerKey|string in '1234' else answerKey }}"
num_fewshot: 0
generation_kwargs:
max_gen_toks: 100
until:
- "\n\n"
- "."
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
metadata:
version: 1.0
# Task-name
### Paper
Title: LLAMA Evals
Abstract: Evals reproducing those provided by the LLAMA team in the Hugging Face repo.
`Short description of paper / benchmark goes here:`
Homepage: `https://huggingface.co/collections/meta-llama/llama-31-evals-66a2c5a14c2093e58298ac7f`
Note: The tasks are formatted to be run with apply_chat_template and fewshot_as_multiturn.
### Citation
```
BibTeX-formatted citation goes here
```
### Groups, Tags, and Tasks
#### Groups
* `group_name`: `Short description`
#### Tags
* `tag_name`: `Short description`
#### Tasks
* `mmlu_llama`: `generation variant of MMLU`
* `arc_chalenge_chat`: `generation variant of ARC-Challenge using MMLU format`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split
output_type: generate_until
test_split: test
fewshot_split: dev
fewshot_config:
sampler: first_n
doc_to_text: "Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D."
assistant_prefill: "The best answer is"
doc_to_target: "{{['A.','B.','C.','D.'][answer]}}"
num_fewshot: 5
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- "\\$"
- "\\.$"
generation_kwargs:
until:
- "."
max_gen_toks: 10
filter_list:
- name: strict_match
filter:
- function: remove_whitespace
- function: take_first
metadata:
version: 1.0
dataset_kwargs:
trust_remote_code: true
group: mmlu_llama_humanities
group_alias: humanities
task:
- mmlu_llama_humanities_tasks
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 1
group: mmlu_llama_other
group_alias: other
task:
- mmlu_llama_other_tasks
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 1
group: mmlu_llama_social_sciences
group_alias: social sciences
task:
- mmlu_llama_social_sciences_tasks
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 1
group: mmlu_llama_stem
group_alias: stem
task:
- mmlu_llama_stem_tasks
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 0
group: mmlu_llama
task:
- mmlu_llama_stem
- mmlu_llama_other
- mmlu_llama_social_sciences
- mmlu_llama_humanities
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [strict_match]
metadata:
version: 1
"dataset_name": "abstract_algebra"
"include": "_continuation_template_yaml"
"tag": "mmlu_llama_stem_tasks"
"task": "mmlu_llama_abstract_algebra"
"task_alias": "abstract algebra"
"dataset_name": "anatomy"
"include": "_continuation_template_yaml"
"tag": "mmlu_llama_stem_tasks"
"task": "mmlu_llama_anatomy"
"task_alias": "anatomy"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment