Commit 34a079e3 authored by lintangsutawika's avatar lintangsutawika
Browse files

add doc_to_image

parent 67a990e7
......@@ -75,6 +75,7 @@ class TaskConfig(dict):
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
......@@ -365,19 +366,22 @@ class Task(abc.ABC):
def doc_to_target(self, doc):
pass
@abc.abstractmethod
def doc_to_image(self, doc):
pass
def build_all_requests(
self,
*,
limit: Union[int, None] = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
limit=None,
rank=None,
world_size=None,
cache_requests=False,
rewrite_requests_cache=False,
system_instruction=None,
apply_chat_template=False,
fewshot_as_multiturn=False,
lm=None,
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
......@@ -392,7 +396,7 @@ class Task(abc.ABC):
if system_instruction is not None
else ""
)
cache_key += f"-tokenizer{tokenizer_name}"
cache_key += f"-tokenizer{lm.tokenizer_name}" if apply_chat_template else ""
cached_instances = load_from_cache(file_name=cache_key)
......@@ -437,7 +441,7 @@ class Task(abc.ABC):
system_instruction,
apply_chat_template,
fewshot_as_multiturn,
chat_template,
lm,
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
......@@ -1015,7 +1019,7 @@ class ConfigurableTask(Task):
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
lm=None,
) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1030,8 +1034,8 @@ class ConfigurableTask(Task):
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param chat_template: Callable
Chat template to be applied to the fewshot context.
:param lm:
Language model with definition of the tokenizer/function to use for applying the chat template.
:returns: str
The fewshot context.
"""
......@@ -1078,7 +1082,7 @@ class ConfigurableTask(Task):
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
return chat_template(labeled_examples)
return lm.apply_chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples, example, fewshot_as_multiturn
......@@ -1090,7 +1094,7 @@ class ConfigurableTask(Task):
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
labeled_examples_list.append(lm.apply_chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
......@@ -1104,7 +1108,7 @@ class ConfigurableTask(Task):
labeled_examples, str(example), fewshot_as_multiturn
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
return lm.apply_chat_template(labeled_examples)
else:
if self.multiple_input:
return labeled_examples
......@@ -1261,9 +1265,27 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_image(self, doc: Any) -> Union[int, str, list]:
if self.config.doc_to_image is None:
eval_logger.error("doc_to_image was called but not set in config")
else:
doc_to_image = self.config.doc_to_image
if isinstance(self.config.doc_to_image, str):
if doc_to_image in self.features:
return doc[doc_to_image]
else:
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
elif callable(doc_to_image):
return doc_to_image(doc)
else:
return None
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
aux_arguments = None
if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
......@@ -1281,16 +1303,6 @@ class ConfigurableTask(Task):
# Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
......@@ -1299,25 +1311,52 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
aux_arguments = [("", f"{choice}") for choice in choices]
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
multimodal_arg = {}
if self.doc_to_image:
multimodal_arg = {
**multimodal_arg,
**{"visual": self.doc_to_image(doc)},
}
if bool(multimodal_arg):
if isinstance(arguments, list):
arguments = [arg + (multimodal_arg,) for arg in arguments]
else:
arguments = arguments + (multimodal_arg,)
if isinstance(arguments, type):
if aux_arguments is not None:
all_arg_list = [arguments, aux_arguments]
else:
all_arg_list = [arguments]
request_list = []
for arg_list in all_arg_list:
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
arguments=arg,
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
for i, arg in enumerate(arg_list)
]
)
return request_list
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
return request_list
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs,
)
def process_results(self, doc, results):
......@@ -1526,7 +1565,7 @@ class ConfigurableTask(Task):
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
f"num_samples={len(self.eval_docs)})",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment