"docs/en/changelog.md" did not exist on "93dc5dfc2b98e293bf93c8f8e4db3a0f277a44aa"
Commit e56b950a authored by lintangsutawika's avatar lintangsutawika
Browse files

able to use prompts from promptsource

parent 1c7521a0
...@@ -228,7 +228,7 @@ class Task(abc.ABC): ...@@ -228,7 +228,7 @@ class Task(abc.ABC):
return self.validation_docs() return self.validation_docs()
else: else:
eval_logger.warning( eval_logger.warning(
"has_training_docs and has_validation_docs are False", "has_training_docs and has_validation_docs are False"
"using test_docs but this is not recommended." "using test_docs but this is not recommended."
) )
return self.test_docs() return self.test_docs()
...@@ -519,7 +519,19 @@ class ConfigurableTask(Task): ...@@ -519,7 +519,19 @@ class ConfigurableTask(Task):
[["take_first", None]] [["take_first", None]]
) )
] ]
if self._config.use_prompt is not None:
eval_logger.info(
f"loading prompt {self._config.use_prompt}"
)
self.prompt = get_prompt(
self._config.use_prompt,
self.DATASET_PATH,
self.DATASET_NAME
)
else:
self.prompt = None
if self.fewshot_docs() != None: if self.fewshot_docs() != None:
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here
...@@ -583,42 +595,35 @@ class ConfigurableTask(Task): ...@@ -583,42 +595,35 @@ class ConfigurableTask(Task):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
if self._config.use_prompt is not None:
doc_to_text = get_prompt( if self.prompt is not None:
self._config.use_prompt, doc_to_text = self.prompt
self.DATASET_NAME,
self.DATASET_PATH
)
else: else:
doc_to_text = self._config.doc_to_text doc_to_text = self._config.doc_to_text
if type(doc_to_text) == str: if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc) return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text): elif callable(doc_to_text):
if hasattr(doc_to_text, "apply"): return doc_to_text(doc)
return doc_to_text.apply(doc)[0] if hasattr(doc_to_text, "apply"):
else: return doc_to_text.apply(doc)[0]
return doc_to_text(doc)
else: else:
print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc): def doc_to_target(self, doc):
if self._config.use_prompt is not None:
doc_to_target = get_prompt( if self.prompt is not None:
self._config.use_prompt, doc_to_target = self.prompt
self.DATASET_NAME,
self.DATASET_PATH
)
else: else:
doc_to_target = self._config.doc_to_target doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str: if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target): elif callable(doc_to_target):
if hasattr(doc_to_target, "apply"): return doc_to_target(doc)
return doc_to_target.apply(doc)[1] elif hasattr(doc_to_target, "apply"):
else: return doc_to_target.apply(doc)[1]
return doc_to_target(doc)
else: else:
raise TypeError raise TypeError
......
from lm_eval.logger import eval_logger
from promptsource.templates import DatasetTemplates
# TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both? # TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both?
# Prompt library. # Prompt library.
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
...@@ -10,20 +13,35 @@ PROMPT_REGISTRY = { ...@@ -10,20 +13,35 @@ PROMPT_REGISTRY = {
}, },
} }
def get_prompt(prompt_id: str, dataset_name=None, dataset_path=None): def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
# unpack prompt name # unpack prompt name
category_name, prompt_name = prompt_id.split(":") category_name, prompt_name = prompt_id.split(":")
if category_name == "promptsource": eval_logger.info(
from promptsource.templates import DatasetTemplates f"Loading prompt from {category_name}"
if prompt_name in prompts.all_template_names: )
prompts = DatasetTemplates(dataset_name, dataset_path) if category_name == "promptsource":
return prompts[prompt_name] try:
# prompts = DatasetTemplates(dataset_name, dataset_path)
if subset_name == None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
except:
raise ValueError(
f"{dataset_name} and {subset_name} not found"
)
if prompt_name in prompts.all_template_names:
return prompts[prompt_name]
else: else:
try: raise ValueError(
return PROMPT_REGISTRY[category_name][prompt_name] f"{prompt_name} not in prompt list {prompts.all_template_names}"
except: )
raise ValueError( else:
f"expected only a single `:` as separator between \ try:
prompt category and name, but got `{prompt_id}` instead" return PROMPT_REGISTRY[category_name][prompt_name]
) except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
group:
- t0-eval
task: "does the pronoun refer to"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:does the pronoun refer to"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
group:
- t0-eval
task: "by p they mean"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:by p they mean"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
group:
- t0-eval
task: "in other words"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:in other words"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment