model_utils.py 2.66 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
from typing import Union

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch.nn as nn
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import AutoConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
8
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
10
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
11
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from cacheflow.models.llama import LlamaForCausalLM
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
13
from cacheflow.models.opt import OPTForCausalLM
14
from cacheflow.models.utils import get_torch_dtype
Woosuk Kwon's avatar
Woosuk Kwon committed
15

16
17

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
18
    'llama': LlamaForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
19
    'opt': OPTForCausalLM,
20
21
    'stablelm': GPTNeoXForCausalLM,
    'pythia': GPTNeoXForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
}

24
_MEMORY_ANALYZERS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
25
    'llama': LlamaMemoryAnalyzer,
26
    'opt': OPTMemoryAnalyzer,
27
28
    'stablelm': GPTNeoXMemoryAnalyzer,
    'pythia': GPTNeoXMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
}

Woosuk Kwon's avatar
Woosuk Kwon committed
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
def get_model(
    model_name: str,
    dtype: Union[torch.dtype, str],
Zhuohan Li's avatar
Zhuohan Li committed
35
    path: str,
36
    use_dummy_weights: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
) -> nn.Module:
38
    torch_dtype = get_torch_dtype(dtype)
Zhuohan Li's avatar
Zhuohan Li committed
39
40
41
42
    torch.set_default_dtype(torch_dtype)
    config = AutoConfig.from_pretrained(model_name)
    for model_class_name, model_class in _MODELS.items():
        if model_class_name in model_name:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
            if use_dummy_weights:
                # Create a model instance.
                # The weights will be initialized as empty tensors.
                model = model_class(config)
                model = model.cuda()
                # NOTE(woosuk): For precise performance evaluation, we assign
                # random values to the weights. 
                model.initialize_dummy_weights()
            else:
                # Download model weights if it's not cached.
                weights_dir = model_class.get_weights(model_name, path=path)
                # Create a model instance.
                model = model_class(config)
                # Load the weights from the cached or downloaded files.
                model.load_weights(weights_dir)
                model = model.cuda()
Zhuohan Li's avatar
Zhuohan Li committed
59
            return model.eval(), torch_dtype
60
    raise ValueError(f'Unsupported model name: {model_name}')
61
62


63
64
65
66
def get_memory_analyzer(
    model_name: str,
    block_size: int,
    dtype: Union[torch.dtype, str],
67
68
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
69
    tensor_parallel_size: int = 1,
70
71
72
73
74
) -> CacheFlowMemoryAnalyzer:
    torch_dtype = get_torch_dtype(dtype)
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
75
76
                model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
                tensor_parallel_size)
77
    raise ValueError(f'Unsupported model name: {model_name}')