memory_analyzer.py 15.1 KB
Newer Older
1
2
3
import torch
from transformers import AutoConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
4
from cacheflow.logger import init_logger
5
6
from cacheflow.models.utils import get_dtype_size

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9

logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
10
_GiB = 1 << 30
11
12
13
14
15
16
17
18
19
20
21


class CacheFlowMemoryAnalyzer:

    def get_max_num_gpu_blocks(
        self,
        max_num_batched_tokens: int,
        memory_utilization: float,
    ) -> int:
        raise NotImplementedError()

Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25
26
27
    def get_workspace_size(self) -> int:
        return 1 * _GiB

    def get_cache_block_size(self) -> int:
        raise NotImplementedError()

28
29
    def get_max_num_cpu_blocks(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
30
        swap_space_gib: int,
31
    ) -> int:
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        swap_space = swap_space_gib * _GiB
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
        cpu_memory = self.cpu_memory
        if swap_space > 0.8 * cpu_memory:
Woosuk Kwon's avatar
Woosuk Kwon committed
35
            raise ValueError(f'The swap space ({swap_space_gib:.2f} GiB) '
Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
                             'takes more than 80% of the available memory '
                             f'({cpu_memory / _GiB:.2f} GiB).'
                             'Please check the swap space size.')
        if swap_space > 0.5 * cpu_memory:
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
43
            logger.info(f'WARNING: The swap space ({swap_space_gib:.2f} GiB) '
                        'takes more than 50% of the available memory '
                        f'({cpu_memory / _GiB:.2f} GiB).'
                        'This may slow the system performance.')
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
        max_num_blocks = swap_space // self.get_cache_block_size()
        return max_num_blocks
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def get_param_size(self) -> int:
        raise NotImplementedError()

    def get_max_act_size(self, max_num_batched_tokens: int) -> int:
        raise NotImplementedError()

    def get_cache_block_size(self) -> int:
        key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
        value_cache_block = key_cache_block
        total = self.num_layers * (key_cache_block + value_cache_block)
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

    def get_max_num_gpu_blocks(
        self,
        max_num_batched_tokens: int,
        memory_utilization: float = 0.95,
    ) -> int:
        # NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
        usable_memory = int(memory_utilization * self.gpu_memory)

        param_size = self.get_param_size()
        act_size = self.get_max_act_size(max_num_batched_tokens)
        workspace_size = self.get_workspace_size()

        max_cache_size = usable_memory - (param_size + act_size + workspace_size)
        if max_cache_size <= 0:
            raise RuntimeError('Not enough GPU memory.')
        max_num_blocks = max_cache_size // self.get_cache_block_size()
        return max_num_blocks

78

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class GPT2MemoryAnalyzer(CacheFlowMemoryAnalyzer):

    def __init__(
        self,
        model_name: str,
        block_size: int,
        dtype: torch.dtype,
        gpu_memory: int,
        cpu_memory: int,
        tensor_parallel_size: int,
    ) -> None:
        self.model_name = model_name
        self.block_size = block_size
        self.dtype = dtype
        self.gpu_memory = gpu_memory
        self.cpu_memory = cpu_memory
        self.tensor_parallel_size = tensor_parallel_size

        config = AutoConfig.from_pretrained(model_name)
        self.num_layers = config.num_hidden_layers
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = config.hidden_size // self.num_heads
        self.ffn_size = config.n_inner if config.n_inner is not None else 4 * self.hidden_size
        self.vocab_size = config.vocab_size
        self.max_position = config.max_position_embeddings

    def get_param_size(self) -> int:
        word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
        position_embedding = self.max_position * self.hidden_size

        ln1 = 2 * self.hidden_size
        q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        mha = ln1 + q + k + v + out

        ln2 = 2 * self.hidden_size
        ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
        ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        ffn = ln2 + ffn1 + ffn2

        total = (word_embedding + position_embedding +
                 self.num_layers * (mha + ffn))
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

    def get_max_act_size(
        self,
        max_num_batched_tokens: int,
    ) -> int:
        # NOTE: We approxmiately calculate the maximum activation size by
        # estimating
        # 1) the maximum activation tensor size during inference
        # 2) the residual tensor size during inference
        # Here, we assume that FlashAttention is used and
        # thus the attention maps are never materialized in GPU DRAM.
        residual = max_num_batched_tokens * self.hidden_size
        qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
        ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
        # Double the activation size for input and output.
        max_act = 2 * (max(qkv, ffn) + residual)
        # Size of output logits.
        output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
        max_act = max(max_act, output_logits)
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * max_act


149
150
151
152
153
154
155
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):

    def __init__(
        self,
        model_name: str,
        block_size: int,
        dtype: torch.dtype,
156
157
        gpu_memory: int,
        cpu_memory: int,
Zhuohan Li's avatar
Zhuohan Li committed
158
        tensor_parallel_size: int,
159
160
161
162
    ) -> None:
        self.model_name = model_name
        self.block_size = block_size
        self.dtype = dtype
163
164
        self.gpu_memory = gpu_memory
        self.cpu_memory = cpu_memory
Zhuohan Li's avatar
Zhuohan Li committed
165
        self.tensor_parallel_size = tensor_parallel_size
166
167
168
169
170
171
172
173
174
175
176

        config = AutoConfig.from_pretrained(model_name)
        self.num_layers = config.num_hidden_layers
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = config.hidden_size // self.num_heads
        self.ffn_size = config.ffn_dim
        self.embedding_size = config.word_embed_proj_dim
        self.vocab_size = config.vocab_size
        self.max_position = config.max_position_embeddings

177
    def get_param_size(self) -> int:
Zhuohan Li's avatar
Zhuohan Li committed
178
        word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
179
        if self.embedding_size != self.hidden_size:
180
            # Project in/out.
Woosuk Kwon's avatar
Woosuk Kwon committed
181
            word_embedding += 2 * self.embedding_size * self.hidden_size
182
183
184
        position_embedding = self.max_position * self.hidden_size

        ln1 = 2 * self.hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
185
186
187
188
        q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
189
190
191
        mha = ln1 + q + k + v + out

        ln2 = 2 * self.hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
192
193
        ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
        ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
194
195
        ffn = ln2 + ffn1 + ffn2

Zhuohan Li's avatar
Zhuohan Li committed
196
        total = (word_embedding + position_embedding +
197
198
199
200
                 self.num_layers * (mha + ffn))
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

201
    def get_max_act_size(
202
203
204
205
        self,
        max_num_batched_tokens: int,
    ) -> int:
        # NOTE: We approxmiately calculate the maximum activation size by
Zhuohan Li's avatar
Zhuohan Li committed
206
207
208
        # estimating
        # 1) the maximum activation tensor size during inference
        # 2) the residual tensor size during inference
209
210
        # Here, we assume that we use memory-efficient attention which
        # does not materialize the attention maps in GPU DRAM.
Zhuohan Li's avatar
Zhuohan Li committed
211
212
213
214
215
        residual = max_num_batched_tokens * self.hidden_size
        qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
        ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
        # Double the activation size for input and output.
        max_act = 2 * (max(qkv, ffn) + residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
        # Size of output logits.
        output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
        max_act = max(max_act, output_logits)
219
220
221
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * max_act

Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225

class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):

    def __init__(
226
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        model_name: str,
        block_size: int,
        dtype: torch.dtype,
        gpu_memory: int,
        cpu_memory: int,
        tensor_parallel_size: int,
    ) -> None:
        self.model_name = model_name
        self.block_size = block_size
        self.dtype = dtype
        self.gpu_memory = gpu_memory
        self.cpu_memory = cpu_memory
        self.tensor_parallel_size = tensor_parallel_size

        config = AutoConfig.from_pretrained(model_name)
        self.num_layers = config.num_hidden_layers
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = config.hidden_size // self.num_heads
        self.ffn_size = config.intermediate_size
        self.vocab_size = config.vocab_size
248
        self.max_position = 8192
Woosuk Kwon's avatar
Woosuk Kwon committed
249

250
251
    def get_param_size(self) -> int:
        # NOTE: LLaMA does not tie the two embeddings.
Woosuk Kwon's avatar
Woosuk Kwon committed
252
        word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
253
        lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

        # NOTE: LLaMA does not have bias terms.
        ln1 = self.hidden_size
        q = self.hidden_size * self.hidden_size // self.tensor_parallel_size
        k = self.hidden_size * self.hidden_size // self.tensor_parallel_size
        v = self.hidden_size * self.hidden_size // self.tensor_parallel_size
        out = self.hidden_size * self.hidden_size // self.tensor_parallel_size
        # Rotary embedding.
        # TODO(woosuk): Share the rotary embedding between layers.
        rot = self.max_position * self.head_size
        mha = ln1 + q + k + v + out + rot

        ln2 = self.hidden_size
        gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size
        down = self.ffn_size * self.hidden_size // self.tensor_parallel_size
        up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
        ffn = ln2 + gate + down + up

272
        total = word_embedding + self.num_layers * (mha + ffn) + lm_head
Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
275
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

276
    def get_max_act_size(
Woosuk Kwon's avatar
Woosuk Kwon committed
277
278
        self,
        max_num_batched_tokens: int,
279
    ) -> int:
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
282
283
        # NOTE: We approxmiately calculate the maximum activation size by
        # estimating
        # 1) the maximum activation tensor size during inference
        # 2) the residual tensor size during inference
284
285
        # Here, we assume that we use memory-efficient attention which
        # does not materialize the attention maps in GPU DRAM.
Woosuk Kwon's avatar
Woosuk Kwon committed
286
287
288
289
290
291
292
293
294
295
296
        residual = max_num_batched_tokens * self.hidden_size
        qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
        ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
        # Double the activation size for input and output.
        max_act = 2 * (max(qkv, ffn) + residual)
        # Size of output logits.
        output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
        max_act = max(max_act, output_logits)
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * max_act

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):

    def __init__(
        self,
        model_name: str,
        block_size: int,
        dtype: torch.dtype,
        gpu_memory: int,
        cpu_memory: int,
        tensor_parallel_size: int,
    ) -> None:
        self.model_name = model_name
        self.block_size = block_size
        self.dtype = dtype
        self.gpu_memory = gpu_memory
        self.cpu_memory = cpu_memory
        self.tensor_parallel_size = tensor_parallel_size

        config = AutoConfig.from_pretrained(model_name)
        self.num_layers = config.num_hidden_layers
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = config.hidden_size // self.num_heads
        self.ffn_size = config.intermediate_size
        self.vocab_size = config.vocab_size
        self.max_position = 8192
        self.tie_word_embeddings = config.tie_word_embeddings

    def get_param_size(self) -> int:
        word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
        if self.tie_word_embeddings:
            lm_head = 0
        else:
            lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size

        ln1 = 2 * self.hidden_size
        q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        # Rotary embedding.
        # TODO(woosuk): Share the rotary embedding between layers.
        rot = self.max_position * self.head_size
        mha = ln1 + q + k + v + out + rot

        ln2 = 2 * self.hidden_size
        ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
        ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
        ffn = ln2 + ffn1 + ffn2

        total = word_embedding + self.num_layers * (mha + ffn) + lm_head
Woosuk Kwon's avatar
Woosuk Kwon committed
349
350
351
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * total

352
    def get_max_act_size(
Woosuk Kwon's avatar
Woosuk Kwon committed
353
354
355
        self,
        max_num_batched_tokens: int,
    ) -> int:
356
357
358
359
        # NOTE: We approxmiately calculate the maximum activation size by
        # estimating
        # 1) the maximum activation tensor size during inference
        # 2) the residual tensor size during inference
360
361
        # Here, we assume that we use memory-efficient attention which
        # does not materialize the attention maps in GPU DRAM.
362
363
364
365
366
367
368
369
370
371
        residual = max_num_batched_tokens * self.hidden_size
        qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
        ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
        # Double the activation size for input and output.
        max_act = 2 * (max(qkv, ffn) + residual)
        # Size of output logits.
        output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
        max_act = max(max_act, output_logits)
        dtype_size = get_dtype_size(self.dtype)
        return dtype_size * max_act