model_loader.py 2.62 KB
Newer Older
1
"""Utilities for selecting and loading models."""
2
import contextlib
3
4
from typing import Type

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import torch.nn as nn
7
from transformers import PretrainedConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
10
from vllm.model_executor.models import *  # pylint: disable=wildcard-import
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.model_executor.weight_utils import initialize_dummy_weights
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
shunxing1234's avatar
shunxing1234 committed
15
    "AquilaModel": AquilaForCausalLM,
16
17
    "BaiChuanForCausalLM": BaiChuanForCausalLM,  # baichuan-7b
    "BaichuanForCausalLM": BaichuanForCausalLM,  # baichuan-13b
Woosuk Kwon's avatar
Woosuk Kwon committed
18
    "BloomForCausalLM": BloomForCausalLM,
Zhuohan Li's avatar
Zhuohan Li committed
19
    "FalconForCausalLM": FalconForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    "GPT2LMHeadModel": GPT2LMHeadModel,
21
    "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
22
    "GPTJForCausalLM": GPTJForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
23
    "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
Jia Guoqing's avatar
Jia Guoqing committed
24
    "InternLMForCausalLM": InternLMForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
25
    "LlamaForCausalLM": LlamaForCausalLM,
26
27
    "LLaMAForCausalLM": LlamaForCausalLM,  # For decapoda-research/llama-*
    "MPTForCausalLM": MPTForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    "OPTForCausalLM": OPTForCausalLM,
Qing's avatar
Qing committed
29
    "QWenLMHeadModel": QWenLMHeadModel,
Zhuohan Li's avatar
Zhuohan Li committed
30
    "RWForCausalLM": FalconForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
}

33

34
35
36
37
38
39
40
41
42
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


43
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
47
48
49
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        if arch in _MODEL_REGISTRY:
            return _MODEL_REGISTRY[arch]
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
50
        f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52


53
54
def get_model(model_config: ModelConfig) -> nn.Module:
    model_class = _get_model_architecture(model_config.hf_config)
55
56
57
58
    with _set_default_torch_dtype(model_config.dtype):
        # Create a model instance.
        # The weights will be initialized as empty tensors.
        model = model_class(model_config.hf_config)
59
        if model_config.load_format == "dummy":
60
61
62
63
64
65
66
            model = model.cuda()
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        else:
            # Load the weights from the cached or downloaded files.
            model.load_weights(model_config.model, model_config.download_dir,
67
                               model_config.load_format)
68
            model = model.cuda()
69
    return model.eval()