theoretical_memory_usage.py 10 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Computes theoretical memory footprint for model training."""


import math

NUM_BYTES_IN_MEGABYTE = 1024 * 1024


def compute_weight_and_optimizer_memory(args, verbose=False):
    # Attention projection size.
    query_projection_size = args.kv_channels * args.num_attention_heads
    query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size
    # Group Query Attention.
    if not args.group_query_attention:
        args.num_query_groups = args.num_attention_heads
    # MoE.
    num_experts = 1 if args.num_experts is None else args.num_experts
    gated_linear_multiplier = 3 / 2 if args.swiglu else 1
dongcl's avatar
dongcl committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    # MLA
    if not args.multi_latent_attention:
        num_parameters_in_transformer_block = (
            2
            * args.hidden_size
            * args.hidden_size
            * (
                # Attention.
                (
                    (1 + (args.num_query_groups / args.num_attention_heads))
                    * query_projection_to_hidden_size_ratio
                )
                # MLP.
                + ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
                # Transformer layernorms.
                + (2 / args.hidden_size)
            )
        )
    else:
        q_head_dim = args.qk_head_dim + args.qk_pos_emb_head_dim
        query_projection_size = args.v_head_dim * args.num_attention_heads
        num_parameters_in_transformer_block = (
xingjinliang's avatar
xingjinliang committed
43
44
            # Attention.
            (
dongcl's avatar
dongcl committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
                # q_down
                args.hidden_size * args.q_lora_rank
                # q_up
                + args.q_lora_rank * (args.num_attention_heads * q_head_dim)
                # kv_down
                + args.hidden_size * (args.kv_lora_rank + args.qk_pos_emb_head_dim)
                # kv_up
                + args.kv_lora_rank * (args.num_attention_heads * (args.qk_head_dim + args.v_head_dim))
                # q_layernorm
                + 2 * args.q_lora_rank
                # kv_layernorm
                + 2 * args.kv_lora_rank
                # linear_proj
                + query_projection_size * args.hidden_size
xingjinliang's avatar
xingjinliang committed
59
            )
dongcl's avatar
dongcl committed
60
61
            # routed experts.
            + (2 * (args.ffn_hidden_size * args.hidden_size) * num_experts * gated_linear_multiplier)
dongcl's avatar
dongcl committed
62
63
            # router
            + args.hidden_size * num_experts
dongcl's avatar
dongcl committed
64
65
            # shared experts.
            + (2 * args.moe_shared_expert_intermediate_size * args.hidden_size)
xingjinliang's avatar
xingjinliang committed
66
            # Transformer layernorms.
dongcl's avatar
dongcl committed
67
            + (4 * args.hidden_size)
xingjinliang's avatar
xingjinliang committed
68
        )
dongcl's avatar
dongcl committed
69
70
71
72
73

    num_parameters_in_transformer_layers = (
        args.num_layers * num_parameters_in_transformer_block
        # Final layernorm.
        + (2 * args.hidden_size)
xingjinliang's avatar
xingjinliang committed
74
    )
dongcl's avatar
dongcl committed
75

xingjinliang's avatar
xingjinliang committed
76
77
78
79
80
    embedding_size = args.hidden_size * args.padded_vocab_size
    if args.untie_embeddings_and_output_weights:
        num_parameters_in_embedding_layers = 2 * embedding_size
    else:
        num_parameters_in_embedding_layers = embedding_size
81
82
83

    # mtp
    num_parameters_in_mtp_layers = (
dongcl's avatar
dongcl committed
84
        args.num_nextn_predict_layers
85
        * (
dongcl's avatar
dongcl committed
86
87
            # transformer block.
            num_parameters_in_transformer_block
88
            # layernorms.
dongcl's avatar
dongcl committed
89
            + (6 * args.hidden_size)
90
            # linear projection.
dongcl's avatar
dongcl committed
91
            + 2 * args.hidden_size * args.hidden_size
92
93
94
95
        )
    )

    # params of mtp embedding and mtp output layer
dongcl's avatar
dongcl committed
96
    num_parameters_in_mtp_embedding_or_output = args.num_nextn_predict_layers * args.hidden_size * args.padded_vocab_size
97
    if not args.share_mtp_embedding_and_output_weight:
dongcl's avatar
dongcl committed
98
99
100
        num_parameters_in_mtp_layers += 2 * num_parameters_in_mtp_embedding_or_output
    elif args.pipeline_model_parallel_size > 1:
        num_parameters_in_mtp_layers += num_parameters_in_mtp_embedding_or_output
101
102

    num_total_parameters = num_parameters_in_transformer_layers + num_parameters_in_embedding_layers + num_parameters_in_mtp_layers
xingjinliang's avatar
xingjinliang committed
103
104
105
106
107
108
109
110
111
    if verbose:
        print(
            f"Number of parameters in transformer layers in billions: "
            f"{num_parameters_in_transformer_layers / 10**9: .2f}"
        )
        print(
            f"Number of parameters in embedding layers in billions: "
            f"{num_parameters_in_embedding_layers / 10**9:.2f}"
        )
112
113
114
115
        print(
            f"Number of parameters in mtp layers in billions: "
            f"{num_parameters_in_mtp_layers / 10**9:.2f}"
        )
xingjinliang's avatar
xingjinliang committed
116
117
118
119
        print(f"Total number of parameters in billions: {num_total_parameters / 10**9:.2f}")

    # Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size.
    num_parameters_on_most_loaded_model_shard = (
120
        (num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size + num_parameters_in_mtp_layers
xingjinliang's avatar
xingjinliang committed
121
122
123
124
125
    ) / args.tensor_model_parallel_size
    if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1:
        num_parameters_on_most_loaded_model_shard += (
            embedding_size / args.tensor_model_parallel_size
        )
126

xingjinliang's avatar
xingjinliang committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    if verbose:
        print(
            f"Number of parameters in most loaded shard in billions: "
            f"{num_parameters_on_most_loaded_model_shard / 10**9:.4f}"
        )

    if args.pipeline_model_parallel_size > 1:
        # Other shards just have (1/pp_size transformer layers) / tp_size.
        num_parameters_on_other_model_shards = num_parameters_in_transformer_layers / (
            args.pipeline_model_parallel_size * args.tensor_model_parallel_size
        )
        if verbose:
            print(
                f"Number of parameters in other shards in billions: "
                f"{num_parameters_on_other_model_shards / 10**9:.4f}"
            )

    num_bytes_per_parameter = (
        18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size)
    )
    weight_and_optimizer_memory = (
        num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter
    )

    return weight_and_optimizer_memory


def compute_activation_memory(args, num_microbatches, verbose=False):
    # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
    # We are trying to compute the maximum activation footprint, so all calculations in this
    # function are for the first pipeline stage.

    # TODO: This function needs to take into account query_projection_size potentially being
    # different from hidden_size.

    # Memory footprint from transformer layer (self-attention and MLP).
    activation_memory = (args.seq_length * args.micro_batch_size * args.hidden_size) * (
        18 + (4 * (args.ffn_hidden_size / args.hidden_size))
    )
    if verbose:
        print(
            f"Activation memory footprint per transformer layer: "
            f"{activation_memory / NUM_BYTES_IN_MEGABYTE / args.tensor_model_parallel_size:.1f} MB"
        )
    activation_memory *= args.num_layers

    # Now add activation memory required for input embeddings, last LayerNorm and output layer.

    # Input to embedding (pp_size microbatches in flight).
    activation_memory += (
        8 * args.seq_length * args.micro_batch_size * args.pipeline_model_parallel_size
    )
    # Dropout in embedding layer (pp_size microbatches in flight).
    activation_memory += (
        args.seq_length
        * args.micro_batch_size
        * args.hidden_size
        * args.pipeline_model_parallel_size
    )

    # Multiply by interleaved PP memory factor.
    if args.virtual_pipeline_model_parallel_size is not None:
        interleaved_schedule_memory_penalty = 1 + (
            (args.pipeline_model_parallel_size - 1)
            / (args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size)
        )
        in_flight_microbatches = math.ceil(
            interleaved_schedule_memory_penalty * args.pipeline_model_parallel_size
        )
        if verbose:
            print(
                f"Memory penalty from interleaved schedule: {interleaved_schedule_memory_penalty:.2f}"
            )
            print(f"Number of in-flight microbatches: {in_flight_microbatches}")
        activation_memory *= interleaved_schedule_memory_penalty

    # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
    # so discount accordingly.
    if args.virtual_pipeline_model_parallel_size is None and args.pipeline_model_parallel_size > 1:
        if num_microbatches is not None:
            activation_memory *= min(1, num_microbatches / args.pipeline_model_parallel_size)
            in_flight_microbatches = min(num_microbatches, args.pipeline_model_parallel_size)
        else:
            in_flight_microbatches = args.pipeline_model_parallel_size
        if verbose:
            print(f"Number of in-flight microbatches: {in_flight_microbatches}")

    if args.pipeline_model_parallel_size == 1:
        # Inputs to output layer and CE loss.
        activation_memory += (
            args.seq_length
            * args.micro_batch_size
            * args.hidden_size
            * 4
            * (1 + (args.padded_vocab_size / args.hidden_size))
        )

    # Activation memory is partitioned by TP size due to tensor and sequence model parallelism.
    return activation_memory / args.tensor_model_parallel_size


def report_theoretical_memory(args, num_microbatches=None, verbose=False):
    weight_and_optimizer_memory = (
        compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE
    )

    # Formulae here assume sequence parallelism and selective activation recomputation.
    if not args.sequence_parallel or args.recompute_granularity != 'selective':
        print(
            f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB"
        )
        return

    activation_memory = (
        compute_activation_memory(args, num_microbatches=num_microbatches, verbose=verbose)
        / NUM_BYTES_IN_MEGABYTE
    )
    total_memory = weight_and_optimizer_memory + activation_memory

    print(
        f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB, "
        f"activation={activation_memory:.2f} MB, total={total_memory:.2f} MB\n"
    )