Commit 6399ef9e authored by lintangsutawika's avatar lintangsutawika
Browse files

handles prompt from promptsource

parent c41bb266
...@@ -584,23 +584,41 @@ class ConfigurableTask(Task): ...@@ -584,23 +584,41 @@ class ConfigurableTask(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
if self._config.use_prompt is not None: if self._config.use_prompt is not None:
doc_to_text = get_prompt(self._config.use_prompt) doc_to_text = get_prompt(
self._config.use_prompt,
self.DATASET_NAME,
self.DATASET_PATH
)
else: else:
doc_to_text = self._config.doc_to_text doc_to_text = self._config.doc_to_text
if type(doc_to_text) == str: if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc) return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text): elif callable(doc_to_text):
return doc_to_text(doc) if hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0]
else:
return doc_to_text(doc)
else: else:
raise TypeError raise TypeError
def doc_to_target(self, doc): def doc_to_target(self, doc):
doc_to_target = self._config.doc_to_target if self._config.use_prompt is not None:
doc_to_target = get_prompt(
self._config.use_prompt,
self.DATASET_NAME,
self.DATASET_PATH
)
else:
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str: if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) if hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
return doc_to_target(doc)
else: else:
raise TypeError raise TypeError
......
...@@ -6,17 +6,24 @@ ...@@ -6,17 +6,24 @@
PROMPT_REGISTRY = { PROMPT_REGISTRY = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {question}\nA:" "q-newline-a": "Q: {{question}}\nA:"
}, },
} }
def get_prompt(prompt_id: str): def get_prompt(prompt_id: str, dataset_name=None, dataset_path=None):
# unpack prompt name # unpack prompt name
try:
category_name, prompt_name = prompt_id.split(":") category_name, prompt_name = prompt_id.split(":")
except: if category_name == "promptsource":
raise ValueError( from promptsource.templates import DatasetTemplates
f"expected only a single `:` as separator between \ if prompt_name in prompts.all_template_names:
prompt category and name, but got `{prompt_id}` instead" prompts = DatasetTemplates(dataset_name, dataset_path)
) return prompts[prompt_name]
return PROMPT_REGISTRY[category_name][prompt_name] else:
\ No newline at end of file try:
return PROMPT_REGISTRY[category_name][prompt_name]
except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment