Commit 8a8c2982 authored by lintangsutawika's avatar lintangsutawika
Browse files

rework doc_to_visual

parent 7d7a3a1c
...@@ -1278,14 +1278,21 @@ class ConfigurableTask(Task): ...@@ -1278,14 +1278,21 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_visual(self, doc: dict) -> Union[int, str, list]: def doc_to_visual(self, doc: Any) -> Union[int, str, list]:
if self.config.doc_to_visual is None:
eval_logger.error("doc_to_visual was called but not set in config")
else:
doc_to_visual = self.config.doc_to_visual
if isinstance(self.config.doc_to_visual, str): if isinstance(self.config.doc_to_visual, str):
assert self.config.doc_to_visual in self.features if doc_to_visual in self.features:
# Single Image. Still return a list for consistency return doc[doc_to_visual]
return doc[self.config.doc_to_visual]
else: else:
assert callable(self.config.doc_to_visual) return ast.literal_eval(utils.apply_template(doc_to_visual, doc))
return self.config.doc_to_visual(doc) elif callable(doc_to_visual):
return doc_to_visual(doc)
else:
return None
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
...@@ -1307,16 +1314,6 @@ class ConfigurableTask(Task): ...@@ -1307,16 +1314,6 @@ class ConfigurableTask(Task):
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys(): if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
...@@ -1325,31 +1322,37 @@ class ConfigurableTask(Task): ...@@ -1325,31 +1322,37 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating # here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice. # in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend( aux_arguments = [("", f"{choice}") for choice in choices]
[ else:
aux_arguments = None
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
if self.doc_to_visual:
if isinstance(arguments, list):
arguments = [arg + (self.doc_to_visual(doc),) for arg in arguments]
else:
arguments = arguments + (self.doc_to_visual(doc),)
if isinstance(arguments, type):
if aux_arguments is not None:
all_arg_list = [arguments, arg_list]
else:
all_arg_list = [arguments]
for arg_list in all_arg_list:
request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=("", "{}".format(choice)), arguments=arg,
idx=i, idx=i,
**kwargs, **kwargs,
) )
for i, choice in enumerate(choices) for i, arg in enumerate(arg_list)
] ]
)
return request_list
elif self.OUTPUT_TYPE == "generate_until": return request_list
if self.INPUT_TYPE == "text_image":
arguments = (
ctx,
deepcopy(self.config.generation_kwargs),
self.doc_to_visual,
doc,
self.config.task,
)
elif self.INPUT_TYPE == "text":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment