Unverified Commit 6769119f authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #816 from EleutherAI/flan-benchmark

[Refactor] Flan benchmark
parents 4824a832 7d5e511c
......@@ -4,7 +4,9 @@ import sklearn.metrics
def mean_3class_f1(predictions, references): # This is a passthrough function
string_label = ["entailment", "contradiction", "neutral"]
predictions = string_label.index(predictions[0])
predictions = (
string_label.index(predictions[0]) if predictions[0] in string_label else 0
)
references = string_label.index(references[0])
return (predictions, references)
......
......@@ -9,6 +9,9 @@ output_type: greedy_until
doc_to_text: "copa choice1: {{choice1}} choice2: {{choice2}} premise: {{premise}} question: {{question}}"
doc_to_target: label
doc_to_choice: ['choice1', 'choice2']
generation_kwargs:
until:
- "</s>"
metric_list:
- metric: exact_match
aggregation: mean
......
......@@ -12,8 +12,6 @@ doc_to_choice: "{% set group_id = idx.question|string %}{{[group_id+'_False', gr
generation_kwargs:
until:
- "</s>"
do_sample: false
temperature: 0.5
metric_list:
- metric: !function t5_utils.f1
aggregation: !function t5_utils.agg_f1
......
......@@ -8,6 +8,9 @@ output_type: greedy_until
process_docs: !function t5_utils.process_docs
doc_to_text: !function t5_utils.doc_to_text
doc_to_target: "{{idx.passage|string}}+{{idx.query}}_{{answers}}"
generation_kwargs:
until:
- "</s>"
metric_list:
- metric: !function t5_utils.em
aggregation: !function t5_utils.squad_em_agg
......
......@@ -9,6 +9,9 @@ output_type: greedy_until
doc_to_text: "rte hypothesis: {{hypothesis}} premise: {{premise}}"
doc_to_target: label
doc_to_choice: ['entailment', 'not_entailment']
generation_kwargs:
until:
- "</s>"
metric_list:
- metric: exact_match
aggregation: mean
......
......@@ -9,6 +9,9 @@ output_type: greedy_until
doc_to_text: "wic sentence1: {{sentence1}} sentence2: {{sentence2}} word: {{word}}"
doc_to_target: label
doc_to_choice: ['False', 'True']
generation_kwargs:
until:
- "</s>"
metric_list:
- metric: exact_match
aggregation: mean
......
......@@ -8,6 +8,9 @@ validation_split: validation
output_type: greedy_until
doc_to_text: !function "t5_utils.doc_to_text"
doc_to_target: label
generation_kwargs:
until:
- "</s>"
metric_list:
- metric: accuracy
aggregation: mean
......
......@@ -421,37 +421,45 @@ def import_function(loader, node):
yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path):
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path)
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.normpath(os.path.join(yaml_dir, path))
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
assert yaml_dir is not None
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
def regex_replace(string, pattern, repl, count: int = 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment