model_loader.py 2.81 KB
Newer Older
1
"""Utilities for selecting and loading models."""
2
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch.nn as nn
6
from transformers import AutoConfig, PretrainedConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
10
11
from cacheflow.model_executor.models import (
    GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13

Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
17
18
19
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
    "GPT2LMHeadModel": GPT2LMHeadModel,
    "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
    "LlamaForCausalLM": LlamaForCausalLM,
    "OPTForCausalLM": OPTForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
}

Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25
26
27
28
29
30
31
32
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        if arch in _MODEL_REGISTRY:
            return _MODEL_REGISTRY[arch]
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
        f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
    )


33
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
Woosuk Kwon's avatar
Woosuk Kwon committed
34
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
Woosuk Kwon's avatar
Woosuk Kwon committed
35
    # because config.torch_dtype can be None.
Woosuk Kwon's avatar
Woosuk Kwon committed
36
    config_dtype = getattr(config, "torch_dtype", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
    if config_dtype is None:
        config_dtype = torch.float32
Woosuk Kwon's avatar
Woosuk Kwon committed
39
    if dtype == "default":
40
41
42
43
44
45
46
47
48
49
50
        if config_dtype == torch.float32:
            # Following the common practice, we use float16 for float32 models.
            torch_dtype = torch.float16
        else:
            torch_dtype = config_dtype
    else:
        torch_dtype = get_torch_dtype(dtype)
        if torch_dtype != config_dtype and config_dtype != torch.float32:
            # TODO(woosuk): Allow using float16 for bfloat16 models and
            # vice versa. Print a warning message and continue.
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
51
                f"Cannot use {torch_dtype} for {config_dtype} model.")
52
53
54
    return torch_dtype


Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
def get_model(
    model_name: str,
57
    dtype: str,
58
    cache_dir: Optional[str],
59
    use_dummy_weights: bool,
60
    use_np_cache: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
61
) -> nn.Module:
Zhuohan Li's avatar
Zhuohan Li committed
62
    config = AutoConfig.from_pretrained(model_name)
63
64
    torch_dtype = _get_dtype(config, dtype)
    torch.set_default_dtype(torch_dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    model_class = _get_model_architecture(config)

    # Create a model instance.
    # The weights will be initialized as empty tensors.
    model = model_class(config)
    if use_dummy_weights:
        model = model.cuda()
        # NOTE(woosuk): For accurate performance evaluation, we assign
        # random values to the weights.
        initialize_dummy_weights(model)
    else:
        # Load the weights from the cached or downloaded files.
        model.load_weights(model_name, cache_dir, use_np_cache)
        model = model.cuda()
    return model.eval(), torch_dtype
80