__init__.py 382 Bytes
Newer Older
Jason Phang's avatar
Jason Phang committed
1
from . import gpt2
Tian Yun's avatar
Tian Yun committed
2
from . import gptj
Jason Phang's avatar
Jason Phang committed
3
from . import gpt3
Tian Yun's avatar
Tian Yun committed
4
5
from . import t5
from . import t0
Leo Gao's avatar
Leo Gao committed
6
from . import dummy
Jason Phang's avatar
gpt3  
Jason Phang committed
7

Jason Phang's avatar
Jason Phang committed
8
MODEL_REGISTRY = {
9
    "hf": gpt2.HFLM,
Jason Phang's avatar
Jason Phang committed
10
    "gpt2": gpt2.GPT2LM,
Tian Yun's avatar
Tian Yun committed
11
    "gptj": gptj.GPTJLM,
Jason Phang's avatar
Jason Phang committed
12
    "gpt3": gpt3.GPT3LM,
Tian Yun's avatar
Tian Yun committed
13
14
15
    "t5": t5.T5LM,
    "mt5": t5.T5LM,
    "t0": t0.T0LM,
Leo Gao's avatar
Leo Gao committed
16
    "dummy": dummy.DummyLM,
Jason Phang's avatar
Jason Phang committed
17
}
Jason Phang's avatar
gpt3  
Jason Phang committed
18
19
20


def get_model(model_name):
Jason Phang's avatar
Jason Phang committed
21
    return MODEL_REGISTRY[model_name]