fsdp.py 10.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

import os
import argparse
7

8
9
10
11
12
13
14
from functools import partial

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
15
16
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing,
17
    checkpoint_wrapper,
18
)
19
20
21

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
22
23
24
25
26
27
28
29
30
31
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))

# RNG state tracker for checkpointing
rng_seed = 1234
torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed)
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
32
33
34
CUDA_RNG_STATES_TRACKER.add("model-parallel-rng", rng_seed)


35
36
37
def get_cuda_rng_tracker():
    return CUDA_RNG_STATES_TRACKER

38

39
40
41
42
def apply_fsdp_checkpointing(model, blocks):
    """apply activation checkpointing to model
    returns None as model is updated directly
    """
43
44
45
46
47
48
    wrapper = lambda m: checkpoint_wrapper(
        m,
        checkpoint_fn=te.distributed.checkpoint,
        use_reentrant=False,
        get_rng_state_tracker=get_cuda_rng_tracker,
    )
49
50
    check_fn = lambda submodule: isinstance(submodule, blocks)
    apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
51

52

53
54
55
def lowercase(s):
    return str(s).lower()

56

57
58
def torch_dtype(d):
    typemap = {
59
60
61
62
63
64
        "fp32": torch.float32,
        "float32": torch.float32,
        "fp16": torch.float16,
        "float16": torch.float16,
        "bf16": torch.bfloat16,
        "bfloat16": torch.bfloat16,
65
66
67
68
69
    }
    if lowercase(d) not in typemap.keys():
        raise TypeError
    return typemap[lowercase(d)]

70

71
te_layer_map = {
72
73
74
75
76
77
78
    "linear": te.Linear,
    "layernorm": te.LayerNorm,
    "rmsnorm": te.RMSNorm,
    "layernormlinear": te.LayerNormLinear,
    "layernormmlp": te.LayerNormMLP,
    "multiheadattention": te.MultiheadAttention,
    "transformerlayer": te.TransformerLayer,
79
}
80
81


82
def te_layer(l):
83
84
85
86
87
88
    if l is not None:
        if lowercase(l) not in te_layer_map.keys():
            raise TypeError
        return te_layer_map[lowercase(l)]
    return None

89

90
91
def get_layer_args(opts):
    hidden_size = opts.num_heads * opts.head_dim
92
    layer_args = (hidden_size,)
93
    layer_kwargs = {
94
95
96
        "params_dtype": opts.dtype,
        "device": "cuda" if opts.no_defer_init else "meta",
        "get_rng_state_tracker": get_cuda_rng_tracker,
97
    }
98
99
    if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
        ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size
100
101
        layer_args += (ffn_hidden_size,)
        layer_kwargs["bias"] = True
102
        if opts.layer_type == te.LayerNormMLP:
103
            layer_kwargs["seq_length"] = opts.seq_length
104
    elif opts.layer_type == te.MultiheadAttention:
105
106
107
        layer_args += (opts.num_heads,)
        layer_kwargs["fuse_qkv_params"] = True
        layer_kwargs["input_layernorm"] = True
108
109
    elif opts.layer_type == te.TransformerLayer:
        layer_args += (3 * hidden_size, opts.num_heads)
110
111
        layer_kwargs["fuse_qkv_params"] = True
        layer_kwargs["seq_length"] = opts.seq_length
112
113
    return layer_args, layer_kwargs

114

115
def parse_fsdp_args():
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    parser = argparse.ArgumentParser(
        description="Run Transformer Engine modules with the "
        + "torch.distributed.fsdp.FullyShardedDataParallel strategy."
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        default=False,
        help="Print out information from all GPUs instead of only the root GPU-0.",
    )
    parser.add_argument("-b", "--batch-size", type=int, default=32, help="Input batch size.")
    parser.add_argument("-s", "--seq-length", type=int, default=1048, help="Input sequence length.")
    parser.add_argument(
        "-n", "--num-heads", type=int, default=16, help="Number of attention heads."
    )
    parser.add_argument(
        "-d",
        "--head-dim",
        type=int,
        default=128,
        help="Dimension of each attention head (number of KV channels).",
    )
    parser.add_argument(
        "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
    )
    parser.add_argument(
        "-k",
        "--num-layers",
        type=int,
        default=3,
        help="Number of modules chained together with nn.Sequential.",
    )
    parser.add_argument(
        "--layer-type",
        type=te_layer,
        default=te.TransformerLayer,
        choices=list(te_layer_map.values()),
        help="TE module type used to construct the test model.",
    )
    parser.add_argument("--seed", type=int, default=1234, help="PyTorch RNG seed.")
    parser.add_argument(
        "--profile-memory",
        action="store_true",
        help="Enable memory profiling via torch.profiler.profile().",
    )
    parser.add_argument(
        "--profile-name", type=str, default=None, help="File path for memory profiling."
    )
    parser.add_argument(
        "--checkpoint-layer",
        type=te_layer,
        default=None,
        help="Recompute activations of the selected layer during the backward "
        + "pass instead of saving.",
    )
    parser.add_argument(
        "--no-fp8",
        action="store_true",
        default=False,
        help="Disables the te.fp8_autocast() context.",
    )
    parser.add_argument(
        "--no-defer-init",
        action="store_true",
        help="Defer module parameter initialization until after FSDP sharding.",
    )
    parser.add_argument(
        "--no-te-fsdp",
        action="store_true",
        help="Disable sharding of intermediate/activation tensors in TE modules.",
    )
    parser.add_argument(
        "--dtype",
        type=torch_dtype,
        default=torch.bfloat16,
        help="Data type for input tensor and Transformer Engine module parameters.",
    )
194
195
    return parser.parse_args()

196

197
198
def dist_print(text, all_ranks=False, no_new_line=False):
    if LOCAL_RANK == 0 or all_ranks:
199
        end = "" if no_new_line else "\n"
200
        print(f"[GPU-{LOCAL_RANK}] " + text, end=end)
201

202

203
def train(opts):
204
205
    # Initialize torch.distributed global process group
    dist.init_process_group(backend="nccl")
206
207
208
    torch.cuda.set_device(LOCAL_RANK)
    dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
    torch.manual_seed(opts.seed)
209
210

    # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
211
212
    layer_args, layer_kwargs = get_layer_args(opts)
    if opts.num_layers > 1:
213
        te_layer_list = []
214
215
        for i in range(opts.num_layers):
            if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
216
                layer_kwargs["layer_number"] = i + 1
217
            te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs))
218
219
220
        te_model = nn.Sequential(*te_layer_list)
    else:
        # Single layer model
221
        te_model = opts.layer_type(*layer_args, **layer_kwargs)
222
223

    # Print out allocated device memory before the model parameters are sharded by FSDP
224
225
    pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
    dist_print(f"Pre-FSDP memory use = {pre_mem_use}MiB")
226
227
228
229

    # Wrap the model with FSDP
    # NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
    #       controls all communication.
230
    all_gpus = dist.new_group(backend="nccl")
231
    fsdp_wrap_policy = always_wrap_policy
232
    if opts.layer_type == te.TransformerLayer:
233
        # NOTE: FSDP causes illegal memory access without this special policy for Transformers
234
235
236
237
238
239
240
241
242
243
244
245
246
        fsdp_wrap_policy = partial(
            transformer_auto_wrap_policy, transformer_layer_cls={te.TransformerLayer}
        )
    te_model = FullyShardedDataParallel(
        te_model,
        process_group=all_gpus,
        use_orig_params=True,
        mixed_precision=MixedPrecision(
            param_dtype=opts.dtype,
            reduce_dtype=torch.float32,
        ),
        auto_wrap_policy=fsdp_wrap_policy,
    )
247

248
249
250
251
252
253
254
255
    if opts.checkpoint_layer is not None:
        # Recompute the activations of the selected layer during the backward pass instead of
        # saving them during the forward pass
        apply_fsdp_checkpointing(te_model, blocks=opts.checkpoint_layer)
    elif not opts.no_te_fsdp:
        # Prepare TE modules to shard internal buffers that FSDP cannot shard on its own
        prepare_te_modules_for_fsdp(te_model)

256
    # Print out allocated device memory after the model parameters are sharded
257
258
259
    post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
    dist_print(f"Post-FSDP memory use = {post_mem_use}MiB")
    dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}")
260
261
262
263
264
265
266
267

    # Fp8 setup for TE
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

    # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
    optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)

268
269
270
271
272
273
274
275
276
277
278
    # Profile memory use
    if opts.profile_memory:
        torch.cuda.memory._record_memory_history(max_entries=100000)
    else:
        torch.cuda.reset_peak_memory_stats()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        start.record()

    for i in range(opts.num_iters):
279
        # Generate a random input batch
280
281
282
283
284
285
286
        x = torch.rand(
            opts.seq_length,
            opts.batch_size,
            opts.num_heads * opts.head_dim,
            dtype=opts.dtype,
            device="cuda",
        )
287
        # fp8_autocast needs to be given the FSDP process group for amax reductions
288
        with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
289
290
291
292
293
            y = te_model(x)
            loss = y.sum()
        # calculate gradient and take training step outside the fp8_autocast context
        loss.backward()
        optim.step()
294
        optim.zero_grad(set_to_none=True)
295
296
        del x

297
298
299
300
301
302
303
    if opts.profile_memory:
        torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle")
        torch.cuda.memory._record_memory_history(enabled=None)
    else:
        end.record()
        torch.cuda.synchronize()
        peak_mem = torch.cuda.max_memory_allocated()
304
        train_time = start.elapsed_time(end) / 1000.0
305
306
307
        dist_print(f"Training Time: {train_time}s")
        dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s")
        dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs")
308
309


310
311
# Run with:
#   torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) test_fsdp.py --defer-init
312
313
314
if __name__ == "__main__":
    args = parse_fsdp_args()
    train(args)