model_utils.py 1.21 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
from typing import Union

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import torch.nn as nn

6
7
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
8
from cacheflow.models.opt import OPTForCausalLM
9
from cacheflow.models.utils import get_torch_dtype
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12

_MODELS = {
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
15
    'opt': OPTForCausalLM,
}

16
17
_MEMORY_ANALYZERS = {
    'opt': OPTMemoryAnalyzer,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
}

Woosuk Kwon's avatar
Woosuk Kwon committed
20

Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
def get_model(
    model_name: str,
    dtype: Union[torch.dtype, str],
) -> nn.Module:
25
26
    torch_dtype = get_torch_dtype(dtype)
    for model_class, hf_model in _MODELS.items():
Woosuk Kwon's avatar
Woosuk Kwon committed
27
        if model_class in model_name:
28
29
            model = hf_model.from_pretrained(
                model_name, torch_dtype=torch_dtype)
30
            return model.eval()
31
    raise ValueError(f'Unsupported model name: {model_name}')
32
33


34
35
36
37
38
39
40
41
42
43
44
def get_memory_analyzer(
    model_name: str,
    block_size: int,
    dtype: Union[torch.dtype, str],
) -> CacheFlowMemoryAnalyzer:
    torch_dtype = get_torch_dtype(dtype)
    for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
        if model_class in model_name:
            return memory_analyzer(
                model_name, block_size, torch_dtype)
    raise ValueError(f'Unsupported model name: {model_name}')