Unverified Commit 532909c0 authored by Kiersten Stokes's avatar Kiersten Stokes Committed by GitHub
Browse files

Ensure backwards compatibility in fewshot_context by using kwargs (#3079)


Signed-off-by: default avatarkiersten-stokes <kierstenstokes@gmail.com>
parent 8bc46207
......@@ -458,11 +458,13 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
system_instruction,
apply_chat_template,
fewshot_as_multiturn,
chat_template,
num_fewshot=0
if self.config.num_fewshot is None
else self.config.num_fewshot,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
gen_prefix=self.doc_to_prefix(doc),
)
......
......@@ -6,7 +6,6 @@ Addressing this need, we present Unitxt, an innovative library for customizable
import importlib.util
import re
from collections.abc import Callable
from functools import partial
from typing import Any, Dict, Optional
......@@ -110,18 +109,10 @@ class Unitxt(ConfigurableTask):
def get_arguments(self, doc, ctx):
return (ctx, {"until": ["\n"]})
def fewshot_context(
self,
doc: str,
num_fewshot: int,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
gen_prefix: Optional[str] = None,
) -> str:
def fewshot_context(self, doc, **kwargs) -> str:
if isinstance(self.doc_to_text(doc), list):
if apply_chat_template:
if kwargs.get("apply_chat_template"):
chat_template = kwargs.get("chat_template")
formated_source = chat_template(self.doc_to_text(doc))
return formated_source
else:
......@@ -129,15 +120,7 @@ class Unitxt(ConfigurableTask):
"Got chat template format from Unitxt, but apply_chat_template is false. Add '--apply_chat_template' to command line."
)
else:
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
gen_prefix=gen_prefix,
)
return super().fewshot_context(doc=doc, **kwargs)
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment