Unverified Commit 532909c0 authored by Kiersten Stokes's avatar Kiersten Stokes Committed by GitHub
Browse files

Ensure backwards compatibility in fewshot_context by using kwargs (#3079)


Signed-off-by: default avatarkiersten-stokes <kierstenstokes@gmail.com>
parent 8bc46207
...@@ -458,11 +458,13 @@ class Task(abc.ABC): ...@@ -458,11 +458,13 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now! # sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot, num_fewshot=0
system_instruction, if self.config.num_fewshot is None
apply_chat_template, else self.config.num_fewshot,
fewshot_as_multiturn, system_instruction=system_instruction,
chat_template, apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
gen_prefix=self.doc_to_prefix(doc), gen_prefix=self.doc_to_prefix(doc),
) )
......
...@@ -6,7 +6,6 @@ Addressing this need, we present Unitxt, an innovative library for customizable ...@@ -6,7 +6,6 @@ Addressing this need, we present Unitxt, an innovative library for customizable
import importlib.util import importlib.util
import re import re
from collections.abc import Callable
from functools import partial from functools import partial
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
...@@ -110,18 +109,10 @@ class Unitxt(ConfigurableTask): ...@@ -110,18 +109,10 @@ class Unitxt(ConfigurableTask):
def get_arguments(self, doc, ctx): def get_arguments(self, doc, ctx):
return (ctx, {"until": ["\n"]}) return (ctx, {"until": ["\n"]})
def fewshot_context( def fewshot_context(self, doc, **kwargs) -> str:
self,
doc: str,
num_fewshot: int,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
gen_prefix: Optional[str] = None,
) -> str:
if isinstance(self.doc_to_text(doc), list): if isinstance(self.doc_to_text(doc), list):
if apply_chat_template: if kwargs.get("apply_chat_template"):
chat_template = kwargs.get("chat_template")
formated_source = chat_template(self.doc_to_text(doc)) formated_source = chat_template(self.doc_to_text(doc))
return formated_source return formated_source
else: else:
...@@ -129,15 +120,7 @@ class Unitxt(ConfigurableTask): ...@@ -129,15 +120,7 @@ class Unitxt(ConfigurableTask):
"Got chat template format from Unitxt, but apply_chat_template is false. Add '--apply_chat_template' to command line." "Got chat template format from Unitxt, but apply_chat_template is false. Add '--apply_chat_template' to command line."
) )
else: else:
return super().fewshot_context( return super().fewshot_context(doc=doc, **kwargs)
doc=doc,
num_fewshot=num_fewshot,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
gen_prefix=gen_prefix,
)
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment