"src/lib/vscode:/vscode.git/clone" did not exist on "d187b1615a751d83d648bcbeab1dd71f8ff298db"
Unverified Commit 6a5cde6a authored by Jess's avatar Jess Committed by GitHub
Browse files

Merge pull request #23 from JessicaOjo/africamgsm

manual xnli, bypass multiple choice logits for openai
parents fb142ccd 9701ef6e
...@@ -3,7 +3,7 @@ from typing import Literal, Optional, Tuple ...@@ -3,7 +3,7 @@ from typing import Literal, Optional, Tuple
OutputType = Literal[ OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice" "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice", "multiple_choice_gpt"
] ]
......
...@@ -23,7 +23,20 @@ def bypass_agg(arr): ...@@ -23,7 +23,20 @@ def bypass_agg(arr):
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr): def mean(arr):
return sum(arr) / len(arr) if isinstance(arr[0], (list, np.ndarray)):
return sum(arr[0]) / len(arr[0])
else:
return sum(arr) / len(arr)
@register_aggregation("acc_gpt")
def acc_gpt(arr):
unzipped_list = list(zip(*arr))
golds = unzipped_list[0]
preds = unzipped_list[1]
accuracy = sklearn.metrics.accuracy_score(golds, preds)
return accuracy
@register_aggregation("median") @register_aggregation("median")
...@@ -151,7 +164,7 @@ def brier_score_fn(items): # This is a passthrough function ...@@ -151,7 +164,7 @@ def brier_score_fn(items): # This is a passthrough function
@register_metric( @register_metric(
metric="acc", metric="acc",
higher_is_better=True, higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"], output_type=["loglikelihood", "multiple_choice", "multiple_choice_gpt"],
aggregation="mean", aggregation="mean",
) )
def acc_fn(items): # This is a passthrough function def acc_fn(items): # This is a passthrough function
...@@ -277,7 +290,7 @@ def mcc_fn(items): # This is a passthrough function ...@@ -277,7 +290,7 @@ def mcc_fn(items): # This is a passthrough function
@register_metric( @register_metric(
metric="f1", metric="f1",
higher_is_better=True, higher_is_better=True,
output_type="multiple_choice", output_type=["multiple_choice", "multiple_choice_gpt"],
aggregation="f1", aggregation="f1",
) )
def f1_fn(items): # This is a passthrough function def f1_fn(items): # This is a passthrough function
......
...@@ -87,6 +87,7 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -87,6 +87,7 @@ DEFAULT_METRIC_REGISTRY = {
], ],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"], "multiple_choice": ["acc", "acc_norm"],
"multiple_choice_gpt": ["acc"],
"generate_until": ["exact_match"], "generate_until": ["exact_match"],
} }
......
...@@ -44,6 +44,7 @@ from lm_eval.prompts import get_prompt ...@@ -44,6 +44,7 @@ from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
"loglikelihood", "loglikelihood",
"multiple_choice", "multiple_choice",
"multiple_choice_gpt",
"loglikelihood_rolling", "loglikelihood_rolling",
"generate_until", "generate_until",
] ]
...@@ -1064,7 +1065,6 @@ class ConfigurableTask(Task): ...@@ -1064,7 +1065,6 @@ class ConfigurableTask(Task):
eval_logger.warning("Applied prompt returns empty string") eval_logger.warning("Applied prompt returns empty string")
return self.config.fewshot_delimiter return self.config.fewshot_delimiter
else: else:
print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]: def doc_to_target(self, doc: Mapping) -> Union[int, str, list]:
...@@ -1142,7 +1142,7 @@ class ConfigurableTask(Task): ...@@ -1142,7 +1142,7 @@ class ConfigurableTask(Task):
arguments = (ctx, self.doc_to_target(doc)) arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif "multiple_choice" in self.OUTPUT_TYPE:
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter target_delimiter = self.config.target_delimiter
if self.multiple_input: if self.multiple_input:
...@@ -1154,17 +1154,28 @@ class ConfigurableTask(Task): ...@@ -1154,17 +1154,28 @@ class ConfigurableTask(Task):
else: else:
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
if self.OUTPUT_TYPE == "multiple_choice_gpt":
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="multiple_choice_gpt",
doc=doc, doc=doc,
arguments=arg, arguments=arg,
idx=i, idx=i,
**kwargs, **kwargs,
) )
for i, arg in enumerate(arguments) for i, arg in enumerate(arguments)
] ]
else:
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys(): if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
...@@ -1310,14 +1321,45 @@ class ConfigurableTask(Task): ...@@ -1310,14 +1321,45 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "multiple_choice_gpt":
gold = self.doc_to_target(doc)
result = results[0]
choices = self.doc_to_choice(doc)
try:
gold = choices[gold]
gold = type(result)(gold)
except TypeError:
gold = gold
for metric in self._metric_fn_list.keys():
try:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
result_dict[metric] = result_score
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
result = results[0] result = results[0]
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
# If you set doc_to_choice, try:
# it assumes that doc_to_target returns a number. # If you set doc_to_choice,
choices = self.doc_to_choice(doc) # it assumes that doc_to_target returns a number.
gold = choices[gold] choices = self.doc_to_choice(doc)
gold = choices[gold]
except TypeError:
gold = gold
# we expect multiple_targets to be a list. # we expect multiple_targets to be a list.
elif self.multiple_target: elif self.multiple_target:
gold = list(gold) gold = list(gold)
...@@ -1333,7 +1375,6 @@ class ConfigurableTask(Task): ...@@ -1333,7 +1375,6 @@ class ConfigurableTask(Task):
scores = [] scores = []
if not isinstance(gold, list): if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold] gold = [gold]
if metric == "exact_match": if metric == "exact_match":
result = [result for _ in range(len(gold))] result = [result for _ in range(len(gold))]
......
...@@ -471,6 +471,52 @@ class OpenaiChatCompletionsLM(LM): ...@@ -471,6 +471,52 @@ class OpenaiChatCompletionsLM(LM):
return grouper.get_original(res) return grouper.get_original(res)
def multiple_choice_gpt(self, requests, disable_tqdm: bool = False) -> List[str]:
res = defaultdict(list)
re_ords = {}
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for key, re_ord in re_ords.items():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
response = oa_completion(
client=self.client,
chat=True,
messages=inps,
model=self.model,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
res[key].append(s)
self.cache_hook.add_partial(
"multiple_choice_gpt", context, s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
def loglikelihood(self, requests, disable_tqdm: bool = False): def loglikelihood(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
......
import argparse import argparse
import yaml import yaml
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
languages_REGEX = {"eng":"The answer is (\\-?[0-9\\.\\,]+)", languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
"amh":"መልሱ (\\-?[0-9\\.\\,]+)",
"ibo":"Azịza ya bụ (\\-?[0-9\\.\\,]+)",
'fra':"La réponse est(\\-?[0-9\\.\\,]+)",
'sna':"Mhinduro kumubvunzo ndi (\\-?[0-9\\.\\,]+)",
'lin':"Eyano ezali (\\-?[0-9\\.\\,]+)",
'wol': "Tontu li (\\-?[0-9\\.\\,]+)",
'ewe': "ŋuɖoɖoae nye (\\-?[0-9\\.\\,]+)",
'lug': "Ansa eri (\\-?[0-9\\.\\,]+)",
'xho': "Impendulo ngu (\\-?[0-9\\.\\,]+)",
'kin': "Igisubizo ni (\\-?[0-9\\.\\,]+)",
'twi': "Ne nnyiano yɛ (\\-?[0-9\\.\\,]+)",
'zul': "Impendulo ithi (\\-?[0-9\\.\\,]+)",
'orm': "Deebiin isaa (\\-?[0-9\\.\\,]+)",
'yor': "Ìdáhùn náà ni (\\-?[0-9\\.\\,]+)",
'hau': "Amsar ita ce (\\-?[0-9\\.\\,]+)",
'sot': "Karabo ke (\\-?[0-9\\.\\,]+)",
'swa': "Jibu ni (\\-?[0-9\\.\\,]+)",
}
LANGUAGES = {}
for lang in languages: configs = {
if lang == 'amh': "QUESTION": "Question:",
LANGUAGES[lang] = { # English "ANSWER": "Step-by-Step Answer:",
"QUESTION": "ጥያቄ:", "DIRECT": "Answer:",
"ANSWER": "በቅደም ተከተል መልስ:", "REGEX": "The answer is (\\-?[0-9\\.\\,]+)"}
"DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
elif lang == 'yor':
LANGUAGES[lang] = { # English
"QUESTION": "Ìbéèrè:",
"ANSWER": "Ìdáhùn lẹ́sẹsẹ:",
"DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
else:
LANGUAGES[lang] = { # English
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
def add_regex_pattern(regex_pattern):
if regex_pattern is None:
return {}
return {
"filter_list": [
{
"name": "strict-match",
"filter": [
{
"function": "regex",
"regex_pattern": f"""{regex_pattern}""",
},
{
"function": "take_first",
},
],
},
{
"name": "flexible-extract",
"filter": [
{
"function": "regex",
"regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""",
"group_select": -1,
},
{
"function": "take_first",
},
],
},
],
}
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
...@@ -89,46 +18,19 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: ...@@ -89,46 +18,19 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
:param overwrite: Whether to overwrite files if they already exist. :param overwrite: Whether to overwrite files if they already exist.
""" """
err = [] err = []
for lang in LANGUAGES.keys(): for lang in languages:
try: try:
yaml_template = "cot_yaml"
filter_list = {}
DELIMITER = None
if mode == "direct": if mode == "direct":
ANSWER = LANGUAGES['eng']["DIRECT"]
QUESTION = LANGUAGES['eng']["QUESTION"]
REGEX = None
task_name = f"afrimgsm_direct_{lang}" task_name = f"afrimgsm_direct_{lang}"
yaml_template = "direct_yaml" yaml_template = "afrimgsm_common_yaml"
if mode == "direct-native":
ANSWER = LANGUAGES[lang]["DIRECT"]
QUESTION = LANGUAGES[lang]["QUESTION"]
REGEX = None
task_name = f"afrimgsm_direct_native_{lang}"
yaml_template = "direct_native_yaml"
elif mode == "native-cot": elif mode == "native-cot":
ANSWER = LANGUAGES[lang]["ANSWER"]
REGEX = LANGUAGES[lang]["REGEX"]
QUESTION = LANGUAGES[lang]["QUESTION"]
task_name = f"afrimgsm_native_cot_{lang}" task_name = f"afrimgsm_native_cot_{lang}"
filter_list = add_regex_pattern(REGEX) yaml_template = "afrimgsm_common_yaml"
DELIMITER = "" if lang in ["zh", "ja"] else None
elif mode == "en-cot": elif mode == "en-cot":
ANSWER = LANGUAGES["eng"]["ANSWER"]
REGEX = LANGUAGES["eng"]["REGEX"]
QUESTION = LANGUAGES["eng"]["QUESTION"]
task_name = f"afrimgsm_en_cot_{lang}" task_name = f"afrimgsm_en_cot_{lang}"
elif mode == "translate-direct": yaml_template = "afrimgsm_common_yaml"
ANSWER = LANGUAGES['eng']["DIRECT"]
QUESTION = LANGUAGES['eng']["QUESTION"]
REGEX = None
task_name = f"translate_afrimgsm_direct_{lang}"
yaml_template = "translate_direct_yaml"
file_name = f"{task_name}.yaml" file_name = f"{task_name}.yaml"
ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1
with open( with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8" f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
) as f: ) as f:
...@@ -137,23 +39,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: ...@@ -137,23 +39,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
{ {
"include": yaml_template, "include": yaml_template,
"dataset_name": lang, "dataset_name": lang,
"task": f"{task_name}", "task": f"{task_name}"
"doc_to_text": f"""{{% if answer is not none %}}"""
f"""{{{{question+"\\n{ANSWER}"}}}}"""
f"""{{% else %}}"""
f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
f"""{{% endif %}}""",
"doc_to_target": f"""{{% if answer is not none %}}"""
f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}"""
f"""{{% else %}}"""
f"""{{{{answer_number|string}}}}"""
f"""{{% endif %}}""",
**filter_list,
"generation_kwargs": {
"until": [QUESTION, "</s>", "<|im_end|>"],
"do_sample": False,
},
**({"target_delimiter": DELIMITER} if DELIMITER else {}),
}, },
f, f,
allow_unicode=True, allow_unicode=True,
...@@ -174,12 +60,12 @@ def main() -> None: ...@@ -174,12 +60,12 @@ def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--overwrite", "--overwrite",
default=False, default=True,
action="store_true", action="store_true",
help="Overwrite files if they already exist", help="Overwrite files if they already exist",
) )
parser.add_argument( parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to" "--output-dir", default="./direct", help="Directory to write yaml files to"
) )
parser.add_argument( parser.add_argument(
"--mode", "--mode",
......
import re
import sys
import unicodedata
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from lm_eval.filters.extraction import RegexFilter
def doc_to_choice(doc): def doc_to_choice(doc):
...@@ -8,15 +13,15 @@ def doc_to_choice(doc): ...@@ -8,15 +13,15 @@ def doc_to_choice(doc):
def doc_to_text(doc): def doc_to_text(doc):
output = """You are a highly knowledgeable and intelligent artificial intelligence output = """You are a highly knowledgeable and intelligent artificial intelligence
model answers multiple-choice questions about '{subject}' model answers multiple-choice questions about {subject}
Question: '''{question}''' Question: {question}
Choices: Choices:
A: ''{choice1}''' A: {choice1}
B: ''{choice2}''' B: {choice2}
C: ''{choice3}''' C: {choice3}
D: ''{choice4}''' D: {choice4}
Answer: """ Answer: """
......
# Generated by utils.py
dataset_name: amh
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_amh
# Generated by utils.py
dataset_name: eng
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_eng
# Generated by utils.py
dataset_name: ewe
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_ewe
# Generated by utils.py
dataset_name: fra
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_fra
# Generated by utils.py
dataset_name: hau
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_hau
# Generated by utils.py
dataset_name: ibo
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_ibo
# Generated by utils.py
dataset_name: kin
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_kin
# Generated by utils.py
dataset_name: lin
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_lin
# Generated by utils.py
dataset_name: lug
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_lug
# Generated by utils.py
dataset_name: orm
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_orm
# Generated by utils.py
dataset_name: sna
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_sna
# Generated by utils.py
dataset_name: sot
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_sot
# Generated by utils.py
dataset_name: swa
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_swa
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment