Commit dadfd4a8 authored by Herbie Bradley's avatar Herbie Bradley
Browse files

Add logiqa calibration

parent 1c5a73c9
import abc import abc
from dataclasses import dataclass, field, asdict
import re
import ast import ast
import yaml
import evaluate
import random
import itertools
import functools import functools
from tqdm import tqdm import itertools
import random
import re
from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from typing import Any, List, Literal, Tuple, Union
import datasets import datasets
import evaluate
import numpy as np import numpy as np
import scipy.special as sp
from typing import Union, List, Any, Tuple, Literal import yaml
from collections.abc import Callable from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import Instance
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
mean,
weighted_perplexity,
bits_per_byte, bits_per_byte,
mean,
metric_max_over_ground_truths, metric_max_over_ground_truths,
weighted_perplexity,
) )
from lm_eval.api.registry import ( from lm_eval.api.registry import (
get_metric, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
get_aggregation, get_aggregation,
get_default_aggregation, get_default_aggregation,
get_metric,
is_higher_better, is_higher_better,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
) )
from lm_eval.filters import build_filter_ensemble
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
"loglikelihood", "loglikelihood",
...@@ -89,7 +87,6 @@ class TaskConfig(dict): ...@@ -89,7 +87,6 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self): def __post_init__(self):
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
if self.output_type != "greedy_until": if self.output_type != "greedy_until":
eval_logger.warning( eval_logger.warning(
...@@ -472,7 +469,6 @@ class Task(abc.ABC): ...@@ -472,7 +469,6 @@ class Task(abc.ABC):
return labeled_examples + str(example) return labeled_examples + str(example)
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.apply(self._instances)
...@@ -570,7 +566,6 @@ class ConfigurableTask(Task): ...@@ -570,7 +566,6 @@ class ConfigurableTask(Task):
"aggregation" "aggregation"
] ]
else: else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_default_aggregation(metric_name) metric_agg = get_default_aggregation(metric_name)
eval_logger.warning( eval_logger.warning(
...@@ -656,8 +651,9 @@ class ConfigurableTask(Task): ...@@ -656,8 +651,9 @@ class ConfigurableTask(Task):
if type(test_target) is list: if type(test_target) is list:
self.multiple_target = len(test_target) self.multiple_target = len(test_target)
def download(self, dataset_kwargs=None): self.calibrations: list = []
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
name=self.DATASET_NAME, name=self.DATASET_NAME,
...@@ -740,7 +736,6 @@ class ConfigurableTask(Task): ...@@ -740,7 +736,6 @@ class ConfigurableTask(Task):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
else: else:
...@@ -775,7 +770,6 @@ class ConfigurableTask(Task): ...@@ -775,7 +770,6 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]: def doc_to_target(self, doc: dict) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
else: else:
...@@ -817,14 +811,12 @@ class ConfigurableTask(Task): ...@@ -817,14 +811,12 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]: def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif self._config.doc_to_choice is None: elif self._config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config") eval_logger.error("doc_to_choice was called but not set in config")
else: else:
doc_to_choice = self._config.doc_to_choice doc_to_choice = self._config.doc_to_choice
if type(doc_to_choice) == str: if type(doc_to_choice) == str:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif type(doc_to_choice) == list: elif type(doc_to_choice) == list:
...@@ -861,13 +853,11 @@ class ConfigurableTask(Task): ...@@ -861,13 +853,11 @@ class ConfigurableTask(Task):
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc)) arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self._config.target_delimiter target_delimiter = self._config.target_delimiter
if self.multiple_input: if self.multiple_input:
...@@ -918,7 +908,6 @@ class ConfigurableTask(Task): ...@@ -918,7 +908,6 @@ class ConfigurableTask(Task):
) )
def process_results(self, doc, results): def process_results(self, doc, results):
if callable(self._config.process_results): if callable(self._config.process_results):
return self._config.process_results(doc, results) return self._config.process_results(doc, results)
...@@ -953,7 +942,6 @@ class ConfigurableTask(Task): ...@@ -953,7 +942,6 @@ class ConfigurableTask(Task):
), ),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
...@@ -980,13 +968,18 @@ class ConfigurableTask(Task): ...@@ -980,13 +968,18 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if type(gold) is str: if type(gold) is str:
gold = choices.index(gold) gold = choices.index(gold)
# Convert lls from log-probabilities to normalized probabilities
norm_probs = np.exp(lls - sp.logsumexp(lls))
print(norm_probs)
if self.multiple_target: if self.multiple_target:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] for i in gold])) exact_match = int(any([is_greedy[i] for i in gold]))
else: else:
acc = 1.0 if pred == gold else 0.0 acc = 1.0 if pred == gold else 0.0
for i, choice in enumerate(choices):
calib_score = 1.0 if i == gold else 0.0
self.calibrations.append((norm_probs[i], calib_score))
acc_norm = 1.0 if pred_norm == gold else 0.0 acc_norm = 1.0 if pred_norm == gold else 0.0
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) exact_match = int(is_greedy[gold])
...@@ -1007,7 +1000,6 @@ class ConfigurableTask(Task): ...@@ -1007,7 +1000,6 @@ class ConfigurableTask(Task):
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if self._config.doc_to_choice is not None: if self._config.doc_to_choice is not None:
# If you set doc_to_choice, # If you set doc_to_choice,
......
import random import collections
import itertools import itertools
import json import json
import collections
import logging import logging
import random
import sys import sys
import torch import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch
import lm_eval.api import lm_eval.api
import lm_eval.tasks
import lm_eval.benchmarks
import lm_eval.models
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.benchmarks
import lm_eval.models
import lm_eval.tasks
from lm_eval.logger import eval_logger
from lm_eval.utils import ( from lm_eval.utils import (
positional_deprecated,
run_task_tests,
make_table,
create_iterator, create_iterator,
get_git_commit_hash, get_git_commit_hash,
make_table,
positional_deprecated,
run_task_tests,
) )
from lm_eval.logger import eval_logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout)) logger.addHandler(logging.StreamHandler(sys.stdout))
...@@ -117,7 +115,6 @@ def simple_evaluate( ...@@ -117,7 +115,6 @@ def simple_evaluate(
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys(): for task_name in task_dict.keys():
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
if type(task_obj) == tuple: if type(task_obj) == tuple:
group, task_obj = task_obj group, task_obj = task_obj
...@@ -224,7 +221,6 @@ def evaluate( ...@@ -224,7 +221,6 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
task_groups[task_name] = group task_groups[task_name] = group
...@@ -345,11 +341,36 @@ def evaluate( ...@@ -345,11 +341,36 @@ def evaluate(
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) vals[(task_name, key, metric)].append(value)
calibs = sorted(task.calibrations, key=lambda x: x[0])
def bin_list_into_subsets(input_list, num_subsets=10):
subset_size = len(input_list) // num_subsets
remainder = len(input_list) % num_subsets
subsets = []
start = 0
for _ in range(num_subsets):
subset_end = start + subset_size + (1 if remainder > 0 else 0)
subsets.append(input_list[start:subset_end])
start = subset_end
remainder -= 1
return subsets
subsets = bin_list_into_subsets(calibs, 10)
x_coords = [np.mean([x[0] for x in subset]) for subset in subsets]
y_coords = [np.mean([x[1] for x in subset]) for subset in subsets]
model_name = lm.config._name_or_path.split("/")[1]
plt.plot(x_coords, y_coords, label=model_name)
plt.plot([0, 1], [0, 1], linestyle="--", color="black")
plt.xlabel("Probabilities")
plt.ylabel("Frequences")
plt.title("Calibration")
plt.legend()
plt.savefig(f"{model_name}-long.png")
if lm.world_size > 1: if lm.world_size > 1:
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
# first gather logged samples across all ranks # first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()): for task_name, task_samples in list(samples.items()):
full_samples = [None] * lm.world_size full_samples = [None] * lm.world_size
torch.distributed.all_gather_object(full_samples, task_samples) torch.distributed.all_gather_object(full_samples, task_samples)
...@@ -358,7 +379,6 @@ def evaluate( ...@@ -358,7 +379,6 @@ def evaluate(
# then collect metrics across all ranks # then collect metrics across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
numitem = 0 numitem = 0
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
......
...@@ -5,7 +5,8 @@ output_type: multiple_choice ...@@ -5,7 +5,8 @@ output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
test_split: test test_split: test
doc_to_choice: "{{options}}" num_fewshot: 5
doc_to_choice: !function utils_logiqa.doc_to_choice
doc_to_text: !function utils_logiqa.doc_to_text doc_to_text: !function utils_logiqa.doc_to_text
doc_to_target: !function utils_logiqa.doc_to_target doc_to_target: !function utils_logiqa.doc_to_target
doc_to_decontamination_query: "{{context}}" doc_to_decontamination_query: "{{context}}"
......
...@@ -4,17 +4,17 @@ def doc_to_text(doc) -> str: ...@@ -4,17 +4,17 @@ def doc_to_text(doc) -> str:
Passage: <passage> Passage: <passage>
Question: <question> Question: <question>
Choices: Choices:
A. <choice1> (A) <choice1>
B. <choice2> (B) <choice2>
C. <choice3> (C) <choice3>
D. <choice4> (D) <choice4>
Answer: Answer:
""" """
choices = ["a", "b", "c", "d"] choices = ["a", "b", "c", "d"]
prompt = "Passage: " + doc["context"] + "\n" prompt = "Passage: " + doc["context"] + "\n"
prompt += "Question: " + doc["question"] + "\nChoices:\n" prompt += "Question: " + doc["question"] + "\nChoices:\n"
for choice, option in zip(choices, doc["options"]): for choice, option in zip(choices, doc["options"]):
prompt += f"{choice.upper()}. {option}\n" prompt += f"({choice.upper()}) {option}\n"
prompt += "Answer:" prompt += "Answer:"
return prompt return prompt
...@@ -22,3 +22,7 @@ def doc_to_text(doc) -> str: ...@@ -22,3 +22,7 @@ def doc_to_text(doc) -> str:
def doc_to_target(doc) -> int: def doc_to_target(doc) -> int:
choices = ["a", "b", "c", "d"] choices = ["a", "b", "c", "d"]
return choices.index(doc["label"].strip()) return choices.index(doc["label"].strip())
def doc_to_choice(doc):
return ["(A)", "(B)", "(C)", "(D)"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment