from .gpt2 import GPT2LM from .gpt3 import GPT3LM
MODEL_REGISTRY = { "gpt2": GPT2LM, "gpt3": GPT3LM, }
def get_model(model_name):
return MODEL_REGISTRY[model_name]