model_utils.py 2.42 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
from typing import Union

Zhuohan Li's avatar
Zhuohan Li committed
3
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch.nn as nn
Zhuohan Li's avatar
Zhuohan Li committed
6
from transformers import AutoConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
10
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from cacheflow.models.llama import LlamaForCausalLM
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
12
from cacheflow.models.opt import OPTForCausalLM
13
from cacheflow.models.utils import get_torch_dtype
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
16

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
17
    'llama': LlamaForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
    'opt': OPTForCausalLM,
}

21
_MEMORY_ANALYZERS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    'llama': LlamaMemoryAnalyzer,
23
    'opt': OPTMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
}

Woosuk Kwon's avatar
Woosuk Kwon committed
26

Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
def get_model(
    model_name: str,
    dtype: Union[torch.dtype, str],
Zhuohan Li's avatar
Zhuohan Li committed
30
    path: str,
31
    use_dummy_weights: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
32
) -> nn.Module:
33
    torch_dtype = get_torch_dtype(dtype)
Zhuohan Li's avatar
Zhuohan Li committed
34
35
36
37
    torch.set_default_dtype(torch_dtype)
    config = AutoConfig.from_pretrained(model_name)
    for model_class_name, model_class in _MODELS.items():
        if model_class_name in model_name:
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
            if use_dummy_weights:
                # Create a model instance.
                # The weights will be initialized as empty tensors.
                model = model_class(config)
                model = model.cuda()
                # NOTE(woosuk): For precise performance evaluation, we assign
                # random values to the weights. 
                model.initialize_dummy_weights()
            else:
                # Download model weights if it's not cached.
                weights_dir = model_class.get_weights(model_name, path=path)
                # Create a model instance.
                model = model_class(config)
                # Load the weights from the cached or downloaded files.
                model.load_weights(weights_dir)
                model = model.cuda()
Zhuohan Li's avatar
Zhuohan Li committed
54
            return model.eval(), torch_dtype
55
    raise ValueError(f'Unsupported model name: {model_name}')
56
57


58
59
60
61
def get_memory_analyzer(
    model_name: str,
    block_size: int,
    dtype: Union[torch.dtype, str],
62
63
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
64
    tensor_parallel_size: int = 1,
65
66
67
68
69
) -> CacheFlowMemoryAnalyzer:
    torch_dtype = get_torch_dtype(dtype)
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
70
71
                model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
                tensor_parallel_size)
72
    raise ValueError(f'Unsupported model name: {model_name}')