Commit e4db76cb authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

Merge branch 'main' into multimodal-prototyping

parents 6cc6e9cd ad80f555
......@@ -5,8 +5,7 @@ doc_to_target: ' {{translation["fr"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
French phrase:'
group:
- generate_until
tag:
- translation
- wmt14
- gpt3_translation_benchmarks
......
......@@ -5,8 +5,7 @@ doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'French phrase: {{translation["fr"]}}
English phrase:'
group:
- generate_until
tag:
- translation
- wmt14
- gpt3_translation_benchmarks
......
......@@ -5,8 +5,7 @@ doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'German phrase: {{translation["de"]}}
English phrase:'
group:
- generate_until
tag:
- translation
- wmt16
- gpt3_translation_benchmarks
......
......@@ -5,8 +5,7 @@ doc_to_target: ' {{translation["de"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
German phrase:'
group:
- generate_until
tag:
- translation
- wmt16
- gpt3_translation_benchmarks
......
......@@ -5,8 +5,7 @@ doc_to_target: ' {{translation["ro"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
Romanian phrase:'
group:
- generate_until
tag:
- translation
- wmt16
- gpt3_translation_benchmarks
......
......@@ -5,8 +5,7 @@ doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'Romanian phrase: {{translation["ro"]}}
English phrase:'
group:
- generate_until
tag:
- translation
- wmt16
- gpt3_translation_benchmarks
......
group:
tag:
- truthfulqa
task: truthfulqa_gen
dataset_path: truthful_qa
......
group:
tag:
- truthfulqa
task: truthfulqa_mc1
dataset_path: truthful_qa
......
include: unitxt_tasks.classification.multi_class
task: 20_newsgroups
dataset_name: card=cards.20_newsgroups,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.20_newsgroups,template=templates.classification.multi_class.title
include: unitxt_tasks.classification.multi_class
task: ag_news
dataset_name: card=cards.ag_news,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.ag_news,template=templates.classification.multi_class.title
include: unitxt_tasks.classification.multi_class
task: argument_topic
dataset_name: card=cards.argument_topic,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.argument_topic,template=templates.classification.multi_class.title
include: unitxt_tasks.span_labeling.extraction
task: atis
dataset_name: card=cards.atis,template=templates.span_labeling.extraction.title
include: unitxt
recipe: card=cards.atis,template=templates.span_labeling.extraction.title
include: unitxt_tasks.classification.multi_class
task: banking77
dataset_name: card=cards.banking77,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.banking77,template=templates.classification.multi_class.title
include: unitxt_tasks.classification.multi_class
task: claim_stance_topic
dataset_name: card=cards.claim_stance_topic,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.claim_stance_topic,template=templates.classification.multi_class.title
include: unitxt_tasks.summarization.abstractive
task: cnn_dailymail
dataset_name: card=cards.cnn_dailymail,template=templates.summarization.abstractive.full
include: unitxt
recipe: card=cards.cnn_dailymail,template=templates.summarization.abstractive.full
include: unitxt_tasks.grammatical_error_correction
task: coedit_gec
dataset_name: card=cards.coedit_gec,template=templates.grammatical_error_correction.simple
include: unitxt
recipe: card=cards.coedit_gec,template=templates.grammatical_error_correction.simple
include: unitxt_tasks.classification.multi_class
task: dbpedia_14
dataset_name: card=cards.dbpedia_14,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.dbpedia_14,template=templates.classification.multi_class.title
include: unitxt_tasks.classification.multi_class
task: ethos_binary
dataset_name: card=cards.ethos_binary,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.ethos_binary,template=templates.classification.multi_class.title
include: unitxt_tasks.classification.multi_class
task: financial_tweets
dataset_name: card=cards.financial_tweets,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.financial_tweets,template=templates.classification.multi_class.title
#
# This file generates a set of LM eval harness yaml file
# that load unitxt datasets (https://github.com/IBM/unitxt)
#
import unitxt_wrapper
import yaml
from unitxt.artifact import fetch_artifact
from unitxt.standard import StandardRecipe
# This code is required to properly dump LM harness YAML that contains references to functions
def function_representer(dumper: yaml.SafeDumper, func) -> yaml.nodes.MappingNode:
return dumper.represent_scalar(
"!function", f"{func.__module__}.{func.__name__}", style=None
)
def write_task_yaml(filename, data):
yaml.add_representer(type(data["process_results"]), function_representer)
with open(filename, "w") as stream:
yaml.dump(data, stream, sort_keys=False)
def write_card_yaml(filename, data):
with open(filename, "w") as stream:
yaml.dump(data, stream, sort_keys=False)
default_template_per_task = {
"tasks.classification.multi_label": "templates.classification.multi_label.title",
"tasks.classification.multi_class": "templates.classification.multi_class.title",
"tasks.summarization.abstractive": "templates.summarization.abstractive.full",
"tasks.regression.two_texts": "templates.regression.two_texts.simple",
"tasks.qa.with_context.extractive": "templates.qa.with_context.simple",
"tasks.grammatical_error_correction": "templates.grammatical_error_correction.simple",
"tasks.span_labeling.extraction": "templates.span_labeling.extraction.title",
}
def generate_task_yaml(task: str):
"""
Generate an LM Eval Harness YAML file based on a Unitxt task defintion.
The output YAML is based on 'template.yaml.file' found in current directoy.
The common template is filled the the specific metrics for the task.
It still leaves the 'dataset_name' and 'task name' unspecified.
"""
print("*" * 80)
print("*")
print(f"* Generating YAML base file for task {task}")
print("*")
task_definition, _ = fetch_artifact(task)
data = {
"group": ["unitxt"],
"dataset_path": "unitxt/data",
"output_type": "generate_until",
"training_split": "train",
"validation_split": "test",
"doc_to_text": "{{source}}",
"doc_to_target": "target",
"process_results": unitxt_wrapper.process_results,
"generation_kwargs": {"until": ["</s>"]},
"metric_list": [],
"metadata": {"verison": 1.0},
}
for metric_name in task_definition.metrics:
new_metric = {"metric": "", "aggregation": "unitxt", "higher_is_better": True}
new_metric["metric"] = metric_name.replace("metrics.", "unitxt_")
data["metric_list"].append(new_metric)
write_task_yaml(f"unitxt_{task}", data)
def generate_card_yaml(card: str):
"""
Generate an LM Eval Harness YAML file based on the Unitxt dataset card.
It includes the task YAML for the dataset, and overrides the 'dataset_name' and 'task' with the card.
"""
print("*" * 80)
print("*")
print(f"* Generating YAML file for unitxt dataset {card}")
print("*")
card_definition, _ = fetch_artifact(f"cards.{card}")
task = card_definition.task.__id__
if task in default_template_per_task:
template = default_template_per_task[task]
else:
raise ValueError(
f"Default template was not defined for task {task} in 'default_template_per_task' dict in generate_yamls.py"
)
data = {}
data["include"] = f"unitxt_{task}"
data["task"] = card
data["dataset_name"] = f"card=cards.{card},template={template}"
# This is faster that the load_dataset approach
# dataset = load_dataset('unitxt/data', data["dataset_name"]+",loader_limit=100",trust_remote_code=True)
recipe = StandardRecipe(card=f"cards.{card}", template=template, loader_limit=100)
stream = recipe()
dataset = stream.to_dataset()
print(dataset)
print("Sample input:")
print(dataset["test"][0]["source"])
print("Sample output:")
print(dataset["test"][0]["target"])
write_card_yaml(f"{card}.yaml", data)
def main():
for task in default_template_per_task.keys():
try:
generate_task_yaml(task)
except Exception as e:
print(f"Unable to generate YAML for {task} due to:")
print(e)
raise (e)
with open("unitxt_datasets") as f:
for unitxt_dataset in f:
unitxt_dataset = unitxt_dataset.strip()
if unitxt_dataset.startswith("### END ###"):
exit(0)
if not unitxt_dataset.startswith("#"):
try:
generate_card_yaml(unitxt_dataset)
except Exception as e:
print(f"Unable to generate YAML for {unitxt_dataset} due to:")
print(e)
raise e
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment