memory_analyzer.py 9.61 KB
Newer Older
1
2
3
4
5
import torch
from transformers import AutoConfig

from cacheflow.models.utils import get_dtype_size

Woosuk Kwon's avatar
Woosuk Kwon committed
6
_GiB = 1 << 30
7
8
9
10
11
12
13
14
15
16
17


class CacheFlowMemoryAnalyzer:

    def get_max_num_gpu_blocks(
        self,
        max_num_batched_tokens: int,
        memory_utilization: float,
    ) -> int:
        raise NotImplementedError()

Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
21
22
23
    def get_workspace_size(self) -> int:
        return 1 * _GiB

    def get_cache_block_size(self) -> int:
        raise NotImplementedError()

24
25
    def get_max_num_cpu_blocks(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
26
        swap_space: int,
27
    ) -> int:
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        swap_space = swap_space * _GiB
        cpu_memory = self.cpu_memory
        if swap_space > 0.8 * cpu_memory:
            raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
                             'takes more than 80% of the available memory '
                             f'({cpu_memory / _GiB:.2f} GiB).'
                             'Please check the swap space size.')
        if swap_space > 0.5 * cpu_memory:
            print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
                  'takes more than 50% of the available memory '
                  f'({cpu_memory / _GiB:.2f} GiB).'
                  'This may slow the system performance.')
        max_num_blocks = swap_space // self.get_cache_block_size()
        return max_num_blocks
42
43
44
45
46
47
48
49
50


class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):

    def __init__(
        self,
        model_name: str,
        block_size: int,
        dtype: torch.dtype,
51
52
        gpu_memory: int,
        cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
53
        tensor_parallel_size: int,
54
55
56
57
    ) -> None:
        self.model_name = model_name
        self.block_size = block_size
        self.dtype = dtype
58
59
        self.gpu_memory = gpu_memory
        self.cpu_memory = cpu_memory
Zhuohan Li's avatar
Zhuohan Li committed
60
        self.tensor_parallel_size = tensor_parallel_size
61
62
63
64
65
66
67
68
69
70
71
72

        config = AutoConfig.from_pretrained(model_name)
        self.num_layers = config.num_hidden_layers
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = config.hidden_size // self.num_heads
        self.ffn_size = config.ffn_dim
        self.embedding_size = config.word_embed_proj_dim
        self.vocab_size = config.vocab_size
        self.max_position = config.max_position_embeddings

    def _get_param_size(self) -> int:
Zhuohan Li's avatar
Zhuohan Li committed
73
        word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        if self.embedding_size != self.hidden_size:
75
            # Project in/out.
Woosuk Kwon's avatar
Woosuk Kwon committed
76
            word_embedding += 2 * self.embedding_size * self.hidden_size
77
78
79
        position_embedding = self.max_position * self.hidden_size

        ln1 = 2 * self.hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
80
81
82
83
        q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
84
85
86
        mha = ln1 + q + k + v + out

        ln2 = 2 * self.hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
87
88
        ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
        ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
89
90
        ffn = ln2 + ffn1 + ffn2

Zhuohan Li's avatar
Zhuohan Li committed
91
        total = (word_embedding + position_embedding +
92
93
94
95
96
97
98
99
100
                 self.num_layers * (mha + ffn))
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

    def _get_max_act_size(
        self,
        max_num_batched_tokens: int,
    ) -> int:
        # NOTE: We approxmiately calculate the maximum activation size by
Zhuohan Li's avatar
Zhuohan Li committed
101
102
103
        # estimating
        # 1) the maximum activation tensor size during inference
        # 2) the residual tensor size during inference
104
105
        # Here, we assume that FlashAttention is used and
        # thus the attention maps are never materialized in GPU DRAM.
Zhuohan Li's avatar
Zhuohan Li committed
106
107
108
109
110
        residual = max_num_batched_tokens * self.hidden_size
        qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
        ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
        # Double the activation size for input and output.
        max_act = 2 * (max(qkv, ffn) + residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
        # Size of output logits.
        output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
        max_act = max(max_act, output_logits)
114
115
116
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * max_act

Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
    def get_cache_block_size(self) -> int:
        key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
        value_cache_block = key_cache_block
120
121
122
123
124
125
126
127
128
129
        total = self.num_layers * (key_cache_block + value_cache_block)
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

    def get_max_num_gpu_blocks(
        self,
        max_num_batched_tokens: int,
        memory_utilization: float = 0.95,
    ) -> int:
        # NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
130
        usable_memory = int(memory_utilization * self.gpu_memory)
131
132
133

        param_size = self._get_param_size()
        act_size = self._get_max_act_size(max_num_batched_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        workspace_size = self.get_workspace_size()
135
136

        max_cache_size = usable_memory - (param_size + act_size + workspace_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
        if max_cache_size <= 0:
            raise RuntimeError('Not enough GPU memory.')
        max_num_blocks = max_cache_size // self.get_cache_block_size()
140
141
        return max_num_blocks

Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145

class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):

    def __init__(
146
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        model_name: str,
        block_size: int,
        dtype: torch.dtype,
        gpu_memory: int,
        cpu_memory: int,
        tensor_parallel_size: int,
    ) -> None:
        self.model_name = model_name
        self.block_size = block_size
        self.dtype = dtype
        self.gpu_memory = gpu_memory
        self.cpu_memory = cpu_memory
        self.tensor_parallel_size = tensor_parallel_size

        config = AutoConfig.from_pretrained(model_name)
        self.num_layers = config.num_hidden_layers
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = config.hidden_size // self.num_heads
        self.ffn_size = config.intermediate_size
        self.vocab_size = config.vocab_size
168
        self.max_position = 8192
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    def _get_param_size(self) -> int:
        word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
        position_embedding = self.max_position * self.hidden_size

        # NOTE: LLaMA does not have bias terms.
        ln1 = self.hidden_size
        q = self.hidden_size * self.hidden_size // self.tensor_parallel_size
        k = self.hidden_size * self.hidden_size // self.tensor_parallel_size
        v = self.hidden_size * self.hidden_size // self.tensor_parallel_size
        out = self.hidden_size * self.hidden_size // self.tensor_parallel_size
        # Rotary embedding.
        # TODO(woosuk): Share the rotary embedding between layers.
        rot = self.max_position * self.head_size
        mha = ln1 + q + k + v + out + rot

        ln2 = self.hidden_size
        gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size
        down = self.ffn_size * self.hidden_size // self.tensor_parallel_size
        up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
        ffn = ln2 + gate + down + up

        total = (word_embedding + position_embedding + self.num_layers * (mha + ffn))
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

    def _get_max_act_size(
        self,
        max_num_batched_tokens: int,
198
    ) -> int:
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # NOTE: We approxmiately calculate the maximum activation size by
        # estimating
        # 1) the maximum activation tensor size during inference
        # 2) the residual tensor size during inference
        # Here, we assume that FlashAttention is used and
        # thus the attention maps are never materialized in GPU DRAM.
        residual = max_num_batched_tokens * self.hidden_size
        qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
        ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
        # Double the activation size for input and output.
        max_act = 2 * (max(qkv, ffn) + residual)
        # Size of output logits.
        output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
        max_act = max(max_act, output_logits)
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * max_act

    def get_cache_block_size(self) -> int:
        key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
        value_cache_block = key_cache_block
        total = self.num_layers * (key_cache_block + value_cache_block)
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

    def get_max_num_gpu_blocks(
        self,
        max_num_batched_tokens: int,
        memory_utilization: float = 0.95,
    ) -> int:
        # NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
        gpu_memory = self.gpu_memory
        usable_memory = int(memory_utilization * gpu_memory)

        param_size = self._get_param_size()
        act_size = self._get_max_act_size(max_num_batched_tokens)
        workspace_size = self.get_workspace_size()

        max_cache_size = usable_memory - (param_size + act_size + workspace_size)
        if max_cache_size <= 0:
            raise RuntimeError('Not enough GPU memory.')
        max_num_blocks = max_cache_size // self.get_cache_block_size()
240
        return max_num_blocks