Unverified Commit 6a5cde6a authored by Jess's avatar Jess Committed by GitHub
Browse files

Merge pull request #23 from JessicaOjo/africamgsm

manual xnli, bypass multiple choice logits for openai
parents fb142ccd 9701ef6e
......@@ -3,7 +3,7 @@ from typing import Literal, Optional, Tuple
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice", "multiple_choice_gpt"
]
......
......@@ -23,7 +23,20 @@ def bypass_agg(arr):
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
if isinstance(arr[0], (list, np.ndarray)):
return sum(arr[0]) / len(arr[0])
else:
return sum(arr) / len(arr)
@register_aggregation("acc_gpt")
def acc_gpt(arr):
unzipped_list = list(zip(*arr))
golds = unzipped_list[0]
preds = unzipped_list[1]
accuracy = sklearn.metrics.accuracy_score(golds, preds)
return accuracy
@register_aggregation("median")
......@@ -151,7 +164,7 @@ def brier_score_fn(items): # This is a passthrough function
@register_metric(
metric="acc",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
output_type=["loglikelihood", "multiple_choice", "multiple_choice_gpt"],
aggregation="mean",
)
def acc_fn(items): # This is a passthrough function
......@@ -277,7 +290,7 @@ def mcc_fn(items): # This is a passthrough function
@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
output_type=["multiple_choice", "multiple_choice_gpt"],
aggregation="f1",
)
def f1_fn(items): # This is a passthrough function
......
......@@ -87,6 +87,7 @@ DEFAULT_METRIC_REGISTRY = {
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"multiple_choice_gpt": ["acc"],
"generate_until": ["exact_match"],
}
......
......@@ -44,6 +44,7 @@ from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
"multiple_choice_gpt",
"loglikelihood_rolling",
"generate_until",
]
......@@ -1064,7 +1065,6 @@ class ConfigurableTask(Task):
eval_logger.warning("Applied prompt returns empty string")
return self.config.fewshot_delimiter
else:
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]:
......@@ -1142,7 +1142,7 @@ class ConfigurableTask(Task):
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
elif "multiple_choice" in self.OUTPUT_TYPE:
choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter
if self.multiple_input:
......@@ -1154,17 +1154,28 @@ class ConfigurableTask(Task):
else:
# Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
if self.OUTPUT_TYPE == "multiple_choice_gpt":
request_list = [
Instance(
request_type="multiple_choice_gpt",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
else:
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
......@@ -1310,14 +1321,45 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "multiple_choice_gpt":
gold = self.doc_to_target(doc)
result = results[0]
choices = self.doc_to_choice(doc)
try:
gold = choices[gold]
gold = type(result)(gold)
except TypeError:
gold = gold
for metric in self._metric_fn_list.keys():
try:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
result_dict[metric] = result_score
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
if self.config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
try:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
except TypeError:
gold = gold
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
......@@ -1333,7 +1375,6 @@ class ConfigurableTask(Task):
scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric == "exact_match":
result = [result for _ in range(len(gold))]
......
......@@ -471,6 +471,52 @@ class OpenaiChatCompletionsLM(LM):
return grouper.get_original(res)
def multiple_choice_gpt(self, requests, disable_tqdm: bool = False) -> List[str]:
res = defaultdict(list)
re_ords = {}
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for key, re_ord in re_ords.items():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
response = oa_completion(
client=self.client,
chat=True,
messages=inps,
model=self.model,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
res[key].append(s)
self.cache_hook.add_partial(
"multiple_choice_gpt", context, s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
def loglikelihood(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
......
import argparse
import yaml
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
languages_REGEX = {"eng":"The answer is (\\-?[0-9\\.\\,]+)",
"amh":"መልሱ (\\-?[0-9\\.\\,]+)",
"ibo":"Azịza ya bụ (\\-?[0-9\\.\\,]+)",
'fra':"La réponse est(\\-?[0-9\\.\\,]+)",
'sna':"Mhinduro kumubvunzo ndi (\\-?[0-9\\.\\,]+)",
'lin':"Eyano ezali (\\-?[0-9\\.\\,]+)",
'wol': "Tontu li (\\-?[0-9\\.\\,]+)",
'ewe': "ŋuɖoɖoae nye (\\-?[0-9\\.\\,]+)",
'lug': "Ansa eri (\\-?[0-9\\.\\,]+)",
'xho': "Impendulo ngu (\\-?[0-9\\.\\,]+)",
'kin': "Igisubizo ni (\\-?[0-9\\.\\,]+)",
'twi': "Ne nnyiano yɛ (\\-?[0-9\\.\\,]+)",
'zul': "Impendulo ithi (\\-?[0-9\\.\\,]+)",
'orm': "Deebiin isaa (\\-?[0-9\\.\\,]+)",
'yor': "Ìdáhùn náà ni (\\-?[0-9\\.\\,]+)",
'hau': "Amsar ita ce (\\-?[0-9\\.\\,]+)",
'sot': "Karabo ke (\\-?[0-9\\.\\,]+)",
'swa': "Jibu ni (\\-?[0-9\\.\\,]+)",
}
LANGUAGES = {}
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
for lang in languages:
if lang == 'amh':
LANGUAGES[lang] = { # English
"QUESTION": "ጥያቄ:",
"ANSWER": "በቅደም ተከተል መልስ:",
"DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
elif lang == 'yor':
LANGUAGES[lang] = { # English
"QUESTION": "Ìbéèrè:",
"ANSWER": "Ìdáhùn lẹ́sẹsẹ:",
"DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
else:
LANGUAGES[lang] = { # English
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
def add_regex_pattern(regex_pattern):
if regex_pattern is None:
return {}
return {
"filter_list": [
{
"name": "strict-match",
"filter": [
{
"function": "regex",
"regex_pattern": f"""{regex_pattern}""",
},
{
"function": "take_first",
},
],
},
{
"name": "flexible-extract",
"filter": [
{
"function": "regex",
"regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""",
"group_select": -1,
},
{
"function": "take_first",
},
],
},
],
}
configs = {
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)"}
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
......@@ -89,46 +18,19 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES.keys():
for lang in languages:
try:
yaml_template = "cot_yaml"
filter_list = {}
DELIMITER = None
if mode == "direct":
ANSWER = LANGUAGES['eng']["DIRECT"]
QUESTION = LANGUAGES['eng']["QUESTION"]
REGEX = None
task_name = f"afrimgsm_direct_{lang}"
yaml_template = "direct_yaml"
if mode == "direct-native":
ANSWER = LANGUAGES[lang]["DIRECT"]
QUESTION = LANGUAGES[lang]["QUESTION"]
REGEX = None
task_name = f"afrimgsm_direct_native_{lang}"
yaml_template = "direct_native_yaml"
yaml_template = "afrimgsm_common_yaml"
elif mode == "native-cot":
ANSWER = LANGUAGES[lang]["ANSWER"]
REGEX = LANGUAGES[lang]["REGEX"]
QUESTION = LANGUAGES[lang]["QUESTION"]
task_name = f"afrimgsm_native_cot_{lang}"
filter_list = add_regex_pattern(REGEX)
DELIMITER = "" if lang in ["zh", "ja"] else None
yaml_template = "afrimgsm_common_yaml"
elif mode == "en-cot":
ANSWER = LANGUAGES["eng"]["ANSWER"]
REGEX = LANGUAGES["eng"]["REGEX"]
QUESTION = LANGUAGES["eng"]["QUESTION"]
task_name = f"afrimgsm_en_cot_{lang}"
elif mode == "translate-direct":
ANSWER = LANGUAGES['eng']["DIRECT"]
QUESTION = LANGUAGES['eng']["QUESTION"]
REGEX = None
task_name = f"translate_afrimgsm_direct_{lang}"
yaml_template = "translate_direct_yaml"
yaml_template = "afrimgsm_common_yaml"
file_name = f"{task_name}.yaml"
ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1
with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
) as f:
......@@ -137,23 +39,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
{
"include": yaml_template,
"dataset_name": lang,
"task": f"{task_name}",
"doc_to_text": f"""{{% if answer is not none %}}"""
f"""{{{{question+"\\n{ANSWER}"}}}}"""
f"""{{% else %}}"""
f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
f"""{{% endif %}}""",
"doc_to_target": f"""{{% if answer is not none %}}"""
f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}"""
f"""{{% else %}}"""
f"""{{{{answer_number|string}}}}"""
f"""{{% endif %}}""",
**filter_list,
"generation_kwargs": {
"until": [QUESTION, "</s>", "<|im_end|>"],
"do_sample": False,
},
**({"target_delimiter": DELIMITER} if DELIMITER else {}),
"task": f"{task_name}"
},
f,
allow_unicode=True,
......@@ -174,12 +60,12 @@ def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
default=True,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
"--output-dir", default="./direct", help="Directory to write yaml files to"
)
parser.add_argument(
"--mode",
......
import re
import sys
import unicodedata
from sklearn.metrics import f1_score
from lm_eval.filters.extraction import RegexFilter
def doc_to_choice(doc):
......@@ -8,15 +13,15 @@ def doc_to_choice(doc):
def doc_to_text(doc):
output = """You are a highly knowledgeable and intelligent artificial intelligence
model answers multiple-choice questions about '{subject}'
model answers multiple-choice questions about {subject}
Question: '''{question}'''
Question: {question}
Choices:
A: ''{choice1}'''
B: ''{choice2}'''
C: ''{choice3}'''
D: ''{choice4}'''
A: {choice1}
B: {choice2}
C: {choice3}
D: {choice4}
Answer: """
......
# Generated by utils.py
dataset_name: amh
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_amh
# Generated by utils.py
dataset_name: eng
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_eng
# Generated by utils.py
dataset_name: ewe
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_ewe
# Generated by utils.py
dataset_name: fra
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_fra
# Generated by utils.py
dataset_name: hau
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_hau
# Generated by utils.py
dataset_name: ibo
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_ibo
# Generated by utils.py
dataset_name: kin
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_kin
# Generated by utils.py
dataset_name: lin
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_lin
# Generated by utils.py
dataset_name: lug
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_lug
# Generated by utils.py
dataset_name: orm
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_orm
# Generated by utils.py
dataset_name: sna
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_sna
# Generated by utils.py
dataset_name: sot
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_sot
# Generated by utils.py
dataset_name: swa
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_swa
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment