from . import gpt2 from . import gpt3 MODEL_REGISTRY = { "gpt2": gpt2.GPT2LM, "gpt3": gpt3.GPT3LM, } def get_model(model_name): return MODEL_REGISTRY[model_name]