Commit 2b40017b authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

Merge branch 'main' into add-chat-templating

parents bbcdffb8 ff739414
...@@ -61,7 +61,7 @@ jobs: ...@@ -61,7 +61,7 @@ jobs:
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest - name: Test with pytest
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra run: python -m pytest --showlocals -s -vv -n=auto
- name: Archive artifacts - name: Archive artifacts
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
......
...@@ -271,7 +271,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -271,7 +271,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
default=_handle_non_serializable, default=_handle_non_serializable,
ensure_ascii=False, ensure_ascii=False,
) )
filename.open("w").write(samples_dumped) filename.write_text(samples_dumped, encoding="utf-8")
print( print(
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
...@@ -1131,6 +1131,15 @@ class ConfigurableTask(Task): ...@@ -1131,6 +1131,15 @@ class ConfigurableTask(Task):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold) # print(gold)
gold = [gold] gold = [gold]
if metric == "exact_match":
result = [result for _ in range(len(gold))]
scores = self._metric_fn_list[metric](
references=gold,
predictions=result,
**self._metric_fn_kwargs[metric],
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold: for gold_option in gold:
try: try:
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
......
...@@ -749,7 +749,7 @@ class HFLM(LM): ...@@ -749,7 +749,7 @@ class HFLM(LM):
generation_kwargs["do_sample"] = False generation_kwargs["do_sample"] = False
# build stopping criteria # build stopping criteria
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0] self.tokenizer, stop, context.shape[1], context.shape[0]
) )
return self.model.generate( return self.model.generate(
input_ids=context, input_ids=context,
......
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
version: 0.0 version: 1.0
...@@ -27,4 +27,4 @@ filter_list: ...@@ -27,4 +27,4 @@ filter_list:
- function: "take_first" - function: "take_first"
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
version: 1.0 version: 2.0
...@@ -24,4 +24,4 @@ filter_list: ...@@ -24,4 +24,4 @@ filter_list:
- function: "take_first" - function: "take_first"
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
version: 0 version: 1.0
...@@ -18,4 +18,4 @@ generation_kwargs: ...@@ -18,4 +18,4 @@ generation_kwargs:
temperature: 0.0 temperature: 0.0
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
version: 0 version: 1.0
...@@ -18,4 +18,4 @@ generation_kwargs: ...@@ -18,4 +18,4 @@ generation_kwargs:
temperature: 0.0 temperature: 0.0
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
version: 0 version: 1.0
group: belebele group: belebele
dataset_path: facebook/belebele dataset_path: facebook/belebele
test_split: test
fewshot_split: test
fewshot_config: fewshot_config:
sampler: first_n sampler: first_n
output_type: multiple_choice output_type: multiple_choice
......
...@@ -8,7 +8,7 @@ import requests ...@@ -8,7 +8,7 @@ import requests
from tqdm import tqdm from tqdm import tqdm
from lm_eval.logger import eval_logger from lm_eval.utils import logging
API_URL = "https://datasets-server.huggingface.co/splits?dataset=facebook/belebele" API_URL = "https://datasets-server.huggingface.co/splits?dataset=facebook/belebele"
...@@ -39,8 +39,8 @@ if __name__ == "__main__": ...@@ -39,8 +39,8 @@ if __name__ == "__main__":
def query(): def query():
response = requests.get(API_URL) response = requests.get(API_URL)
return response.json()["splits"] return response.json()["splits"]
print(query())
languages = [split["config"] for split in query()] languages = [split["split"] for split in query()]
for lang in tqdm(languages): for lang in tqdm(languages):
yaml_dict = { yaml_dict = {
...@@ -48,11 +48,12 @@ if __name__ == "__main__": ...@@ -48,11 +48,12 @@ if __name__ == "__main__":
"task": f"belebele_{args.task_prefix}_{lang}" "task": f"belebele_{args.task_prefix}_{lang}"
if args.task_prefix != "" if args.task_prefix != ""
else f"belebele_{lang}", else f"belebele_{lang}",
"dataset_name": lang, "test_split": lang,
"fewshot_split":lang,
} }
file_save_path = args.save_prefix_path + f"_{lang}.yaml" file_save_path = args.save_prefix_path + f"_{lang}.yaml"
eval_logger.info(f"Saving yaml for subset {lang} to {file_save_path}") logging.info(f"Saving yaml for subset {lang} to {file_save_path}")
with open(file_save_path, "w") as yaml_file: with open(file_save_path, "w") as yaml_file:
yaml.dump( yaml.dump(
yaml_dict, yaml_dict,
......
"dataset_name": "acm_Arab" "fewshot_split": "acm_Arab"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_acm_Arab" "task": "belebele_acm_Arab"
"test_split": "acm_Arab"
"dataset_name": "afr_Latn" "fewshot_split": "afr_Latn"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_afr_Latn" "task": "belebele_afr_Latn"
"test_split": "afr_Latn"
"dataset_name": "als_Latn" "fewshot_split": "als_Latn"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_als_Latn" "task": "belebele_als_Latn"
"test_split": "als_Latn"
"dataset_name": "amh_Ethi" "fewshot_split": "amh_Ethi"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_amh_Ethi" "task": "belebele_amh_Ethi"
"test_split": "amh_Ethi"
"dataset_name": "apc_Arab" "fewshot_split": "apc_Arab"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_apc_Arab" "task": "belebele_apc_Arab"
"test_split": "apc_Arab"
"dataset_name": "arb_Arab" "fewshot_split": "arb_Arab"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_arb_Arab" "task": "belebele_arb_Arab"
"test_split": "arb_Arab"
"dataset_name": "arb_Latn" "fewshot_split": "arb_Latn"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_arb_Latn" "task": "belebele_arb_Latn"
"test_split": "arb_Latn"
"dataset_name": "ars_Arab" "fewshot_split": "ars_Arab"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_ars_Arab" "task": "belebele_ars_Arab"
"test_split": "ars_Arab"
"dataset_name": "ary_Arab" "fewshot_split": "ary_Arab"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "belebele_ary_Arab" "task": "belebele_ary_Arab"
"test_split": "ary_Arab"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment