Unverified Commit 5888a695 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #33 from zphang/retrieval

Refactor for explicit imports
parents 635a2155 515d78b3
......@@ -122,17 +122,3 @@ class Dataset(abc.ABC):
) + "\n\n"
example = self.doc_to_text(doc, include_target=False).strip()
return description + labeled_examples + example
class Registry:
def __init__(self, registry_name):
self.registry_name = registry_name
self.registry = {}
def register(self, name):
def register_cls(new_cls):
if name in self.registry:
raise ValueError('Cannot register duplicate ({})'.format(self.registry_name, name))
self.registry[name] = new_cls
return new_cls
return register_cls
import importlib
import os
from lm_eval.base import Registry
from . import gpt2
from . import gpt3
MODEL_REGISTRY = Registry(registry_name="models")
# Load all modules in models directory to populate registry
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith('_')
and not file.startswith('.')
and (file.endswith('.py') or os.path.isdir(path))
):
module_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('lm_eval.models.' + module_name)
MODEL_REGISTRY = {
"gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM,
}
def get_model(model_name):
return MODEL_REGISTRY.registry[model_name]
return MODEL_REGISTRY[model_name]
......@@ -3,10 +3,8 @@ import torch
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("gpt2")
class GPT2LM(LM):
def __init__(self, device="cpu"):
self.device = torch.device(device)
......
import importlib
import os
from lm_eval.base import Registry
from . import superglue
from . import glue
TASK_REGISTRY = Registry(registry_name="tasks")
# Load all modules in models directory to populate registry
tasks_dir = os.path.dirname(__file__)
for file in os.listdir(tasks_dir):
path = os.path.join(tasks_dir, file)
if (
not file.startswith('_')
and not file.startswith('.')
and (file.endswith('.py') or os.path.isdir(path))
):
module_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('lm_eval.tasks.' + module_name)
TASK_REGISTRY = {
"cola": glue.CoLA,
"mnli": glue.MNLI,
"mrpc": glue.MRPC,
"rte": glue.RTE,
"qnli": glue.QNLI,
"qqp": glue.QQP,
"stsb": glue.STSB,
"sst": glue.SST,
"wnli": glue.WNLI,
"boolq": superglue.BoolQ,
"commitmentbank": superglue.CommitmentBank,
"copa": superglue.Copa,
"wic": superglue.WordsInContext,
"wsc": superglue.WinogradSchemaChallenge,
}
ALL_TASKS = sorted(list(TASK_REGISTRY.registry))
ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name):
return TASK_REGISTRY.registry[task_name]
return TASK_REGISTRY[task_name]
def get_task_dict(task_name_list):
......
import json
import random
from lm_eval.base import Dataset
from . import TASK_REGISTRY
@TASK_REGISTRY.register("coqa")
class CoQA(Dataset):
def has_training_docs(self):
return True
......
......@@ -3,7 +3,6 @@ from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import NLP_TASK, simple_accuracy_metric, yesno
from . import TASK_REGISTRY
def get_accuracy_and_f1(preds, golds):
......@@ -23,7 +22,6 @@ def get_accuracy_and_f1(preds, golds):
}
@TASK_REGISTRY.register("cola")
class CoLA(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "cola"
......@@ -66,7 +64,6 @@ class CoLA(NLP_TASK):
}
@TASK_REGISTRY.register("mnli")
class MNLI(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "mnli"
......@@ -110,15 +107,14 @@ class MNLI(NLP_TASK):
num_fewshot=num_fewshot,
)
probs = np.array([
self.lm.loglikelihood(ctx, ' True'),
self.lm.loglikelihood(ctx, ' Neither'),
self.lm.loglikelihood(ctx, ' False'),
lm.loglikelihood(ctx, ' True'),
lm.loglikelihood(ctx, ' Neither'),
lm.loglikelihood(ctx, ' False'),
])
preds.append(np.argmax(probs))
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("mrpc")
class MRPC(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "mrpc"
......@@ -157,7 +153,6 @@ class MRPC(NLP_TASK):
return get_accuracy_and_f1(preds=preds, golds=golds)
@TASK_REGISTRY.register("rte")
class RTE(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "rte"
......@@ -195,7 +190,6 @@ class RTE(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("qnli")
class QNLI(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "qnli"
......@@ -229,11 +223,10 @@ class QNLI(NLP_TASK):
provide_description=provide_description,
num_fewshot=num_fewshot,
)
preds.append(self.lm.loglikelihood(ctx, ' False') > self.lm.loglikelihood(ctx, ' True'))
preds.append(lm.loglikelihood(ctx, ' False') > lm.loglikelihood(ctx, ' True'))
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("qqp")
class QQP(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "qqp"
......@@ -272,7 +265,6 @@ class QQP(NLP_TASK):
return get_accuracy_and_f1(preds=preds, golds=golds)
@TASK_REGISTRY.register("stsb")
class STSB(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "stsb"
......@@ -330,7 +322,6 @@ class STSB(NLP_TASK):
}
@TASK_REGISTRY.register("sst")
class SST(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "sst2"
......@@ -368,7 +359,6 @@ class SST(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("wnli")
class WNLI(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "wnli"
......@@ -404,9 +394,9 @@ class WNLI(NLP_TASK):
num_fewshot=num_fewshot,
)
probs = np.array([
self.lm.loglikelihood(ctx, ' True'),
self.lm.loglikelihood(ctx, ' Neither'),
self.lm.loglikelihood(ctx, ' False'),
lm.loglikelihood(ctx, ' True'),
lm.loglikelihood(ctx, ' Neither'),
lm.loglikelihood(ctx, ' False'),
])
preds.append(np.argmax(probs))
return simple_accuracy_metric(preds=preds, golds=golds)
import numpy as np
from tqdm import auto as tqdm_lib
from . common import NLP_TASK, simple_accuracy_metric, yesno
from . import TASK_REGISTRY
@TASK_REGISTRY.register("boolq")
class BoolQ(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "boolq"
......@@ -38,7 +36,6 @@ class BoolQ(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("cb")
class CommitmentBank(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "cb"
......@@ -82,7 +79,6 @@ class CommitmentBank(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("copa")
class Copa(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "copa"
......@@ -124,8 +120,6 @@ class Copa(NLP_TASK):
return choice[0].lower() + choice[1:]
@TASK_REGISTRY.register("wic")
class WordsInContext(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "wic"
......@@ -163,7 +157,6 @@ class WordsInContext(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("wsc")
class WinogradSchemaChallenge(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "wsc"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment