Commit 4e5a328e authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

bugfixes + add write_out

parent ff8903f2
...@@ -362,10 +362,3 @@ def stderr_for_metric(metric, bootstrap_iters): ...@@ -362,10 +362,3 @@ def stderr_for_metric(metric, bootstrap_iters):
stderr = {mean: mean_stderr, acc_all: acc_all_stderr} stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None) return stderr.get(metric, None)
def yesno(x):
if x:
return "yes"
else:
return "no"
...@@ -63,7 +63,7 @@ class TaskConfig(dict): ...@@ -63,7 +63,7 @@ class TaskConfig(dict):
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
template_aliases: str = None template_aliases: str = ""
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
gold_alias: Union[Callable, str] = None gold_alias: Union[Callable, str] = None
...@@ -89,7 +89,7 @@ class TaskConfig(dict): ...@@ -89,7 +89,7 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of # force prompt-compatibility for some prompt regardless of
# field names in prompt # field names in prompt
if self.template_aliases is not None: if type(self.template_aliases) == str:
if type(self.doc_to_text) == str: if type(self.doc_to_text) == str:
self.doc_to_text = self.template_aliases + self.doc_to_text self.doc_to_text = self.template_aliases + self.doc_to_text
......
...@@ -199,6 +199,19 @@ def evaluate( ...@@ -199,6 +199,19 @@ def evaluate(
task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size) task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)
eval_logger.info(
f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
)
if write_out:
for inst in task.instances:
# print the prompt for the first few documents
if inst.doc_id < 4:
print(
f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\n{inst.args[0]}\n(end of prompt on previous line)"
)
print("Request:", inst)
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = ( reqtype = (
"loglikelihood" "loglikelihood"
...@@ -338,16 +351,16 @@ def evaluate( ...@@ -338,16 +351,16 @@ def evaluate(
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
stderr = lm_eval.api.metrics.stderr_for_metric( if stderr is not None:
metric=task.aggregation()[metric], results[task_name][metric + "_stderr" + "," + key] = stderr(items)
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
return { return {
"results": dict(results), "results": dict(results),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment