Commit e56b950a authored by lintangsutawika's avatar lintangsutawika
Browse files

able to use prompts from promptsource

parent 1c7521a0
......@@ -228,7 +228,7 @@ class Task(abc.ABC):
return self.validation_docs()
else:
eval_logger.warning(
"has_training_docs and has_validation_docs are False",
"has_training_docs and has_validation_docs are False"
"using test_docs but this is not recommended."
)
return self.test_docs()
......@@ -519,7 +519,19 @@ class ConfigurableTask(Task):
[["take_first", None]]
)
]
if self._config.use_prompt is not None:
eval_logger.info(
f"loading prompt {self._config.use_prompt}"
)
self.prompt = get_prompt(
self._config.use_prompt,
self.DATASET_PATH,
self.DATASET_NAME
)
else:
self.prompt = None
if self.fewshot_docs() != None:
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here
......@@ -583,42 +595,35 @@ class ConfigurableTask(Task):
return doc
def doc_to_text(self, doc):
if self._config.use_prompt is not None:
doc_to_text = get_prompt(
self._config.use_prompt,
self.DATASET_NAME,
self.DATASET_PATH
)
if self.prompt is not None:
doc_to_text = self.prompt
else:
doc_to_text = self._config.doc_to_text
if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text):
if hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0]
else:
return doc_to_text(doc)
return doc_to_text(doc)
if hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0]
else:
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc):
if self._config.use_prompt is not None:
doc_to_target = get_prompt(
self._config.use_prompt,
self.DATASET_NAME,
self.DATASET_PATH
)
if self.prompt is not None:
doc_to_target = self.prompt
else:
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
if hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
return doc_to_target(doc)
return doc_to_target(doc)
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
raise TypeError
......
from lm_eval.logger import eval_logger
from promptsource.templates import DatasetTemplates
# TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both?
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
......@@ -10,20 +13,35 @@ PROMPT_REGISTRY = {
},
}
def get_prompt(prompt_id: str, dataset_name=None, dataset_path=None):
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
# unpack prompt name
category_name, prompt_name = prompt_id.split(":")
if category_name == "promptsource":
from promptsource.templates import DatasetTemplates
if prompt_name in prompts.all_template_names:
prompts = DatasetTemplates(dataset_name, dataset_path)
return prompts[prompt_name]
category_name, prompt_name = prompt_id.split(":")
eval_logger.info(
f"Loading prompt from {category_name}"
)
if category_name == "promptsource":
try:
# prompts = DatasetTemplates(dataset_name, dataset_path)
if subset_name == None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
except:
raise ValueError(
f"{dataset_name} and {subset_name} not found"
)
if prompt_name in prompts.all_template_names:
return prompts[prompt_name]
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name]
except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
raise ValueError(
f"{prompt_name} not in prompt list {prompts.all_template_names}"
)
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name]
except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
group:
- t0-eval
task: "does the pronoun refer to"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:does the pronoun refer to"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
group:
- t0-eval
task: "by p they mean"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:by p they mean"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
group:
- t0-eval
task: "in other words"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:in other words"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment