Commit 8a8c2982 authored by lintangsutawika's avatar lintangsutawika
Browse files

rework doc_to_visual

parent 7d7a3a1c
...@@ -1278,14 +1278,21 @@ class ConfigurableTask(Task): ...@@ -1278,14 +1278,21 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_visual(self, doc: dict) -> Union[int, str, list]: def doc_to_visual(self, doc: Any) -> Union[int, str, list]:
if self.config.doc_to_visual is None:
eval_logger.error("doc_to_visual was called but not set in config")
else:
doc_to_visual = self.config.doc_to_visual
if isinstance(self.config.doc_to_visual, str): if isinstance(self.config.doc_to_visual, str):
assert self.config.doc_to_visual in self.features if doc_to_visual in self.features:
# Single Image. Still return a list for consistency return doc[doc_to_visual]
return doc[self.config.doc_to_visual] else:
return ast.literal_eval(utils.apply_template(doc_to_visual, doc))
elif callable(doc_to_visual):
return doc_to_visual(doc)
else: else:
assert callable(self.config.doc_to_visual) return None
return self.config.doc_to_visual(doc)
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
...@@ -1307,16 +1314,6 @@ class ConfigurableTask(Task): ...@@ -1307,16 +1314,6 @@ class ConfigurableTask(Task):
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys(): if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
...@@ -1325,31 +1322,37 @@ class ConfigurableTask(Task): ...@@ -1325,31 +1322,37 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating # here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice. # in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend( aux_arguments = [("", f"{choice}") for choice in choices]
[ else:
Instance( aux_arguments = None
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
if self.INPUT_TYPE == "text_image": arguments = (ctx, deepcopy(self.config.generation_kwargs))
arguments = (
ctx, if self.doc_to_visual:
deepcopy(self.config.generation_kwargs), if isinstance(arguments, list):
self.doc_to_visual, arguments = [arg + (self.doc_to_visual(doc),) for arg in arguments]
doc, else:
self.config.task, arguments = arguments + (self.doc_to_visual(doc),)
)
elif self.INPUT_TYPE == "text": if isinstance(arguments, type):
arguments = (ctx, deepcopy(self.config.generation_kwargs)) if aux_arguments is not None:
all_arg_list = [arguments, arg_list]
else:
all_arg_list = [arguments]
for arg_list in all_arg_list:
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arg_list)
]
return request_list
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment