Unverified Commit 9dda03d6 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

fix gen_prefix (#2630)

* switch arg
parent 703fbffd
......@@ -71,9 +71,9 @@ class ContextSampler:
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, assistant_prefill: str = None):
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = assistant_prefill + " " if assistant_prefill else ""
prefix = gen_prefix + " " if gen_prefix else ""
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
......@@ -115,10 +115,10 @@ class ContextSampler:
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
gen_prefix: Optional[str] = None,
):
# TODO: Do we need any other delimiter
prefix = assistant_prefill + " " if assistant_prefill else ""
prefix = gen_prefix + " " if gen_prefix else ""
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
......@@ -163,7 +163,7 @@ class ContextSampler:
{
"role": "user",
"content": self.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
doc, num_fewshot, gen_prefix=gen_prefix
),
}
)
......
......@@ -93,7 +93,7 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
assistant_prefill: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
......@@ -371,6 +371,9 @@ class Task(abc.ABC):
def doc_to_image(self, doc):
raise NotImplementedError
def doc_to_prefix(self, doc):
return ""
def build_all_requests(
self,
*,
......@@ -444,7 +447,7 @@ class Task(abc.ABC):
apply_chat_template,
fewshot_as_multiturn,
chat_template,
assistant_prefill=self.config.assistant_prefill,
gen_prefix=self.doc_to_prefix(doc),
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
......@@ -544,13 +547,7 @@ class Task(abc.ABC):
return len(re.split(r"\s+", doc))
@utils.positional_deprecated
def fewshot_context(
self,
doc,
num_fewshot,
rnd=None,
description=None,
):
def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1006,7 +1003,7 @@ class ConfigurableTask(Task):
labeled_examples: List[Dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None,
gen_prefix: Optional[str] = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
......@@ -1022,8 +1019,8 @@ class ConfigurableTask(Task):
else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question})
if assistant_prefill:
labeled_examples.append({"role": "assistant", "content": assistant_prefill})
if gen_prefix:
labeled_examples.append({"role": "assistant", "content": gen_prefix})
@utils.positional_deprecated
def fewshot_context(
......@@ -1034,7 +1031,7 @@ class ConfigurableTask(Task):
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
assistant_prefill: Optional[str] = None,
gen_prefix: Optional[str] = None,
) -> Union[str, List[str]]:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1081,7 +1078,6 @@ class ConfigurableTask(Task):
labeled_examples.append({"role": "system", "content": system_prompt})
else:
labeled_examples = system_prompt
# if few-shot - append examples after the system prompt
if num_fewshot > 0:
if apply_chat_template:
......@@ -1090,12 +1086,12 @@ class ConfigurableTask(Task):
doc,
num_fewshot,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
)
else:
labeled_examples += self.sampler.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill
doc, num_fewshot, gen_prefix=gen_prefix
)
example = self.doc_to_text(doc)
......@@ -1108,7 +1104,7 @@ class ConfigurableTask(Task):
labeled_examples,
example,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
......@@ -1120,13 +1116,13 @@ class ConfigurableTask(Task):
chat,
ex,
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
# TODO: append prefill?
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=False if assistant_prefill else True,
add_generation_prompt=False if gen_prefix else True,
)
)
return labeled_examples_list
......@@ -1138,24 +1134,24 @@ class ConfigurableTask(Task):
labeled_examples,
choices[example],
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
else:
self.append_target_question(
labeled_examples,
str(example),
fewshot_as_multiturn,
assistant_prefill=assistant_prefill,
gen_prefix=gen_prefix,
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=False if assistant_prefill else True,
add_generation_prompt=False if gen_prefix else True,
)
else:
prefix = (
self.config.target_delimiter + assistant_prefill
if assistant_prefill is not None
self.config.target_delimiter + gen_prefix
if gen_prefix is not None
else ""
)
if self.multiple_input:
......@@ -1342,6 +1338,14 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc):
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
else:
return utils.apply_template(gen_prefix, doc)
return None
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
......
......@@ -9,7 +9,7 @@ validation_split: validation
test_split: test
fewshot_split: train
doc_to_text: 'Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices.text[0]}}\nB. {{choices.text[1]}}\nC. {{choices.text[2]}}{% if choices.text|length > 3 %}\nD. {{choices.text[3]}}{% endif %}\nYour response should end with "The best answer is [the_answer_letter]" where the [the_answer_letter] is one of A, B, C or D.'
assistant_prefill: 'The best answer is'
gen_prefix: 'The best answer is'
fewshot_delimiter: "\n\n"
doc_to_target: "{{ 'ABCD'[answerKey|int - 1] if answerKey|string in '1234' else answerKey }}"
num_fewshot: 0
......
......@@ -5,7 +5,7 @@ fewshot_split: dev
fewshot_config:
sampler: first_n
doc_to_text: "Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D."
assistant_prefill: "The best answer is"
gen_prefix: "The best answer is"
doc_to_target: "{{['A.','B.','C.','D.'][answer]}}"
num_fewshot: 5
metric_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment