"tests/models/test_layers_utils.py" did not exist on "02a76c2c81915846eb679ce9f24fbe9806e49c20"
model_utils.py 3.66 KB
Newer Older
1
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch.nn as nn
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import AutoConfig
6
from transformers import PretrainedConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
10
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
12
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from cacheflow.models.gpt2 import GPT2LMHeadModel
14
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from cacheflow.models.llama import LlamaForCausalLM
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
16
from cacheflow.models.opt import OPTForCausalLM
17
from cacheflow.models.utils import get_torch_dtype
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
20

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    'gpt2': GPT2LMHeadModel,
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    'llama': LlamaForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
23
    'opt': OPTForCausalLM,
24
25
    'stablelm': GPTNeoXForCausalLM,
    'pythia': GPTNeoXForCausalLM,
26
    'dolly-v2': GPTNeoXForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
}

29
_MEMORY_ANALYZERS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
30
    'gpt2': GPT2MemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
    'llama': LlamaMemoryAnalyzer,
32
    'opt': OPTMemoryAnalyzer,
33
34
    'stablelm': GPTNeoXMemoryAnalyzer,
    'pythia': GPTNeoXMemoryAnalyzer,
35
    'dolly-v2': GPTNeoXMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
}

Woosuk Kwon's avatar
Woosuk Kwon committed
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
    config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
    if dtype == 'default':
        if config_dtype == torch.float32:
            # Following the common practice, we use float16 for float32 models.
            torch_dtype = torch.float16
        else:
            torch_dtype = config_dtype
    else:
        torch_dtype = get_torch_dtype(dtype)
        if torch_dtype != config_dtype and config_dtype != torch.float32:
            # TODO(woosuk): Allow using float16 for bfloat16 models and
            # vice versa. Print a warning message and continue.
            raise ValueError(
                f'Cannot use {torch_dtype} for {config_dtype} model.')
    return torch_dtype


Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
def get_model(
    model_name: str,
59
    dtype: str,
60
    cache_dir: Optional[str],
61
    use_dummy_weights: bool,
62
    use_np_cache: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
63
) -> nn.Module:
Zhuohan Li's avatar
Zhuohan Li committed
64
    config = AutoConfig.from_pretrained(model_name)
65
66
    torch_dtype = _get_dtype(config, dtype)
    torch.set_default_dtype(torch_dtype)
Zhuohan Li's avatar
Zhuohan Li committed
67
68
    for model_class_name, model_class in _MODELS.items():
        if model_class_name in model_name:
69
70
71
72
73
74
            if use_dummy_weights:
                # Create a model instance.
                # The weights will be initialized as empty tensors.
                model = model_class(config)
                model = model.cuda()
                # NOTE(woosuk): For precise performance evaluation, we assign
75
                # random values to the weights.
76
77
78
79
80
                model.initialize_dummy_weights()
            else:
                # Create a model instance.
                model = model_class(config)
                # Load the weights from the cached or downloaded files.
81
                model.load_weights(model_name, cache_dir, use_np_cache)
82
                model = model.cuda()
Zhuohan Li's avatar
Zhuohan Li committed
83
            return model.eval(), torch_dtype
84
    raise ValueError(f'Unsupported model name: {model_name}')
85
86


87
88
89
def get_memory_analyzer(
    model_name: str,
    block_size: int,
90
    dtype: str,
91
92
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
93
    tensor_parallel_size: int = 1,
94
) -> CacheFlowMemoryAnalyzer:
95
96
    config = AutoConfig.from_pretrained(model_name)
    torch_dtype = _get_dtype(config, dtype)
97
98
99
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
100
101
                model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
                tensor_parallel_size)
102
    raise ValueError(f'Unsupported model name: {model_name}')