__init__.py 179 Bytes
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
from .gpt2 import GPT2LM
from .gpt3 import GPT3LM
Jason Phang's avatar
gpt3  
Jason Phang committed
3

Jason Phang's avatar
Jason Phang committed
4
5
6
7
MODEL_REGISTRY = {
    "gpt2": GPT2LM,
    "gpt3": GPT3LM,
}
Jason Phang's avatar
gpt3  
Jason Phang committed
8
9
10


def get_model(model_name):
Jason Phang's avatar
Jason Phang committed
11
    return MODEL_REGISTRY[model_name]