__init__.py 294 Bytes
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
from . import gpt2
from . import gpt3
Jonathan Tow's avatar
Jonathan Tow committed
3
from . import textsynth
Leo Gao's avatar
Leo Gao committed
4
from . import dummy
Jason Phang's avatar
gpt3  
Jason Phang committed
5

Jason Phang's avatar
Jason Phang committed
6
MODEL_REGISTRY = {
7
8
    "hf-causal": gpt2.HFLM,
    "openai": gpt3.GPT3LM,
Jonathan Tow's avatar
Jonathan Tow committed
9
    "textsynth": textsynth.TextSynthLM,
Leo Gao's avatar
Leo Gao committed
10
    "dummy": dummy.DummyLM,
Jason Phang's avatar
Jason Phang committed
11
}
Jason Phang's avatar
gpt3  
Jason Phang committed
12
13
14


def get_model(model_name):
Jason Phang's avatar
Jason Phang committed
15
    return MODEL_REGISTRY[model_name]