Commit 515d78b3 authored by Jason Phang's avatar Jason Phang
Browse files

refactor to explicit registries

parent 89de8d7d
from .gpt2 import GPT2LM
from .gpt3 import GPT3LM
from . import gpt2
from . import gpt3
MODEL_REGISTRY = {
"gpt2": GPT2LM,
"gpt3": GPT3LM,
"gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM,
}
......
import json
import random
from lm_eval.base import Dataset
from . import TASK_REGISTRY
@TASK_REGISTRY.register("coqa")
class CoQA(Dataset):
def has_training_docs(self):
return True
......
import numpy as np
from tqdm import auto as tqdm_lib
from . common import NLP_TASK, simple_accuracy_metric, yesno
from . import TASK_REGISTRY
@TASK_REGISTRY.register("boolq")
class BoolQ(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "boolq"
......@@ -38,7 +36,6 @@ class BoolQ(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("cb")
class CommitmentBank(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "cb"
......@@ -82,7 +79,6 @@ class CommitmentBank(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("copa")
class Copa(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "copa"
......@@ -124,7 +120,6 @@ class Copa(NLP_TASK):
return choice[0].lower() + choice[1:]
@TASK_REGISTRY.register("wic")
class WordsInContext(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "wic"
......@@ -162,7 +157,6 @@ class WordsInContext(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("wsc")
class WinogradSchemaChallenge(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "wsc"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment