"megatron/model/vscode:/vscode.git/clone" did not exist on "06fc51cef50fded88e0142f32b40dc615e39672a"
Unverified Commit 9dda03d6 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

fix gen_prefix (#2630)

* switch arg
parent 703fbffd
...@@ -71,9 +71,9 @@ class ContextSampler: ...@@ -71,9 +71,9 @@ class ContextSampler:
) )
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, assistant_prefill: str = None): def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
prefix = assistant_prefill + " " if assistant_prefill else "" prefix = gen_prefix + " " if gen_prefix else ""
n_samples = ( n_samples = (
num_fewshot + 1 num_fewshot + 1
if self.config.fewshot_split == self.config.test_split if self.config.fewshot_split == self.config.test_split
...@@ -115,10 +115,10 @@ class ContextSampler: ...@@ -115,10 +115,10 @@ class ContextSampler:
doc: dict, doc: dict,
num_fewshot: int, num_fewshot: int,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None, gen_prefix: Optional[str] = None,
): ):
# TODO: Do we need any other delimiter # TODO: Do we need any other delimiter
prefix = assistant_prefill + " " if assistant_prefill else "" prefix = gen_prefix + " " if gen_prefix else ""
chat_history = [] chat_history = []
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
n_samples = ( n_samples = (
...@@ -163,7 +163,7 @@ class ContextSampler: ...@@ -163,7 +163,7 @@ class ContextSampler:
{ {
"role": "user", "role": "user",
"content": self.get_context( "content": self.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill doc, num_fewshot, gen_prefix=gen_prefix
), ),
} }
) )
......
...@@ -93,7 +93,7 @@ class TaskConfig(dict): ...@@ -93,7 +93,7 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: Optional[str] = None
assistant_prefill: Optional[str] = None gen_prefix: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks None # by default, not used in the code. allows for users to pass arbitrary info to tasks
) )
...@@ -371,6 +371,9 @@ class Task(abc.ABC): ...@@ -371,6 +371,9 @@ class Task(abc.ABC):
def doc_to_image(self, doc): def doc_to_image(self, doc):
raise NotImplementedError raise NotImplementedError
def doc_to_prefix(self, doc):
return ""
def build_all_requests( def build_all_requests(
self, self,
*, *,
...@@ -444,7 +447,7 @@ class Task(abc.ABC): ...@@ -444,7 +447,7 @@ class Task(abc.ABC):
apply_chat_template, apply_chat_template,
fewshot_as_multiturn, fewshot_as_multiturn,
chat_template, chat_template,
assistant_prefill=self.config.assistant_prefill, gen_prefix=self.doc_to_prefix(doc),
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -544,13 +547,7 @@ class Task(abc.ABC): ...@@ -544,13 +547,7 @@ class Task(abc.ABC):
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context( def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
self,
doc,
num_fewshot,
rnd=None,
description=None,
):
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1006,7 +1003,7 @@ class ConfigurableTask(Task): ...@@ -1006,7 +1003,7 @@ class ConfigurableTask(Task):
labeled_examples: List[Dict[str, str]], labeled_examples: List[Dict[str, str]],
question: str, question: str,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
assistant_prefill: Optional[str] = None, gen_prefix: Optional[str] = None,
) -> None: ) -> None:
"""Adds a target question to the labeled examples list. """Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
...@@ -1022,8 +1019,8 @@ class ConfigurableTask(Task): ...@@ -1022,8 +1019,8 @@ class ConfigurableTask(Task):
else: else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant) # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question}) labeled_examples.append({"role": "user", "content": question})
if assistant_prefill: if gen_prefix:
labeled_examples.append({"role": "assistant", "content": assistant_prefill}) labeled_examples.append({"role": "assistant", "content": gen_prefix})
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context( def fewshot_context(
...@@ -1034,7 +1031,7 @@ class ConfigurableTask(Task): ...@@ -1034,7 +1031,7 @@ class ConfigurableTask(Task):
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Optional[Callable] = None,
assistant_prefill: Optional[str] = None, gen_prefix: Optional[str] = None,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1081,7 +1078,6 @@ class ConfigurableTask(Task): ...@@ -1081,7 +1078,6 @@ class ConfigurableTask(Task):
labeled_examples.append({"role": "system", "content": system_prompt}) labeled_examples.append({"role": "system", "content": system_prompt})
else: else:
labeled_examples = system_prompt labeled_examples = system_prompt
# if few-shot - append examples after the system prompt # if few-shot - append examples after the system prompt
if num_fewshot > 0: if num_fewshot > 0:
if apply_chat_template: if apply_chat_template:
...@@ -1090,12 +1086,12 @@ class ConfigurableTask(Task): ...@@ -1090,12 +1086,12 @@ class ConfigurableTask(Task):
doc, doc,
num_fewshot, num_fewshot,
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
) )
else: else:
labeled_examples += self.sampler.get_context( labeled_examples += self.sampler.get_context(
doc, num_fewshot, assistant_prefill=assistant_prefill doc, num_fewshot, gen_prefix=gen_prefix
) )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
...@@ -1108,7 +1104,7 @@ class ConfigurableTask(Task): ...@@ -1108,7 +1104,7 @@ class ConfigurableTask(Task):
labeled_examples, labeled_examples,
example, example,
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
# for loglikelihood create a list of questions with appended choices # for loglikelihood create a list of questions with appended choices
elif isinstance(example, list): elif isinstance(example, list):
...@@ -1120,13 +1116,13 @@ class ConfigurableTask(Task): ...@@ -1120,13 +1116,13 @@ class ConfigurableTask(Task):
chat, chat,
ex, ex,
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
# TODO: append prefill? # TODO: append prefill?
labeled_examples_list.append( labeled_examples_list.append(
chat_template( chat_template(
chat, chat,
add_generation_prompt=False if assistant_prefill else True, add_generation_prompt=False if gen_prefix else True,
) )
) )
return labeled_examples_list return labeled_examples_list
...@@ -1138,24 +1134,24 @@ class ConfigurableTask(Task): ...@@ -1138,24 +1134,24 @@ class ConfigurableTask(Task):
labeled_examples, labeled_examples,
choices[example], choices[example],
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
else: else:
self.append_target_question( self.append_target_question(
labeled_examples, labeled_examples,
str(example), str(example),
fewshot_as_multiturn, fewshot_as_multiturn,
assistant_prefill=assistant_prefill, gen_prefix=gen_prefix,
) )
# return lm.apply_chat_template(labeled_examples) # return lm.apply_chat_template(labeled_examples)
return chat_template( return chat_template(
labeled_examples, labeled_examples,
add_generation_prompt=False if assistant_prefill else True, add_generation_prompt=False if gen_prefix else True,
) )
else: else:
prefix = ( prefix = (
self.config.target_delimiter + assistant_prefill self.config.target_delimiter + gen_prefix
if assistant_prefill is not None if gen_prefix is not None
else "" else ""
) )
if self.multiple_input: if self.multiple_input:
...@@ -1342,6 +1338,14 @@ class ConfigurableTask(Task): ...@@ -1342,6 +1338,14 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc):
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
else:
return utils.apply_template(gen_prefix, doc)
return None
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
......
...@@ -9,7 +9,7 @@ validation_split: validation ...@@ -9,7 +9,7 @@ validation_split: validation
test_split: test test_split: test
fewshot_split: train fewshot_split: train
doc_to_text: 'Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices.text[0]}}\nB. {{choices.text[1]}}\nC. {{choices.text[2]}}{% if choices.text|length > 3 %}\nD. {{choices.text[3]}}{% endif %}\nYour response should end with "The best answer is [the_answer_letter]" where the [the_answer_letter] is one of A, B, C or D.' doc_to_text: 'Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices.text[0]}}\nB. {{choices.text[1]}}\nC. {{choices.text[2]}}{% if choices.text|length > 3 %}\nD. {{choices.text[3]}}{% endif %}\nYour response should end with "The best answer is [the_answer_letter]" where the [the_answer_letter] is one of A, B, C or D.'
assistant_prefill: 'The best answer is' gen_prefix: 'The best answer is'
fewshot_delimiter: "\n\n" fewshot_delimiter: "\n\n"
doc_to_target: "{{ 'ABCD'[answerKey|int - 1] if answerKey|string in '1234' else answerKey }}" doc_to_target: "{{ 'ABCD'[answerKey|int - 1] if answerKey|string in '1234' else answerKey }}"
num_fewshot: 0 num_fewshot: 0
......
...@@ -5,7 +5,7 @@ fewshot_split: dev ...@@ -5,7 +5,7 @@ fewshot_split: dev
fewshot_config: fewshot_config:
sampler: first_n sampler: first_n
doc_to_text: "Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D." doc_to_text: "Given the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: {{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D."
assistant_prefill: "The best answer is" gen_prefix: "The best answer is"
doc_to_target: "{{['A.','B.','C.','D.'][answer]}}" doc_to_target: "{{['A.','B.','C.','D.'][answer]}}"
num_fewshot: 5 num_fewshot: 5
metric_list: metric_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment