memory_analyzer.py 5.31 KB
Newer Older
1
2
3
4
5
import torch
from transformers import AutoConfig

from cacheflow.models.utils import get_dtype_size

Woosuk Kwon's avatar
Woosuk Kwon committed
6
_GiB = 1 << 30
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


class CacheFlowMemoryAnalyzer:

    def get_max_num_gpu_blocks(
        self,
        max_num_batched_tokens: int,
        memory_utilization: float,
    ) -> int:
        raise NotImplementedError()

    def get_max_num_cpu_blocks(
        self,
        memory_utilization: float,
    ) -> int:
        raise NotImplementedError()


class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):

    def __init__(
        self,
        model_name: str,
        block_size: int,
        dtype: torch.dtype,
32
33
        gpu_memory: int,
        cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
34
        tensor_parallel_size: int,
35
36
37
38
    ) -> None:
        self.model_name = model_name
        self.block_size = block_size
        self.dtype = dtype
39
40
        self.gpu_memory = gpu_memory
        self.cpu_memory = cpu_memory
Zhuohan Li's avatar
Zhuohan Li committed
41
        self.tensor_parallel_size = tensor_parallel_size
42
43
44
45
46
47
48
49
50
51
52
53

        config = AutoConfig.from_pretrained(model_name)
        self.num_layers = config.num_hidden_layers
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = config.hidden_size // self.num_heads
        self.ffn_size = config.ffn_dim
        self.embedding_size = config.word_embed_proj_dim
        self.vocab_size = config.vocab_size
        self.max_position = config.max_position_embeddings

    def _get_param_size(self) -> int:
Zhuohan Li's avatar
Zhuohan Li committed
54
        word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
55
56
57
58
59
60
        if self.embedding_size != self.vocab_size:
            # Project in/out.
            word_embedding += 2 * self.embedding_size * self.vocab_size
        position_embedding = self.max_position * self.hidden_size

        ln1 = 2 * self.hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
61
62
63
64
        q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
65
66
67
        mha = ln1 + q + k + v + out

        ln2 = 2 * self.hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
68
69
        ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
        ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
70
71
        ffn = ln2 + ffn1 + ffn2

Zhuohan Li's avatar
Zhuohan Li committed
72
        total = (word_embedding + position_embedding +
73
74
75
76
77
78
79
80
81
                 self.num_layers * (mha + ffn))
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

    def _get_max_act_size(
        self,
        max_num_batched_tokens: int,
    ) -> int:
        # NOTE: We approxmiately calculate the maximum activation size by
Zhuohan Li's avatar
Zhuohan Li committed
82
83
84
        # estimating
        # 1) the maximum activation tensor size during inference
        # 2) the residual tensor size during inference
85
86
        # Here, we assume that FlashAttention is used and
        # thus the attention maps are never materialized in GPU DRAM.
Zhuohan Li's avatar
Zhuohan Li committed
87
88
89
90
91
        residual = max_num_batched_tokens * self.hidden_size
        qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
        ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
        # Double the activation size for input and output.
        max_act = 2 * (max(qkv, ffn) + residual)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * max_act

    def _get_workspace_size(self) -> int:
        return 1 * _GiB

    def _get_cache_block_size(self) -> int:
        key_cache_block = self.block_size * self.num_heads * self.head_size
        value_cache_block = self.block_size * self.num_heads * self.head_size
        total = self.num_layers * (key_cache_block + value_cache_block)
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

    def get_max_num_gpu_blocks(
        self,
        max_num_batched_tokens: int,
        memory_utilization: float = 0.95,
    ) -> int:
        # NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
111
        usable_memory = int(memory_utilization * self.gpu_memory)
112
113
114
115
116
117
118
119
120
121
122

        param_size = self._get_param_size()
        act_size = self._get_max_act_size(max_num_batched_tokens)
        workspace_size = self._get_workspace_size()

        max_cache_size = usable_memory - (param_size + act_size + workspace_size)
        max_num_blocks = max_cache_size // self._get_cache_block_size()
        return max_num_blocks

    def get_max_num_cpu_blocks(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
123
        swap_space: int,
124
    ) -> int:
Woosuk Kwon's avatar
Woosuk Kwon committed
125
        swap_space = swap_space * _GiB
126
        if swap_space > 0.8 * self.cpu_memory:
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
            raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
                             'takes more than 80% of the available memory '
129
                             f'({self.cpu_memory / _GiB:.2f} GiB).'
Woosuk Kwon's avatar
Woosuk Kwon committed
130
                             'Please check the swap space size.')
131
        if swap_space > 0.5 * self.cpu_memory:
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
            print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
                  'takes more than 50% of the available memory '
134
                  f'({self.cpu_memory / _GiB:.2f} GiB).'
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
                  'This may slow the system performance.')
        max_num_blocks = swap_space // self._get_cache_block_size()
137
        return max_num_blocks