model_loader.py 4.39 KB
Newer Older
1
"""Utilities for selecting and loading models."""
2
import contextlib
yhu422's avatar
yhu422 committed
3
from typing import Tuple, Type
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
import torch.nn as nn

8
from vllm.config import DeviceConfig, ModelConfig
9
from vllm.model_executor.models import ModelRegistry
10
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
11
12
from vllm.model_executor.weight_utils import (get_quant_config,
                                              initialize_dummy_weights)
13

14
15
16
17
_VISION_MODEL_CLASSES = [
    LlavaForConditionalGeneration,
]

18

19
20
21
22
23
24
25
26
27
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


yhu422's avatar
yhu422 committed
28
29
def _get_model_architecture(
        model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
30
31
32
33
34
35
36
    architectures = getattr(model_config.hf_config, "architectures", [])
    # Special handling for quantized Mixtral.
    # FIXME(woosuk): This is a temporary hack.
    if (model_config.quantization is not None
            and "MixtralForCausalLM" in architectures):
        architectures = ["QuantMixtralForCausalLM"]

Woosuk Kwon's avatar
Woosuk Kwon committed
37
    for arch in architectures:
38
39
        model_cls = ModelRegistry.load_model_cls(arch)
        if model_cls is not None:
yhu422's avatar
yhu422 committed
40
            return (model_cls, arch)
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
43
        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45


yhu422's avatar
yhu422 committed
46
47
48
49
def get_architecture_class_name(model_config: ModelConfig) -> str:
    return _get_model_architecture(model_config)[1]


50
51
52
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
              **kwargs) -> nn.Module:
    lora_config = kwargs.get("lora_config", None)
53
    vision_language_config = kwargs.get("vision_language_config", None)
yhu422's avatar
yhu422 committed
54
    model_class = _get_model_architecture(model_config)[0]
55

56
57
    # Get the (maybe quantized) linear method.
    linear_method = None
58
    if model_config.quantization is not None:
59
        quant_config = get_quant_config(model_config)
60
61
62
63
64
65
66
67
        capability = torch.cuda.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        if capability < quant_config.get_min_capability():
            raise ValueError(
                f"The quantization method {model_config.quantization} is not "
                "supported for the current GPU. "
                f"Minimum capability: {quant_config.get_min_capability()}. "
                f"Current capability: {capability}.")
68
69
70
71
72
73
        supported_dtypes = quant_config.get_supported_act_dtypes()
        if model_config.dtype not in supported_dtypes:
            raise ValueError(
                f"{model_config.dtype} is not supported for quantization "
                f"method {model_config.quantization}. Supported dtypes: "
                f"{supported_dtypes}")
74
        linear_method = quant_config.get_linear_method()
75

76
77
78
    with _set_default_torch_dtype(model_config.dtype):
        # Create a model instance.
        # The weights will be initialized as empty tensors.
79
        with torch.device(device_config.device):
Terry's avatar
Terry committed
80
            if hasattr(model_class, "supported_lora_modules"):
81
82
83
84
85
86
87
88
89
                model = model_class(model_config.hf_config, linear_method,
                                    lora_config)
            elif lora_config:
                raise ValueError(
                    f"Model {model_class.__name__} does not support LoRA, "
                    "but LoRA is enabled. Support for this model may "
                    "be added in the future. If this is important to you, "
                    "please open an issue on github.")
            else:
90
91
92
93
94
                if model_class not in _VISION_MODEL_CLASSES:
                    model = model_class(model_config.hf_config, linear_method)
                else:
                    model = model_class(model_config.hf_config,
                                        vision_language_config, linear_method)
95
        if model_config.load_format == "dummy":
96
97
98
99
100
101
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        else:
            # Load the weights from the cached or downloaded files.
            model.load_weights(model_config.model, model_config.download_dir,
Jasmond L's avatar
Jasmond L committed
102
                               model_config.load_format, model_config.revision)
103
    return model.eval()