model_loader.py 3.66 KB
Newer Older
1
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch.nn as nn
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import AutoConfig
6
from transformers import PretrainedConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
10
11
12
13
14
from cacheflow.model_executor.memory_analyzer import (
    CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
    LlamaMemoryAnalyzer, OPTMemoryAnalyzer)
from cacheflow.model_executor.models import (
    GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
Woosuk Kwon's avatar
Woosuk Kwon committed
15

16
17

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
18
    'gpt2': GPT2LMHeadModel,
Woosuk Kwon's avatar
Woosuk Kwon committed
19
    'llama': LlamaForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    'opt': OPTForCausalLM,
21
22
    'stablelm': GPTNeoXForCausalLM,
    'pythia': GPTNeoXForCausalLM,
23
    'dolly-v2': GPTNeoXForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
}

26
_MEMORY_ANALYZERS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
27
    'gpt2': GPT2MemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    'llama': LlamaMemoryAnalyzer,
29
    'opt': OPTMemoryAnalyzer,
30
31
    'stablelm': GPTNeoXMemoryAnalyzer,
    'pythia': GPTNeoXMemoryAnalyzer,
32
    'dolly-v2': GPTNeoXMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
}

Woosuk Kwon's avatar
Woosuk Kwon committed
35

36
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
39
40
41
    # NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, 'torch_dtype', None)
    if config_dtype is None:
        config_dtype = torch.float32
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    if dtype == 'default':
        if config_dtype == torch.float32:
            # Following the common practice, we use float16 for float32 models.
            torch_dtype = torch.float16
        else:
            torch_dtype = config_dtype
    else:
        torch_dtype = get_torch_dtype(dtype)
        if torch_dtype != config_dtype and config_dtype != torch.float32:
            # TODO(woosuk): Allow using float16 for bfloat16 models and
            # vice versa. Print a warning message and continue.
            raise ValueError(
                f'Cannot use {torch_dtype} for {config_dtype} model.')
    return torch_dtype


Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
def get_model(
    model_name: str,
60
    dtype: str,
61
    cache_dir: Optional[str],
62
    use_dummy_weights: bool,
63
    use_np_cache: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
64
) -> nn.Module:
Zhuohan Li's avatar
Zhuohan Li committed
65
    config = AutoConfig.from_pretrained(model_name)
66
67
    torch_dtype = _get_dtype(config, dtype)
    torch.set_default_dtype(torch_dtype)
Zhuohan Li's avatar
Zhuohan Li committed
68
69
    for model_class_name, model_class in _MODELS.items():
        if model_class_name in model_name:
70
71
72
73
74
75
            if use_dummy_weights:
                # Create a model instance.
                # The weights will be initialized as empty tensors.
                model = model_class(config)
                model = model.cuda()
                # NOTE(woosuk): For precise performance evaluation, we assign
76
                # random values to the weights.
77
                initialize_dummy_weights(model)
78
79
80
81
            else:
                # Create a model instance.
                model = model_class(config)
                # Load the weights from the cached or downloaded files.
82
                model.load_weights(model_name, cache_dir, use_np_cache)
83
                model = model.cuda()
Zhuohan Li's avatar
Zhuohan Li committed
84
            return model.eval(), torch_dtype
85
    raise ValueError(f'Unsupported model name: {model_name}')
86
87


88
89
90
def get_memory_analyzer(
    model_name: str,
    block_size: int,
91
    dtype: str,
92
93
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
94
    tensor_parallel_size: int = 1,
95
) -> CacheFlowMemoryAnalyzer:
96
97
    config = AutoConfig.from_pretrained(model_name)
    torch_dtype = _get_dtype(config, dtype)
98
99
100
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
101
102
                model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
                tensor_parallel_size)
103
    raise ValueError(f'Unsupported model name: {model_name}')