model_utils.py 2.6 KB
Newer Older
1
from typing import Union, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch.nn as nn
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import AutoConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
8
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
10
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
11
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from cacheflow.models.llama import LlamaForCausalLM
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
13
from cacheflow.models.opt import OPTForCausalLM
14
from cacheflow.models.utils import get_torch_dtype
Woosuk Kwon's avatar
Woosuk Kwon committed
15

16
17

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
18
    'llama': LlamaForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
19
    'opt': OPTForCausalLM,
20
21
    'stablelm': GPTNeoXForCausalLM,
    'pythia': GPTNeoXForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
}

24
_MEMORY_ANALYZERS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
25
    'llama': LlamaMemoryAnalyzer,
26
    'opt': OPTMemoryAnalyzer,
27
28
    'stablelm': GPTNeoXMemoryAnalyzer,
    'pythia': GPTNeoXMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
}

Woosuk Kwon's avatar
Woosuk Kwon committed
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
def get_model(
    model_name: str,
    dtype: Union[torch.dtype, str],
35
    cache_dir: Optional[str],
36
    use_dummy_weights: bool,
37
    use_np_cache: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
) -> nn.Module:
39
    torch_dtype = get_torch_dtype(dtype)
Zhuohan Li's avatar
Zhuohan Li committed
40
41
42
43
    torch.set_default_dtype(torch_dtype)
    config = AutoConfig.from_pretrained(model_name)
    for model_class_name, model_class in _MODELS.items():
        if model_class_name in model_name:
44
45
46
47
48
49
            if use_dummy_weights:
                # Create a model instance.
                # The weights will be initialized as empty tensors.
                model = model_class(config)
                model = model.cuda()
                # NOTE(woosuk): For precise performance evaluation, we assign
50
                # random values to the weights.
51
52
53
54
55
                model.initialize_dummy_weights()
            else:
                # Create a model instance.
                model = model_class(config)
                # Load the weights from the cached or downloaded files.
56
                model.load_weights(model_name, cache_dir, use_np_cache)
57
                model = model.cuda()
Zhuohan Li's avatar
Zhuohan Li committed
58
            return model.eval(), torch_dtype
59
    raise ValueError(f'Unsupported model name: {model_name}')
60
61


62
63
64
65
def get_memory_analyzer(
    model_name: str,
    block_size: int,
    dtype: Union[torch.dtype, str],
66
67
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
68
    tensor_parallel_size: int = 1,
69
70
71
72
73
) -> CacheFlowMemoryAnalyzer:
    torch_dtype = get_torch_dtype(dtype)
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
74
75
                model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
                tensor_parallel_size)
76
    raise ValueError(f'Unsupported model name: {model_name}')