"vscode:/vscode.git/clone" did not exist on "9b34672eec873e7bab8382c9e60063b2caa35eee"
Commit 753e8670 authored by JessicaOjo's avatar JessicaOjo
Browse files

add manual xnli prompt, add multichoice for openai models, and adapt...

add manual xnli prompt, add multichoice for openai models, and adapt multichoice metric for openai model
parent f720ce81
......@@ -3,7 +3,7 @@ from typing import Literal, Optional, Tuple
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice", "multiple_choice_gpt"
]
......
......@@ -23,7 +23,20 @@ def bypass_agg(arr):
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
if isinstance(arr[0], (list, np.ndarray)):
return sum(arr[0]) / len(arr[0])
else:
return sum(arr) / len(arr)
@register_aggregation("acc_gpt")
def acc_gpt(arr):
unzipped_list = list(zip(*arr))
golds = unzipped_list[0]
preds = unzipped_list[1]
accuracy = sklearn.metrics.accuracy_score(golds, preds)
return accuracy
@register_aggregation("median")
......@@ -151,7 +164,7 @@ def brier_score_fn(items): # This is a passthrough function
@register_metric(
metric="acc",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
output_type=["loglikelihood", "multiple_choice", "multiple_choice_gpt"],
aggregation="mean",
)
def acc_fn(items): # This is a passthrough function
......@@ -277,7 +290,7 @@ def mcc_fn(items): # This is a passthrough function
@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
output_type=["multiple_choice", "multiple_choice_gpt"],
aggregation="f1",
)
def f1_fn(items): # This is a passthrough function
......
......@@ -87,6 +87,7 @@ DEFAULT_METRIC_REGISTRY = {
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"multiple_choice_gpt": ["acc"],
"generate_until": ["exact_match"],
}
......
......@@ -44,6 +44,7 @@ from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
"multiple_choice_gpt",
"loglikelihood_rolling",
"generate_until",
]
......@@ -1064,7 +1065,6 @@ class ConfigurableTask(Task):
eval_logger.warning("Applied prompt returns empty string")
return self.config.fewshot_delimiter
else:
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]:
......@@ -1142,7 +1142,7 @@ class ConfigurableTask(Task):
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
elif "multiple_choice" in self.OUTPUT_TYPE:
choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter
if self.multiple_input:
......@@ -1154,17 +1154,28 @@ class ConfigurableTask(Task):
else:
# Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
if self.OUTPUT_TYPE == "multiple_choice_gpt":
request_list = [
Instance(
request_type="multiple_choice_gpt",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
else:
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arg,
idx=i,
**kwargs,
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
......@@ -1310,14 +1321,45 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "multiple_choice_gpt":
gold = self.doc_to_target(doc)
result = results[0]
choices = self.doc_to_choice(doc)
try:
gold = choices[gold]
gold = type(result)(gold)
except TypeError:
gold = gold
for metric in self._metric_fn_list.keys():
try:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
result_dict[metric] = result_score
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
if self.config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
try:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
except TypeError:
gold = gold
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
......@@ -1333,7 +1375,6 @@ class ConfigurableTask(Task):
scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric == "exact_match":
result = [result for _ in range(len(gold))]
......
......@@ -471,6 +471,52 @@ class OpenaiChatCompletionsLM(LM):
return grouper.get_original(res)
def multiple_choice_gpt(self, requests, disable_tqdm: bool = False) -> List[str]:
res = defaultdict(list)
re_ords = {}
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for key, re_ord in re_ords.items():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
response = oa_completion(
client=self.client,
chat=True,
messages=inps,
model=self.model,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
res[key].append(s)
self.cache_hook.add_partial(
"multiple_choice_gpt", context, s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
def loglikelihood(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
......
import argparse
import yaml
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
LANGUAGES = {}
for lang in languages:
LANGUAGES[lang] = { # English
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)"}
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
def add_regex_pattern(regex_pattern):
if regex_pattern is None:
return {}
return {
"filter_list": [
{
"name": "strict-match",
"filter": [
{
"function": "regex",
"regex_pattern": f"""{regex_pattern}""",
},
{
"function": "take_first",
},
],
},
{
"name": "flexible-extract",
"filter": [
{
"function": "regex",
"regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""",
"group_select": -1,
},
{
"function": "take_first",
},
],
},
],
}
configs = {
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)"}
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
......@@ -55,31 +18,19 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES.keys():
for lang in languages:
try:
QUESTION = LANGUAGES[lang]["QUESTION"]
yaml_template = "cot_yaml"
filter_list = {}
DELIMITER = None
if mode == "direct":
ANSWER = LANGUAGES[lang]["DIRECT"]
REGEX = None
task_name = f"afrimgsm_direct_{lang}"
yaml_template = "direct_yaml"
yaml_template = "afrimgsm_common_yaml"
elif mode == "native-cot":
ANSWER = LANGUAGES[lang]["ANSWER"]
REGEX = LANGUAGES[lang]["REGEX"]
task_name = f"afrimgsm_native_cot_{lang}"
filter_list = add_regex_pattern(REGEX)
DELIMITER = "" if lang in ["zh", "ja"] else None
yaml_template = "afrimgsm_common_yaml"
elif mode == "en-cot":
ANSWER = LANGUAGES["en"]["ANSWER"]
REGEX = LANGUAGES["en"]["REGEX"]
task_name = f"afrimgsm_en_cot_{lang}"
yaml_template = "afrimgsm_common_yaml"
file_name = f"{task_name}.yaml"
ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1
with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
) as f:
......@@ -88,23 +39,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
{
"include": yaml_template,
"dataset_name": lang,
"task": f"{task_name}",
"doc_to_text": f"""{{% if answer is not none %}}"""
f"""{{{{question+"\\n{ANSWER}"}}}}"""
f"""{{% else %}}"""
f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
f"""{{% endif %}}""",
"doc_to_target": f"""{{% if answer is not none %}}"""
f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}"""
f"""{{% else %}}"""
f"""{{{{answer_number|string}}}}"""
f"""{{% endif %}}""",
**filter_list,
"generation_kwargs": {
"until": [QUESTION, "</s>", "<|im_end|>"],
"do_sample": False,
},
**({"target_delimiter": DELIMITER} if DELIMITER else {}),
"task": f"{task_name}"
},
f,
allow_unicode=True,
......@@ -125,16 +60,16 @@ def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
default=True,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
"--output-dir", default="./direct", help="Directory to write yaml files to"
)
parser.add_argument(
"--mode",
default="native-cot",
default="direct",
choices=["direct", "native-cot", "en-cot"],
help="Mode of chain-of-thought",
)
......
import re
import sys
import unicodedata
from sklearn.metrics import f1_score
from lm_eval.filters.extraction import RegexFilter
def doc_to_choice(doc):
......@@ -8,15 +13,15 @@ def doc_to_choice(doc):
def doc_to_text(doc):
output = """You are a highly knowledgeable and intelligent artificial intelligence
model answers multiple-choice questions about '{subject}'
model answers multiple-choice questions about {subject}
Question: '''{question}'''
Question: {question}
Choices:
A: ''{choice1}'''
B: ''{choice2}'''
C: ''{choice3}'''
D: ''{choice4}'''
A: {choice1}
B: {choice2}
C: {choice3}
D: {choice4}
Answer: """
......
# Generated by utils.py
dataset_name: amh
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_amh
# Generated by utils.py
dataset_name: eng
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_eng
# Generated by utils.py
dataset_name: ewe
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_ewe
# Generated by utils.py
dataset_name: fra
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_fra
# Generated by utils.py
dataset_name: hau
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_hau
# Generated by utils.py
dataset_name: ibo
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_ibo
# Generated by utils.py
dataset_name: kin
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_kin
# Generated by utils.py
dataset_name: lin
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_lin
# Generated by utils.py
dataset_name: lug
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_lug
# Generated by utils.py
dataset_name: orm
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_orm
# Generated by utils.py
dataset_name: sna
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_sna
# Generated by utils.py
dataset_name: sot
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_sot
# Generated by utils.py
dataset_name: swa
include: afrixnli_manual_direct_yaml
task: afrixnli_manual_direct_swa
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment