Commit dadfd4a8 authored by Herbie Bradley's avatar Herbie Bradley
Browse files

Add logiqa calibration

parent 1c5a73c9
import abc
from dataclasses import dataclass, field, asdict
import re
import ast
import yaml
import evaluate
import random
import itertools
import functools
from tqdm import tqdm
import itertools
import random
import re
from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from typing import Any, List, Literal, Tuple, Union
import datasets
import evaluate
import numpy as np
from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable
import scipy.special as sp
import yaml
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import (
mean,
weighted_perplexity,
bits_per_byte,
mean,
metric_max_over_ground_truths,
weighted_perplexity,
)
from lm_eval.api.registry import (
get_metric,
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
get_aggregation,
get_default_aggregation,
get_metric,
is_higher_better,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
)
from lm_eval.filters import build_filter_ensemble
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [
"loglikelihood",
......@@ -89,7 +87,6 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self):
if self.generation_kwargs is not None:
if self.output_type != "greedy_until":
eval_logger.warning(
......@@ -472,7 +469,6 @@ class Task(abc.ABC):
return labeled_examples + str(example)
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
......@@ -570,7 +566,6 @@ class ConfigurableTask(Task):
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_default_aggregation(metric_name)
eval_logger.warning(
......@@ -656,8 +651,9 @@ class ConfigurableTask(Task):
if type(test_target) is list:
self.multiple_target = len(test_target)
def download(self, dataset_kwargs=None):
self.calibrations: list = []
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
......@@ -740,7 +736,6 @@ class ConfigurableTask(Task):
return doc
def doc_to_text(self, doc):
if self.prompt is not None:
doc_to_text = self.prompt
else:
......@@ -775,7 +770,6 @@ class ConfigurableTask(Task):
raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
else:
......@@ -817,14 +811,12 @@ class ConfigurableTask(Task):
raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None:
doc_to_choice = self.prompt
elif self._config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config")
else:
doc_to_choice = self._config.doc_to_choice
if type(doc_to_choice) == str:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif type(doc_to_choice) == list:
......@@ -861,13 +853,11 @@ class ConfigurableTask(Task):
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc)
target_delimiter = self._config.target_delimiter
if self.multiple_input:
......@@ -918,7 +908,6 @@ class ConfigurableTask(Task):
)
def process_results(self, doc, results):
if callable(self._config.process_results):
return self._config.process_results(doc, results)
......@@ -953,7 +942,6 @@ class ConfigurableTask(Task):
),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc.
......@@ -980,13 +968,18 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc)
if type(gold) is str:
gold = choices.index(gold)
# Convert lls from log-probabilities to normalized probabilities
norm_probs = np.exp(lls - sp.logsumexp(lls))
print(norm_probs)
if self.multiple_target:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] for i in gold]))
else:
acc = 1.0 if pred == gold else 0.0
for i, choice in enumerate(choices):
calib_score = 1.0 if i == gold else 0.0
self.calibrations.append((norm_probs[i], calib_score))
acc_norm = 1.0 if pred_norm == gold else 0.0
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold])
......@@ -1007,7 +1000,6 @@ class ConfigurableTask(Task):
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc)
if self._config.doc_to_choice is not None:
# If you set doc_to_choice,
......
import random
import collections
import itertools
import json
import collections
import logging
import random
import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import lm_eval.api
import lm_eval.tasks
import lm_eval.benchmarks
import lm_eval.models
import lm_eval.api.metrics
import lm_eval.api.registry
import lm_eval.benchmarks
import lm_eval.models
import lm_eval.tasks
from lm_eval.logger import eval_logger
from lm_eval.utils import (
positional_deprecated,
run_task_tests,
make_table,
create_iterator,
get_git_commit_hash,
make_table,
positional_deprecated,
run_task_tests,
)
from lm_eval.logger import eval_logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
......@@ -117,7 +115,6 @@ def simple_evaluate(
task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if type(task_obj) == tuple:
group, task_obj = task_obj
......@@ -224,7 +221,6 @@ def evaluate(
# get lists of each type of request
for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
task_groups[task_name] = group
......@@ -345,11 +341,36 @@ def evaluate(
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
calibs = sorted(task.calibrations, key=lambda x: x[0])
def bin_list_into_subsets(input_list, num_subsets=10):
subset_size = len(input_list) // num_subsets
remainder = len(input_list) % num_subsets
subsets = []
start = 0
for _ in range(num_subsets):
subset_end = start + subset_size + (1 if remainder > 0 else 0)
subsets.append(input_list[start:subset_end])
start = subset_end
remainder -= 1
return subsets
subsets = bin_list_into_subsets(calibs, 10)
x_coords = [np.mean([x[0] for x in subset]) for subset in subsets]
y_coords = [np.mean([x[1] for x in subset]) for subset in subsets]
model_name = lm.config._name_or_path.split("/")[1]
plt.plot(x_coords, y_coords, label=model_name)
plt.plot([0, 1], [0, 1], linestyle="--", color="black")
plt.xlabel("Probabilities")
plt.ylabel("Frequences")
plt.title("Calibration")
plt.legend()
plt.savefig(f"{model_name}-long.png")
if lm.world_size > 1:
# if multigpu, then gather data across all ranks
# first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()):
full_samples = [None] * lm.world_size
torch.distributed.all_gather_object(full_samples, task_samples)
......@@ -358,7 +379,6 @@ def evaluate(
# then collect metrics across all ranks
vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items():
numitem = 0
if type(items[0]) == tuple:
numitem = len(items[0])
......
......@@ -5,7 +5,8 @@ output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
doc_to_choice: "{{options}}"
num_fewshot: 5
doc_to_choice: !function utils_logiqa.doc_to_choice
doc_to_text: !function utils_logiqa.doc_to_text
doc_to_target: !function utils_logiqa.doc_to_target
doc_to_decontamination_query: "{{context}}"
......
......@@ -4,17 +4,17 @@ def doc_to_text(doc) -> str:
Passage: <passage>
Question: <question>
Choices:
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
(A) <choice1>
(B) <choice2>
(C) <choice3>
(D) <choice4>
Answer:
"""
choices = ["a", "b", "c", "d"]
prompt = "Passage: " + doc["context"] + "\n"
prompt += "Question: " + doc["question"] + "\nChoices:\n"
for choice, option in zip(choices, doc["options"]):
prompt += f"{choice.upper()}. {option}\n"
prompt += f"({choice.upper()}) {option}\n"
prompt += "Answer:"
return prompt
......@@ -22,3 +22,7 @@ def doc_to_text(doc) -> str:
def doc_to_target(doc) -> int:
choices = ["a", "b", "c", "d"]
return choices.index(doc["label"].strip())
def doc_to_choice(doc):
return ["(A)", "(B)", "(C)", "(D)"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment