__init__.py 465 Bytes
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
from . import gpt2
from . import gpt3
bzantium's avatar
bzantium committed
3
4
from . import huggingface
from . import textsynth
Leo Gao's avatar
Leo Gao committed
5
from . import dummy
Jason Phang's avatar
gpt3  
Jason Phang committed
6

Jason Phang's avatar
Jason Phang committed
7
MODEL_REGISTRY = {
8
    "hf": gpt2.HFLM,
bzantium's avatar
bzantium committed
9
10
11
    "hf-causal": gpt2.HFLM,
    "hf-causal-experimental": huggingface.AutoCausalLM,
    "hf-seq2seq": huggingface.AutoSeq2SeqLM,
Jason Phang's avatar
Jason Phang committed
12
13
    "gpt2": gpt2.GPT2LM,
    "gpt3": gpt3.GPT3LM,
bzantium's avatar
bzantium committed
14
    "textsynth": textsynth.TextSynthLM,
Leo Gao's avatar
Leo Gao committed
15
    "dummy": dummy.DummyLM,
Jason Phang's avatar
Jason Phang committed
16
}
Jason Phang's avatar
gpt3  
Jason Phang committed
17
18
19


def get_model(model_name):
Jason Phang's avatar
Jason Phang committed
20
    return MODEL_REGISTRY[model_name]