Commit 6399ef9e authored by lintangsutawika's avatar lintangsutawika
Browse files

handles prompt from promptsource

parent c41bb266
......@@ -584,23 +584,41 @@ class ConfigurableTask(Task):
def doc_to_text(self, doc):
if self._config.use_prompt is not None:
doc_to_text = get_prompt(self._config.use_prompt)
doc_to_text = get_prompt(
self._config.use_prompt,
self.DATASET_NAME,
self.DATASET_PATH
)
else:
doc_to_text = self._config.doc_to_text
if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text):
return doc_to_text(doc)
if hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0]
else:
return doc_to_text(doc)
else:
raise TypeError
def doc_to_target(self, doc):
doc_to_target = self._config.doc_to_target
if self._config.use_prompt is not None:
doc_to_target = get_prompt(
self._config.use_prompt,
self.DATASET_NAME,
self.DATASET_PATH
)
else:
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
if hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
return doc_to_target(doc)
else:
raise TypeError
......
......@@ -6,17 +6,24 @@
PROMPT_REGISTRY = {
"qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {question}\nA:"
"q-newline-a": "Q: {{question}}\nA:"
},
}
def get_prompt(prompt_id: str):
def get_prompt(prompt_id: str, dataset_name=None, dataset_path=None):
# unpack prompt name
try:
category_name, prompt_name = prompt_id.split(":")
except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
return PROMPT_REGISTRY[category_name][prompt_name]
\ No newline at end of file
if category_name == "promptsource":
from promptsource.templates import DatasetTemplates
if prompt_name in prompts.all_template_names:
prompts = DatasetTemplates(dataset_name, dataset_path)
return prompts[prompt_name]
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name]
except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment