__init__.py 225 Bytes
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
from . import gpt2
from . import gpt3
Leo Gao's avatar
Leo Gao committed
3
from . import dummy
Jason Phang's avatar
gpt3  
Jason Phang committed
4

Jason Phang's avatar
Jason Phang committed
5
MODEL_REGISTRY = {
Jason Phang's avatar
Jason Phang committed
6
7
    "gpt2": gpt2.GPT2LM,
    "gpt3": gpt3.GPT3LM,
Leo Gao's avatar
Leo Gao committed
8
    "dummy": dummy.DummyLM,
Jason Phang's avatar
Jason Phang committed
9
}
Jason Phang's avatar
gpt3  
Jason Phang committed
10
11
12


def get_model(model_name):
Jason Phang's avatar
Jason Phang committed
13
    return MODEL_REGISTRY[model_name]