Commit db37024b authored by lintangsutawika's avatar lintangsutawika
Browse files

fix doc_to_target processing

parent 8eab2a58
...@@ -465,8 +465,11 @@ class Task(abc.ABC): ...@@ -465,8 +465,11 @@ class Task(abc.ABC):
elif type(example) == list: elif type(example) == list:
return [labeled_examples + ex for ex in example] return [labeled_examples + ex for ex in example]
elif type(example) == int: elif type(example) == int:
if self._config.doc_to_choice is not None:
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
return labeled_examples + choices[example] return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
def apply_filters(self): def apply_filters(self):
...@@ -790,7 +793,7 @@ class ConfigurableTask(Task): ...@@ -790,7 +793,7 @@ class ConfigurableTask(Task):
target_string = utils.apply_template(doc_to_target, doc) target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit(): if target_string.isdigit():
return ast.literal_eval(target_string) return ast.literal_eval(target_string)
elif (target_string[0] == "[") and (target_string[-1] == "]"): elif len(target_string) >= 2 and (target_string[0] == "[") and (target_string[-1] == "]"):
return ast.literal_eval(target_string) return ast.literal_eval(target_string)
else: else:
return target_string return target_string
...@@ -1002,9 +1005,13 @@ class ConfigurableTask(Task): ...@@ -1002,9 +1005,13 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if type(gold) == int: if self._config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
gold = choices[gold] gold = choices[gold]
else:
gold = str(gold)
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
if self.multiple_target: if self.multiple_target:
......
...@@ -2,7 +2,7 @@ task: realtoxicityprompts ...@@ -2,7 +2,7 @@ task: realtoxicityprompts
dataset_path: "allenai/real-toxicity-prompts" dataset_path: "allenai/real-toxicity-prompts"
training_split: 'train' training_split: 'train'
test_split: 'train' test_split: 'train'
doc_to_text: "{{' '+prompt.text}}" doc_to_text: "{{prompt.text}}"
doc_to_target: "" doc_to_target: ""
metric_list: metric_list:
- metric: !function metric.toxicity_perspective_api - metric: !function metric.toxicity_perspective_api
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment