model_utils.py 1.92 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
from typing import Union

Zhuohan Li's avatar
Zhuohan Li committed
3
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch.nn as nn
Zhuohan Li's avatar
Zhuohan Li committed
6
from transformers import AutoConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
9
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
10
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from cacheflow.models.llama import LlamaForCausalLM
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
12
from cacheflow.models.opt import OPTForCausalLM
13
from cacheflow.models.utils import get_torch_dtype
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
16

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
17
    'llama': LlamaForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
    'opt': OPTForCausalLM,
}

21
_MEMORY_ANALYZERS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    'llama': LlamaMemoryAnalyzer,
23
    'opt': OPTMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
}

Woosuk Kwon's avatar
Woosuk Kwon committed
26

Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
def get_model(
    model_name: str,
    dtype: Union[torch.dtype, str],
Zhuohan Li's avatar
Zhuohan Li committed
30
    path: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
) -> nn.Module:
32
    torch_dtype = get_torch_dtype(dtype)
Zhuohan Li's avatar
Zhuohan Li committed
33
34
35
36
37
    torch.set_default_dtype(torch_dtype)
    config = AutoConfig.from_pretrained(model_name)
    for model_class_name, model_class in _MODELS.items():
        if model_class_name in model_name:
            # Download model weights if it's not cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
38
            weights_dir = model_class.get_weights(model_name, path=path)
Zhuohan Li's avatar
Zhuohan Li committed
39
40
41
42
43
            # Create a model instance.
            model = model_class(config)
            # Load the weights from the cached or downloaded files.
            model.load_weights(weights_dir)
            return model.eval(), torch_dtype
44
    raise ValueError(f'Unsupported model name: {model_name}')
45
46


47
48
49
50
def get_memory_analyzer(
    model_name: str,
    block_size: int,
    dtype: Union[torch.dtype, str],
51
52
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
53
    tensor_parallel_size: int = 1,
54
55
56
57
58
) -> CacheFlowMemoryAnalyzer:
    torch_dtype = get_torch_dtype(dtype)
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
59
60
                model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
                tensor_parallel_size)
61
    raise ValueError(f'Unsupported model name: {model_name}')