Unverified Commit 6efc8d5e authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #702 from EleutherAI/num_fewshot-bug

[Refactor] Fixes for when using `num_fewshot`
parents 4e44f0aa e0b3cbf5
...@@ -130,6 +130,9 @@ class TaskConfig(dict): ...@@ -130,6 +130,9 @@ class TaskConfig(dict):
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self): def to_dict(self):
"""dumps the current config as a dictionary object, as a printable format. """dumps the current config as a dictionary object, as a printable format.
null fields will not be printed. null fields will not be printed.
......
...@@ -35,7 +35,7 @@ def simple_evaluate( ...@@ -35,7 +35,7 @@ def simple_evaluate(
model, model,
model_args=None, model_args=None,
tasks=[], tasks=[],
num_fewshot=0, num_fewshot=None,
batch_size=None, batch_size=None,
max_batch_size=None, max_batch_size=None,
device=None, device=None,
...@@ -112,7 +112,17 @@ def simple_evaluate( ...@@ -112,7 +112,17 @@ def simple_evaluate(
+ "_rank" + str(lm.rank) + ".db", + "_rank" + str(lm.rank) + ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot) task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys():
config = task_dict[task_name]._config
if num_fewshot is not None:
if config["num_fewshot"] > 0:
default_num_fewshot = config["num_fewshot"]
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_dict[task_name]._config["num_fewshot"] = num_fewshot
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -134,7 +144,6 @@ def simple_evaluate( ...@@ -134,7 +144,6 @@ def simple_evaluate(
if isinstance(model, str) if isinstance(model, str)
else model.model.config._name_or_path, else model.model.config._name_or_path,
"model_args": model_args, "model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size, "batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()) "batch_sizes": list(lm.batch_sizes.values())
if hasattr(lm, "batch_sizes") if hasattr(lm, "batch_sizes")
...@@ -169,8 +178,6 @@ def evaluate( ...@@ -169,8 +178,6 @@ def evaluate(
Language Model Language Model
:param task_dict: dict[str, Task] :param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param limit: int, optional :param limit: int, optional
Limit the number of examples per task (only use this for testing) Limit the number of examples per task (only use this for testing)
:param bootstrap_iters: :param bootstrap_iters:
......
...@@ -265,10 +265,20 @@ def make_table(result_dict): ...@@ -265,10 +265,20 @@ def make_table(result_dict):
md_writer = MarkdownTableWriter() md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter() latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"] md_writer.headers = [
"Task",
"Version",
"Fewshot",
"Filter",
"Metric",
"Value",
"",
"Stderr",
]
latex_writer.headers = [ latex_writer.headers = [
"Task", "Task",
"Version", "Version",
"Fewshot",
"Filter", "Filter",
"Metric", "Metric",
"Value", "Value",
...@@ -280,6 +290,7 @@ def make_table(result_dict): ...@@ -280,6 +290,7 @@ def make_table(result_dict):
for k, dic in result_dict["results"].items(): for k, dic in result_dict["results"].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
n = str(result_dict["configs"][k]["num_fewshot"])
for (mf), v in dic.items(): for (mf), v in dic.items():
m, _, f = mf.partition(",") m, _, f = mf.partition(",")
if m.endswith("_stderr"): if m.endswith("_stderr"):
...@@ -287,10 +298,11 @@ def make_table(result_dict): ...@@ -287,10 +298,11 @@ def make_table(result_dict):
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] se = dic[m + "_stderr" + "," + f]
values.append([k, version, f, m, "%.4f" % v, "±", "%.4f" % se]) values.append([k, version, n, f, m, "%.4f" % v, "±", "%.4f" % se])
else: else:
values.append([k, version, f, m, "%.4f" % v, "", ""]) values.append([k, version, n, f, m, "%.4f" % v, "", ""])
k = "" k = ""
n = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
latex_writer.value_matrix = values latex_writer.value_matrix = values
......
...@@ -28,7 +28,7 @@ def parse_args(): ...@@ -28,7 +28,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--num_fewshot", "--num_fewshot",
type=int, type=int,
default=0, default=None,
help="Number of examples in few-shot context", help="Number of examples in few-shot context",
) )
parser.add_argument("--batch_size", type=int, default=1) # TODO: only integers parser.add_argument("--batch_size", type=int, default=1) # TODO: only integers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment