model_utils.py 2.77 KB
Newer Older
1
from typing import Union, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch.nn as nn
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import AutoConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
9
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
10
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
11
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from cacheflow.models.gpt2 import GPT2LMHeadModel
13
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from cacheflow.models.llama import LlamaForCausalLM
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
15
from cacheflow.models.opt import OPTForCausalLM
16
from cacheflow.models.utils import get_torch_dtype
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
19

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    'gpt2': GPT2LMHeadModel,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    'llama': LlamaForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    'opt': OPTForCausalLM,
23
24
    'stablelm': GPTNeoXForCausalLM,
    'pythia': GPTNeoXForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
}

27
_MEMORY_ANALYZERS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    'gpt2': GPT2MemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
29
    'llama': LlamaMemoryAnalyzer,
30
    'opt': OPTMemoryAnalyzer,
31
32
    'stablelm': GPTNeoXMemoryAnalyzer,
    'pythia': GPTNeoXMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
}

Woosuk Kwon's avatar
Woosuk Kwon committed
35

Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
def get_model(
    model_name: str,
    dtype: Union[torch.dtype, str],
39
    cache_dir: Optional[str],
40
    use_dummy_weights: bool,
41
    use_np_cache: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
42
) -> nn.Module:
43
    torch_dtype = get_torch_dtype(dtype)
Zhuohan Li's avatar
Zhuohan Li committed
44
45
46
47
    torch.set_default_dtype(torch_dtype)
    config = AutoConfig.from_pretrained(model_name)
    for model_class_name, model_class in _MODELS.items():
        if model_class_name in model_name:
48
49
50
51
52
53
            if use_dummy_weights:
                # Create a model instance.
                # The weights will be initialized as empty tensors.
                model = model_class(config)
                model = model.cuda()
                # NOTE(woosuk): For precise performance evaluation, we assign
54
                # random values to the weights.
55
56
57
58
59
                model.initialize_dummy_weights()
            else:
                # Create a model instance.
                model = model_class(config)
                # Load the weights from the cached or downloaded files.
60
                model.load_weights(model_name, cache_dir, use_np_cache)
61
                model = model.cuda()
Zhuohan Li's avatar
Zhuohan Li committed
62
            return model.eval(), torch_dtype
63
    raise ValueError(f'Unsupported model name: {model_name}')
64
65


66
67
68
69
def get_memory_analyzer(
    model_name: str,
    block_size: int,
    dtype: Union[torch.dtype, str],
70
71
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
72
    tensor_parallel_size: int = 1,
73
74
75
76
77
) -> CacheFlowMemoryAnalyzer:
    torch_dtype = get_torch_dtype(dtype)
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
78
79
                model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
                tensor_parallel_size)
80
    raise ValueError(f'Unsupported model name: {model_name}')