Commit b8122d98 authored by lintangsutawika's avatar lintangsutawika
Browse files

pre-commit

parent 01b129bb
import datasets
from functools import partial from functools import partial
import datasets
class ContextSampler: class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd self.rnd = rnd
...@@ -15,27 +17,36 @@ class ContextSampler: ...@@ -15,27 +17,36 @@ class ContextSampler:
self.target_delimiter = self.config.target_delimiter self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter self.fewshot_delimiter = self.config.fewshot_delimiter
if self.config.fewshot_config is not None and self.config.fewshot_config.get("doc_to_text", None) is not None: if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_text", None) is not None
):
self.doc_to_text = partial( self.doc_to_text = partial(
self.task.doc_to_text, self.task.doc_to_text,
doc_to_text=self.config.fewshot_config.get("doc_to_text", None) doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
) )
else: else:
self.doc_to_text = self.task.doc_to_text self.doc_to_text = self.task.doc_to_text
if self.config.fewshot_config is not None and self.config.fewshot_config.get("doc_to_target", None) is not None: if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_target", None) is not None
):
self.doc_to_target = partial( self.doc_to_target = partial(
self.task.doc_to_target, self.task.doc_to_target,
doc_to_target=self.config.fewshot_config.get("doc_to_target", None) doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
) )
else: else:
self.doc_to_target = self.task.doc_to_target self.doc_to_target = self.task.doc_to_target
if self.config.fewshot_config is not None and self.config.fewshot_config.get("doc_to_choice", None) is not None: if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_choice", None) is not None
):
self.doc_to_choice = partial( self.doc_to_choice = partial(
self.task.doc_to_choice, self.task.doc_to_choice,
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None) doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
) )
else: else:
self.doc_to_choice = self.task.doc_to_choice self.doc_to_choice = self.task.doc_to_choice
...@@ -72,7 +83,7 @@ class ContextSampler: ...@@ -72,7 +83,7 @@ class ContextSampler:
else self.doc_to_choice(doc)[doc_content] else self.doc_to_choice(doc)[doc_content]
) )
labeled_examples += self.target_delimiter labeled_examples += self.target_delimiter
if doc_target is not "": if doc_target != "":
labeled_examples += ( labeled_examples += (
str(doc_target[0]) str(doc_target[0])
if isinstance(doc_target, list) if isinstance(doc_target, list)
......
...@@ -608,16 +608,16 @@ def evaluate( ...@@ -608,16 +608,16 @@ def evaluate(
] ]
# compute group's pooled metric and stderr # compute group's pooled metric and stderr
results[group][ results[group][metric] = (
metric lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes) )
# TODO: calculate grouped metric using aggregation fn # TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs: if "N/A" in stderrs:
results[group][stderr] = "N/A" results[group][stderr] = "N/A"
else: else:
results[group][ results[group][stderr] = (
stderr lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes) )
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line: # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics) # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
......
...@@ -275,9 +275,9 @@ def consolidate_results( ...@@ -275,9 +275,9 @@ def consolidate_results(
metric_key metric_key
] ]
results[task_output.task_name]["samples"] = task_output.sample_len results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][ results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
f"{metric}_stderr,{filter_key}" task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] )
return results, samples, configs, versions, num_fewshot, higher_is_better return results, samples, configs, versions, num_fewshot, higher_is_better
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment