Commit a0787a9f authored by baberabb's avatar baberabb
Browse files

Merge remote-tracking branch 'origin/big-refactor' into big-refactor_dp

parents 6359f083 dc5b3d5d
......@@ -50,7 +50,7 @@ Scoring details:
- **doc_to_decontamination_query** (`str`, *optional*) —
Other:
- **metadata** (`str`, *optional*) — An optional field where arbitrary metadata can be passed.
- **metadata** (`Union[str, list]`, *optional*) — An optional field where arbitrary metadata can be passed. A good example would be `version` that is used to denote the version of the yaml config.
## Filters
......
......@@ -81,7 +81,7 @@ class TaskConfig(dict):
fewshot_delimiter: str = "\n\n"
fewshot_config: dict = None
# runtime configuration options
num_fewshot: int = 0
num_fewshot: int = -1
# scoring options
metric_list: list = None
output_type: str = "generate_until"
......@@ -91,7 +91,9 @@ class TaskConfig(dict):
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
metadata: Union[
str, list
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None:
if self.dataset_path and ("." in self.dataset_path):
......
......@@ -134,13 +134,17 @@ def simple_evaluate(
config["generation_kwargs"].update(gen_kwargs)
if num_fewshot is not None:
if config["num_fewshot"] > 0:
if config["num_fewshot"] == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else:
default_num_fewshot = config["num_fewshot"]
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj._config["num_fewshot"] = num_fewshot
task_obj._config["num_fewshot"] = num_fewshot
if check_integrity:
run_task_tests(task_list=tasks)
......@@ -233,6 +237,8 @@ def evaluate(
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
task_group_alias = collections.defaultdict(dict)
# store num-fewshot value per task
num_fewshot = collections.defaultdict(int)
# get lists of each type of request
for task_name, task in task_dict.items():
......@@ -251,6 +257,12 @@ def evaluate(
versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config())
if "num_fewshot" in configs[task_name]:
n_shot = configs[task_name]["num_fewshot"]
else:
n_shot = -1
num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"]
......@@ -612,11 +624,16 @@ def evaluate(
else:
groups_agg[group]["alias"] = tab_string + group
for group_name, task_list in task_hierarchy.items():
if task_list != []:
num_fewshot[group_name] = num_fewshot[task_list[0]]
results_dict = {
"results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
}
if log_samples:
results_dict["samples"] = dict(samples)
......
......@@ -158,12 +158,17 @@ class HFLM(LM):
trust_remote_code=trust_remote_code,
)
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif (
not getattr(self._config, "model_type")
if (
getattr(self._config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
elif getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
if not trust_remote_code:
eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
......@@ -172,8 +177,6 @@ class HFLM(LM):
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
......
......@@ -22,3 +22,5 @@ metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
......@@ -19,3 +19,5 @@ metric_list:
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
......@@ -12,3 +12,5 @@ metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
......@@ -10,3 +10,5 @@ metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
......@@ -16,3 +16,5 @@ metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
metadata:
- version: 0.0
......@@ -24,3 +24,5 @@ filter_list:
- function: "regex"
regex_pattern: "(?<=the answer is )(.*)(?=.)"
- function: "take_first"
metadata:
- version: 0.0
......@@ -22,3 +22,5 @@ filter_list:
- function: "regex"
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
- function: "take_first"
metadata:
- version: 0
......@@ -16,3 +16,5 @@ generation_kwargs:
- "\n\n"
do_sample: false
temperature: 0.0
metadata:
- version: 0
......@@ -16,3 +16,5 @@ generation_kwargs:
- "\n\n"
do_sample: false
temperature: 0.0
metadata:
- version: 0
......@@ -17,3 +17,5 @@ metric_list:
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
- version: 0.0
......@@ -14,3 +14,5 @@ metric_list:
aggregation: mean
higher_is_better: true
ignore_punctuation: true
metadata:
- version: 0.0
......@@ -11,3 +11,5 @@ doc_to_choice: "{{multiple_choice_targets}}"
metric_list:
- metric: acc
# TODO: brier score and other metrics
metadata:
- version: 0.0
......@@ -5,7 +5,10 @@ validation_split: train
doc_to_text: ""
doc_to_target: 0
doc_to_choice: "{{[sentence_good, sentence_bad]}}"
num_fewshot: 0
should_decontaminate: true
doc_to_decontamination_query: "{{sentence_good}} {{sentence_bad}}"
metric_list:
- metric: acc
metadata:
- version: 1.0
# Generated by utils.py
dataset_name: adjunct_island
include: template_yaml
include: _template_yaml
task: blimp_adjunct_island
# Generated by utils.py
dataset_name: anaphor_gender_agreement
include: template_yaml
include: _template_yaml
task: blimp_anaphor_gender_agreement
# Generated by utils.py
dataset_name: anaphor_number_agreement
include: template_yaml
include: _template_yaml
task: blimp_anaphor_number_agreement
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment