"examples/nas/oneshot/enas/utils.py" did not exist on "bb797e10e460c086a7de192dce2dae6681bbfcf0"
Commit 2b56339e authored by Baber's avatar Baber
Browse files

Merge branch 'main' into longcxt

parents 0b533339 703fbffd
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_vi_zh
dataset_name: mlqa.vi.zh
process_results: !function utils.process_results_vi
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_ar
dataset_name: mlqa.zh.ar
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_de
dataset_name: mlqa.zh.de
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_en
dataset_name: mlqa.zh.en
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_es
dataset_name: mlqa.zh.es
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_hi
dataset_name: mlqa.zh.hi
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_vi
dataset_name: mlqa.zh.vi
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_zh
dataset_name: mlqa.zh.zh
process_results: !function utils.process_results_zh
"""
Code based on Official evaluation script for the MLQA dataset.
Repo: https://github.com/facebookresearch/MLQA/blob/main/mlqa_evaluation_v1.py
"""
import re
import string
import sys
import unicodedata
from collections import Counter
import datasets
PUNCT = {
chr(i)
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
}.union(string.punctuation)
WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"]
MIXED_SEGMENTATION_LANGS = ["zh"]
def whitespace_tokenize(text):
return text.split()
def mixed_segmentation(text):
segs_out = []
temp_str = ""
for char in text:
if re.search(r"[\u4e00-\u9fa5]", char) or char in PUNCT:
if temp_str != "":
ss = whitespace_tokenize(temp_str)
segs_out.extend(ss)
temp_str = ""
segs_out.append(char)
else:
temp_str += char
if temp_str != "":
ss = whitespace_tokenize(temp_str)
segs_out.extend(ss)
return segs_out
def normalize_answer(s, lang):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text, lang):
if lang == "en":
return re.sub(r"\b(a|an|the)\b", " ", text)
elif lang == "es":
return re.sub(r"\b(un|una|unos|unas|el|la|los|las)\b", " ", text)
elif lang == "hi":
return text # Hindi does not have formal articles
elif lang == "vi":
return re.sub(r"\b(của|là|cái|chiếc|những)\b", " ", text)
elif lang == "de":
return re.sub(
r"\b(ein|eine|einen|einem|eines|einer|der|die|das|den|dem|des)\b",
" ",
text,
)
elif lang == "ar":
return re.sub(r"\sال^|ال", " ", text)
elif lang == "zh":
return text # Chinese does not have formal articles
else:
raise Exception("Unknown Language {}".format(lang))
def white_space_fix(text, lang):
if lang in WHITESPACE_LANGS:
tokens = whitespace_tokenize(text)
elif lang in MIXED_SEGMENTATION_LANGS:
tokens = mixed_segmentation(text)
else:
raise Exception("Unknown Language {}".format(lang))
return " ".join([t for t in tokens if t.strip() != ""])
def remove_punc(text):
return "".join(ch for ch in text if ch not in PUNCT)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s)), lang), lang)
def f1_score(prediction, ground_truth, lang):
prediction_tokens = normalize_answer(prediction, lang).split()
ground_truth_tokens = normalize_answer(ground_truth, lang).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth, lang):
return normalize_answer(prediction, lang) == normalize_answer(ground_truth, lang)
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, lang):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth, lang)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
out_doc = {
"context": doc["context"],
"question": doc["question"],
"answers": doc["answers"]["text"],
}
return out_doc
return dataset.map(_process_doc)
# Base function
def process_results_lang(doc, results, lang):
ground_truths = doc["answers"]
prediction = results[0].strip()
exact_match = metric_max_over_ground_truths(
exact_match_score, prediction, ground_truths, lang
)
f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths, lang)
return {"exact_match": exact_match, "f1": f1}
# Language Wrapper functions
def process_results_en(doc, results):
return process_results_lang(doc, results, "en")
def process_results_es(doc, results):
return process_results_lang(doc, results, "es")
def process_results_hi(doc, results):
return process_results_lang(doc, results, "hi")
def process_results_vi(doc, results):
return process_results_lang(doc, results, "vi")
def process_results_de(doc, results):
return process_results_lang(doc, results, "de")
def process_results_ar(doc, results):
return process_results_lang(doc, results, "ar")
def process_results_zh(doc, results):
return process_results_lang(doc, results, "zh")
...@@ -4,7 +4,8 @@ from functools import reduce ...@@ -4,7 +4,8 @@ from functools import reduce
import numpy as np import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from datasets import Dataset, load_metric from datasets import Dataset
from evaluate import load
from transformers import AutoTokenizer from transformers import AutoTokenizer
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -48,7 +49,10 @@ def _download_metric(): ...@@ -48,7 +49,10 @@ def _download_metric():
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
scrolls_metric_path = hf_hub_download( scrolls_metric_path = hf_hub_download(
repo_id="tau/scrolls", repo_type="dataset", filename="metrics/scrolls.py" repo_id="tau/scrolls",
repo_type="dataset",
filename="metrics/scrolls.py",
revision="refs/pr/5",
) )
updated_scrolls_metric_path = ( updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path) os.path.dirname(scrolls_metric_path)
...@@ -119,7 +123,7 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -119,7 +123,7 @@ class _SCROLLSTask(ConfigurableTask):
def __init__(self, config=None): def __init__(self, config=None):
super().__init__(config={"metadata": {"version": self.VERSION}}) super().__init__(config={"metadata": {"version": self.VERSION}})
if self.DATASET_NAME is not None: if self.DATASET_NAME is not None:
self.metric = load_metric(_download_metric(), config_name=self.DATASET_NAME) self.metric = load(_download_metric(), config_name=self.DATASET_NAME)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -253,11 +257,14 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask): ...@@ -253,11 +257,14 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
} }
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
apply_chat_template = kwargs.pop("apply_chat_template", False)
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, " {}".format(choice))
if not apply_chat_template
else (ctx, "{}".format(choice)),
idx=i, idx=i,
**kwargs, **kwargs,
) )
...@@ -285,6 +292,7 @@ class _SCROLLSSummaryTask(_SCROLLSTask): ...@@ -285,6 +292,7 @@ class _SCROLLSSummaryTask(_SCROLLSTask):
} }
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
kwargs.pop("apply_chat_template", False)
return Instance( return Instance(
request_type="generate_until", request_type="generate_until",
doc=doc, doc=doc,
...@@ -327,19 +335,22 @@ class Qasper(_SCROLLSTask): ...@@ -327,19 +335,22 @@ class Qasper(_SCROLLSTask):
return {"f1": (prediction, doc["outputs"])} return {"f1": (prediction, doc["outputs"])}
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
apply_chat_template = kwargs.pop("apply_chat_template", False)
if doc["is_yes_no"]: if doc["is_yes_no"]:
return [ return [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " yes"), arguments=(ctx, " yes")
if not apply_chat_template
else (ctx, "yes"),
idx=0, idx=0,
**kwargs, **kwargs,
), ),
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " no"), arguments=(ctx, " no") if not apply_chat_template else (ctx, "no"),
idx=1, idx=1,
**kwargs, **kwargs,
), ),
...@@ -406,6 +417,7 @@ class NarrativeQA(_SCROLLSTask): ...@@ -406,6 +417,7 @@ class NarrativeQA(_SCROLLSTask):
return {"f1": (results[0], doc["outputs"])} return {"f1": (results[0], doc["outputs"])}
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
kwargs.pop("apply_chat_template", False)
return Instance( return Instance(
request_type="generate_until", request_type="generate_until",
doc=doc, doc=doc,
......
# File generated by `create-yamls.py` # File generated by `create-yamls.py`
include: _phrases_es_common.yaml include: _phrases_es_common
task: phrases_es-va task: phrases_es-va
doc_to_text: 'Oració en espanyol: {{es}} doc_to_text: 'Oració en espanyol: {{es}}
......
# File generated by `create-yamls.py` # File generated by `create-yamls.py`
include: _phrases_es_common.yaml include: _phrases_es_common
task: phrases_va-es task: phrases_va-es
doc_to_text: 'Oració en valencià: {{va}} doc_to_text: 'Oració en valencià: {{va}}
......
...@@ -104,7 +104,8 @@ def simple_parse_args_string(args_string): ...@@ -104,7 +104,8 @@ def simple_parse_args_string(args_string):
return {} return {}
arg_list = [arg for arg in args_string.split(",") if arg] arg_list = [arg for arg in args_string.split(",") if arg]
args_dict = { args_dict = {
k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list] kv[0]: handle_arg_string("=".join(kv[1:]))
for kv in [arg.split("=") for arg in arg_list]
} }
return args_dict return args_dict
......
...@@ -109,13 +109,14 @@ def main(): ...@@ -109,13 +109,14 @@ def main():
if model_index == 0: # Only need to assemble data for the first model if model_index == 0: # Only need to assemble data for the first model
metrics = [] metrics = []
for metric in config["metric_list"]: for metric in config["metric_list"]:
metrics.append( if metric.get("aggregation") == "mean":
ZenoMetric( metrics.append(
name=metric["metric"], ZenoMetric(
type="mean", name=metric["metric"],
columns=[metric["metric"]], type="mean",
columns=[metric["metric"]],
)
) )
)
project = client.create_project( project = client.create_project(
name=args.project_name + (f"_{task}" if len(tasks) > 1 else ""), name=args.project_name + (f"_{task}" if len(tasks) > 1 else ""),
view="text-classification", view="text-classification",
...@@ -168,7 +169,11 @@ def generate_dataset( ...@@ -168,7 +169,11 @@ def generate_dataset(
Returns: Returns:
pd.Dataframe: A dataframe that is ready to be uploaded to Zeno. pd.Dataframe: A dataframe that is ready to be uploaded to Zeno.
""" """
ids = [x["doc_id"] for x in data] ids = (
[x["doc_id"] for x in data]
if not config.get("filter_list")
else [f"{x['doc_id']}.{x['filter']}" for x in data]
)
labels = [x["target"] for x in data] labels = [x["target"] for x in data]
instance = [""] * len(ids) instance = [""] * len(ids)
...@@ -190,6 +195,7 @@ def generate_dataset( ...@@ -190,6 +195,7 @@ def generate_dataset(
return pd.DataFrame( return pd.DataFrame(
{ {
"id": ids, "id": ids,
"doc_id": [x["doc_id"] for x in data],
"data": instance, "data": instance,
"input_len": [len(x) for x in instance], "input_len": [len(x) for x in instance],
"labels": labels, "labels": labels,
...@@ -208,8 +214,15 @@ def generate_system_df(data, config): ...@@ -208,8 +214,15 @@ def generate_system_df(data, config):
Returns: Returns:
pd.Dataframe: A dataframe that is ready to be uploaded to Zeno as a system. pd.Dataframe: A dataframe that is ready to be uploaded to Zeno as a system.
""" """
ids = [x["doc_id"] for x in data] ids = (
[x["doc_id"] for x in data]
if not config.get("filter_list")
else [f"{x['doc_id']}.{x['filter']}" for x in data]
)
system_dict = {"id": ids} system_dict = {"id": ids}
system_dict["doc_id"] = [x["doc_id"] for x in data]
if config.get("filter_list"):
system_dict["filter"] = [x["filter"] for x in data]
system_dict["output"] = [""] * len(ids) system_dict["output"] = [""] * len(ids)
if config["output_type"] == "loglikelihood": if config["output_type"] == "loglikelihood":
...@@ -228,11 +241,10 @@ def generate_system_df(data, config): ...@@ -228,11 +241,10 @@ def generate_system_df(data, config):
system_dict["output"] = [str(x["filtered_resps"][0]) for x in data] system_dict["output"] = [str(x["filtered_resps"][0]) for x in data]
system_dict["output_length"] = [len(str(x["filtered_resps"][0])) for x in data] system_dict["output_length"] = [len(str(x["filtered_resps"][0])) for x in data]
metrics = {} metrics = {
for metric in config["metric_list"]: metric["metric"]: [x[metric["metric"]] for x in data]
if "aggregation" in metric and metric["aggregation"] == "mean": for metric in config["metric_list"]
metrics[metric["metric"]] = [x[metric["metric"]] for x in data] }
system_dict.update(metrics) system_dict.update(metrics)
system_df = pd.DataFrame(system_dict) system_df = pd.DataFrame(system_dict)
return system_df return system_df
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment