model_loader.py 4.04 KB
Newer Older
1
"""Utilities for selecting and loading models."""
2
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch.nn as nn
6
from transformers import AutoConfig, PretrainedConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
10
11
12
13
14
from cacheflow.model_executor.memory_analyzer import (
    CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
    LlamaMemoryAnalyzer, OPTMemoryAnalyzer)
from cacheflow.model_executor.models import (
    GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
Woosuk Kwon's avatar
Woosuk Kwon committed
15

16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
19
20
21
22
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
    "GPT2LMHeadModel": GPT2LMHeadModel,
    "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
    "LlamaForCausalLM": LlamaForCausalLM,
    "OPTForCausalLM": OPTForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
}

25
_MEMORY_ANALYZERS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
    "GPT2LMHeadModel": GPT2MemoryAnalyzer,
    "GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
    "LlamaForCausalLM": LlamaMemoryAnalyzer,
    "OPTForCausalLM": OPTMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
}

Woosuk Kwon's avatar
Woosuk Kwon committed
32

Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        if arch in _MODEL_REGISTRY:
            return _MODEL_REGISTRY[arch]
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
        f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
    )


def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        if arch in _MEMORY_ANALYZERS:
            return _MEMORY_ANALYZERS[arch]
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
        f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
    )


55
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
Woosuk Kwon's avatar
Woosuk Kwon committed
56
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
Woosuk Kwon's avatar
Woosuk Kwon committed
57
    # because config.torch_dtype can be None.
Woosuk Kwon's avatar
Woosuk Kwon committed
58
    config_dtype = getattr(config, "torch_dtype", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
    if config_dtype is None:
        config_dtype = torch.float32
Woosuk Kwon's avatar
Woosuk Kwon committed
61
    if dtype == "default":
62
63
64
65
66
67
68
69
70
71
72
        if config_dtype == torch.float32:
            # Following the common practice, we use float16 for float32 models.
            torch_dtype = torch.float16
        else:
            torch_dtype = config_dtype
    else:
        torch_dtype = get_torch_dtype(dtype)
        if torch_dtype != config_dtype and config_dtype != torch.float32:
            # TODO(woosuk): Allow using float16 for bfloat16 models and
            # vice versa. Print a warning message and continue.
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
73
                f"Cannot use {torch_dtype} for {config_dtype} model.")
74
75
76
    return torch_dtype


Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
def get_model(
    model_name: str,
79
    dtype: str,
80
    cache_dir: Optional[str],
81
    use_dummy_weights: bool,
82
    use_np_cache: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
83
) -> nn.Module:
Zhuohan Li's avatar
Zhuohan Li committed
84
    config = AutoConfig.from_pretrained(model_name)
85
86
    torch_dtype = _get_dtype(config, dtype)
    torch.set_default_dtype(torch_dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    model_class = _get_model_architecture(config)

    # Create a model instance.
    # The weights will be initialized as empty tensors.
    model = model_class(config)
    if use_dummy_weights:
        model = model.cuda()
        # NOTE(woosuk): For accurate performance evaluation, we assign
        # random values to the weights.
        initialize_dummy_weights(model)
    else:
        # Load the weights from the cached or downloaded files.
        model.load_weights(model_name, cache_dir, use_np_cache)
        model = model.cuda()
    return model.eval(), torch_dtype
102
103


104
105
106
def get_memory_analyzer(
    model_name: str,
    block_size: int,
107
    dtype: str,
108
109
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
110
    tensor_parallel_size: int = 1,
111
) -> CacheFlowMemoryAnalyzer:
112
113
    config = AutoConfig.from_pretrained(model_name)
    torch_dtype = _get_dtype(config, dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
117
    memory_analyzer = _get_memory_analyzer(config)
    return memory_analyzer(
        model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
        tensor_parallel_size)