model_utils.py 1.75 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
from typing import Union

Zhuohan Li's avatar
Zhuohan Li committed
3
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch.nn as nn
Zhuohan Li's avatar
Zhuohan Li committed
6
from transformers import AutoConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
10
from cacheflow.models.opt import OPTForCausalLM
11
from cacheflow.models.utils import get_torch_dtype
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
14

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
    'opt': OPTForCausalLM,
}

18
19
_MEMORY_ANALYZERS = {
    'opt': OPTMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
}

Woosuk Kwon's avatar
Woosuk Kwon committed
22

Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25
def get_model(
    model_name: str,
    dtype: Union[torch.dtype, str],
Zhuohan Li's avatar
Zhuohan Li committed
26
    path: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
27
) -> nn.Module:
28
    torch_dtype = get_torch_dtype(dtype)
Zhuohan Li's avatar
Zhuohan Li committed
29
30
31
32
33
34
35
36
37
38
39
    torch.set_default_dtype(torch_dtype)
    config = AutoConfig.from_pretrained(model_name)
    for model_class_name, model_class in _MODELS.items():
        if model_class_name in model_name:
            # Download model weights if it's not cached.
            weights_dir = model_class.download_weights(model_name, path=path)
            # Create a model instance.
            model = model_class(config)
            # Load the weights from the cached or downloaded files.
            model.load_weights(weights_dir)
            return model.eval(), torch_dtype
40
    raise ValueError(f'Unsupported model name: {model_name}')
41
42


43
44
45
46
def get_memory_analyzer(
    model_name: str,
    block_size: int,
    dtype: Union[torch.dtype, str],
47
48
    gpu_memory: int,
    cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
49
    tensor_parallel_size: int = 1,
50
51
52
53
54
) -> CacheFlowMemoryAnalyzer:
    torch_dtype = get_torch_dtype(dtype)
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
55
56
                model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
                tensor_parallel_size)
57
    raise ValueError(f'Unsupported model name: {model_name}')