Commit b7449cd9 authored by lintangsutawika's avatar lintangsutawika
Browse files

process focuses on doc_to_text, doc_to_target, and doc_to_choice

parent 5978dc4e
...@@ -10,6 +10,10 @@ class Sampler: ...@@ -10,6 +10,10 @@ class Sampler:
self.target_delimiter = self.config.target_delimiter self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter self.fewshot_delimiter = self.config.fewshot_delimiter
self.doc_to_text = self.task.doc_to_text
self.doc_to_target = self.task.doc_to_target
self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs() self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
...@@ -34,16 +38,19 @@ class Sampler: ...@@ -34,16 +38,19 @@ class Sampler:
self.fewshot_delimiter.join( self.fewshot_delimiter.join(
[ [
# TODO: is separating doc_to_text and doc_to_target by one space always desired? # TODO: is separating doc_to_text and doc_to_target by one space always desired?
self.task.doc_to_text(doc) self.doc_to_text(doc)
+ self.target_delimiter + self.target_delimiter
+ self.task.doc_to_target(doc) + (
self.doc_to_target(doc)
if self.config.doc_to_choice is None
else self.doc_to_choice(doc)[self.doc_to_target(doc)]
)
for doc in selected_docs for doc in selected_docs
] ]
) )
+ self.fewshot_delimiter + self.fewshot_delimiter
) )
# only returns the fewshot context! Does not append the document, do this outside the object
return labeled_examples return labeled_examples
def sample(self, n): def sample(self, n):
......
...@@ -646,8 +646,10 @@ class ConfigurableTask(Task): ...@@ -646,8 +646,10 @@ class ConfigurableTask(Task):
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
# Test One Doc # Test One Doc
self.features = list(docs.features.keys())
self.multiple_input = 0
self.multiple_target = 0
test_doc = docs[0] test_doc = docs[0]
self.features = list(test_doc.keys())
test_text = self.doc_to_text(test_doc) test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc) test_target = self.doc_to_target(test_doc)
...@@ -655,17 +657,17 @@ class ConfigurableTask(Task): ...@@ -655,17 +657,17 @@ class ConfigurableTask(Task):
test_choice = self.doc_to_choice(test_doc) test_choice = self.doc_to_choice(test_doc)
if type(test_choice) is not list: if type(test_choice) is not list:
eval_logger.error("doc_to_choice must return list") eval_logger.error("doc_to_choice must return list")
else:
num_choice = len(test_choice)
if self._config.output_type == "multiple_choice": if type(test_text) is int:
if type(test_text) is list: self.multiple_input = num_choice
self.multiple_input = len(test_text)
elif type(test_text) is str:
self.multiple_input = 0
if type(test_target) is list: if type(test_target) is list:
self.multiple_output = len(test_target) self.multiple_target = len(test_target)
else:
self.multiple_output = 0 eval_logger.info(f" Input choices: {self.multiple_input}")
eval_logger.info(f"Output choices: {self.multiple_target}")
def download(self, dataset_kwargs=None): def download(self, dataset_kwargs=None):
...@@ -749,6 +751,9 @@ class ConfigurableTask(Task): ...@@ -749,6 +751,9 @@ class ConfigurableTask(Task):
if type(doc_to_text) == str: if type(doc_to_text) == str:
if doc_to_text in self.features: if doc_to_text in self.features:
# if self._config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]]
# else:
return doc[doc_to_text] return doc[doc_to_text]
else: else:
return utils.apply_template(doc_to_text, doc) return utils.apply_template(doc_to_text, doc)
...@@ -770,10 +775,10 @@ class ConfigurableTask(Task): ...@@ -770,10 +775,10 @@ class ConfigurableTask(Task):
if type(doc_to_target) == str: if type(doc_to_target) == str:
if doc_to_target in self.features: if doc_to_target in self.features:
if self._config.doc_to_choice is not None: # if self._config.doc_to_choice is not None:
return self.doc_to_choice(doc)[doc[doc_to_target]] # return self.doc_to_choice(doc)[doc[doc_to_target]]
else: # else:
return doc[doc_to_target] return doc[doc_to_target]
else: else:
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target): elif callable(doc_to_target):
...@@ -811,14 +816,14 @@ class ConfigurableTask(Task): ...@@ -811,14 +816,14 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this? choices = self.doc_to_choice(doc)
if self.multiple_input > 0: if self.multiple_input:
choices = self.doc_to_text(doc) # If there are multiple inputs, choices are placed in the ctx
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
arguments = [(ctx, " {}".format(cont)) for ctx in choices] arguments = [(ctx, " {}".format(cont)) for ctx in choices]
else: else:
choices = self.doc_to_choice(doc) # Otherwise they are placed in the continuation
arguments = [(ctx, " {}".format(cont)) for cont in choices] arguments = [(ctx, " {}".format(cont)) for cont in choices]
request_list = [ request_list = [
...@@ -920,10 +925,7 @@ class ConfigurableTask(Task): ...@@ -920,10 +925,7 @@ class ConfigurableTask(Task):
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
if self.multiple_input: choices = self.doc_to_choice(doc)
choices = [self.doc_to_target(doc)] * self.multiple_input
else:
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
if ( if (
...@@ -937,33 +939,15 @@ class ConfigurableTask(Task): ...@@ -937,33 +939,15 @@ class ConfigurableTask(Task):
# and this stores our "regular" conditional loglikelihoods # and this stores our "regular" conditional loglikelihoods
lls = lls[::2] lls = lls[::2]
pred_idx = np.argmax(lls) pred = np.argmax(lls)
pred_idx_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
# Gives priority to evaluate base on gold_alias if self.multiple_input:
if self._config.gold_alias is not None: gold = self.doc_to_text(doc)
pred = pred_idx
pred_norm = pred_idx_norm
gold_idx = int(self.gold_alias(doc))
gold = gold_idx
else: else:
pred = choices[pred_idx]
pred_norm = choices[pred_idx_norm]
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if self.multiple_output > 0:
if type(gold[0]) == int:
gold_idx = gold
gold = [choices[idx] for idx in gold_idx]
elif type(gold[0]) == str:
gold_idx = [choices.index(g) for g in gold]
else:
if type(gold) == int:
gold_idx = gold
gold = choices[gold_idx]
elif type(gold) == str:
gold_idx = choices.index(gold)
if self.multiple_output: if self.multiple_target:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
else: else:
...@@ -972,8 +956,8 @@ class ConfigurableTask(Task): ...@@ -972,8 +956,8 @@ class ConfigurableTask(Task):
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold_idx, pred_idx)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold_idx, pred_idx)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
} }
...@@ -1013,7 +997,7 @@ class ConfigurableTask(Task): ...@@ -1013,7 +997,7 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key]( _dict = self._metric_fn_list[key](
references=gold if self.multiple_output else [gold], references=gold if self.multiple_target else [gold],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[key], **self._metric_fn_kwargs[key],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment