fsdp.py 11.1 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import argparse
7

8
9
10
11
12
13
14
from functools import partial

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
15
16
17
18
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing,
    checkpoint_wrapper
)
19
20
21

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))

# RNG state tracker for checkpointing
rng_seed = 1234
torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed)
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
CUDA_RNG_STATES_TRACKER.add('model-parallel-rng', rng_seed)
def get_cuda_rng_tracker():
    return CUDA_RNG_STATES_TRACKER

def apply_fsdp_checkpointing(model, blocks):
    """apply activation checkpointing to model
    returns None as model is updated directly
    """
    wrapper = lambda m: checkpoint_wrapper(m,
                                           checkpoint_fn=te.distributed.checkpoint,
                                           use_reentrant=False,
                                           get_rng_state_tracker=get_cuda_rng_tracker)
    check_fn = lambda submodule: isinstance(submodule, blocks)
    apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

def lowercase(s):
    return str(s).lower()

def torch_dtype(d):
    typemap = {
        'fp32' : torch.float32,
        'float32' : torch.float32,
        'fp16' : torch.float16,
        'float16' : torch.float16,
        'bf16' : torch.bfloat16,
        'bfloat16' : torch.bfloat16
    }
    if lowercase(d) not in typemap.keys():
        raise TypeError
    return typemap[lowercase(d)]

te_layer_map = {
    'linear': te.Linear,
    'layernorm': te.LayerNorm,
    'rmsnorm': te.RMSNorm,
    'layernormlinear': te.LayerNormLinear,
    'layernormmlp': te.LayerNormMLP,
    'multiheadattention': te.MultiheadAttention,
    'transformerlayer': te.TransformerLayer
}
def te_layer(l):
73
74
75
76
77
78
79
80
    if l is not None:
        if lowercase(l) not in te_layer_map.keys():
            raise TypeError
        return te_layer_map[lowercase(l)]
    return None

def get_layer_args(opts):
    hidden_size = opts.num_heads * opts.head_dim
81
82
    layer_args = (hidden_size, )
    layer_kwargs = {
83
84
85
        'params_dtype': opts.dtype,
        'device': 'cuda' if opts.no_defer_init else 'meta',
        'get_rng_state_tracker': get_cuda_rng_tracker,
86
    }
87
88
    if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
        ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size
89
90
        layer_args += (ffn_hidden_size, )
        layer_kwargs['bias'] = True
91
92
93
94
        if opts.layer_type == te.LayerNormMLP:
            layer_kwargs['seq_length'] = opts.seq_length
    elif opts.layer_type == te.MultiheadAttention:
        layer_args += (opts.num_heads, )
95
        layer_kwargs['fuse_qkv_params'] = True
96
97
98
        layer_kwargs['input_layernorm'] = True
    elif opts.layer_type == te.TransformerLayer:
        layer_args += (3 * hidden_size, opts.num_heads)
99
        layer_kwargs['fuse_qkv_params'] = True
100
        layer_kwargs['seq_length'] = opts.seq_length
101
102
103
104
105
    return layer_args, layer_kwargs

def parse_fsdp_args():
    parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
                                    "torch.distributed.fsdp.FullyShardedDataParallel strategy.")
106
107
    parser.add_argument('-v', "--verbose", action="store_true", default=False,
                        help="Print out information from all GPUs instead of only the root GPU-0.")
108
109
110
111
112
113
114
115
    parser.add_argument('-b', "--batch-size", type=int, default=32,
                        help="Input batch size.")
    parser.add_argument('-s', "--seq-length", type=int, default=1048,
                        help="Input sequence length.")
    parser.add_argument('-n', "--num-heads", type=int, default=16,
                        help="Number of attention heads.")
    parser.add_argument('-d', "--head-dim", type=int, default=128,
                        help="Dimension of each attention head (number of KV channels).")
116
117
118
    parser.add_argument('-i', "--num-iters", type=int, default=5,
                        help="Number of dummy 'training' iterations.")
    parser.add_argument('-k', "--num-layers", type=int, default=3,
119
                        help="Number of modules chained together with nn.Sequential.")
120
121
122
    parser.add_argument("--layer-type", type=te_layer, default=te.TransformerLayer,
                        choices=list(te_layer_map.values()),
                        help="TE module type used to construct the test model.")
123
124
    parser.add_argument("--seed", type=int, default=1234,
                        help="PyTorch RNG seed.")
125
126
127
128
129
130
131
132
133
134
    parser.add_argument("--profile-memory", action="store_true",
                        help="Enable memory profiling via torch.profiler.profile().")
    parser.add_argument("--profile-name", type=str, default=None,
                        help="File path for memory profiling.")
    parser.add_argument("--checkpoint-layer", type=te_layer, default=None,
                        help="Recompute activations of the selected layer during the backward " + \
                             "pass instead of saving.")
    parser.add_argument("--no-fp8", action="store_true", default=False,
                        help="Disables the te.fp8_autocast() context.")
    parser.add_argument("--no-defer-init", action="store_true",
135
                        help="Defer module parameter initialization until after FSDP sharding.")
136
137
    parser.add_argument("--no-te-fsdp", action="store_true",
                        help="Disable sharding of intermediate/activation tensors in TE modules.")
138
139
140
141
    parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
                        help="Data type for input tensor and Transformer Engine module parameters.")
    return parser.parse_args()

142
143
144
145
def dist_print(text, all_ranks=False, no_new_line=False):
    if LOCAL_RANK == 0 or all_ranks:
        end = '' if no_new_line else '\n'
        print(f"[GPU-{LOCAL_RANK}] " + text, end=end)
146

147
def train(opts):
148
149
    # Initialize torch.distributed global process group
    dist.init_process_group(backend="nccl")
150
151
152
    torch.cuda.set_device(LOCAL_RANK)
    dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
    torch.manual_seed(opts.seed)
153
154

    # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
155
156
    layer_args, layer_kwargs = get_layer_args(opts)
    if opts.num_layers > 1:
157
        te_layer_list = []
158
159
        for i in range(opts.num_layers):
            if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
160
                layer_kwargs['layer_number'] = i+1
161
            te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs))
162
163
164
        te_model = nn.Sequential(*te_layer_list)
    else:
        # Single layer model
165
        te_model = opts.layer_type(*layer_args, **layer_kwargs)
166
167

    # Print out allocated device memory before the model parameters are sharded by FSDP
168
169
    pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
    dist_print(f"Pre-FSDP memory use = {pre_mem_use}MiB")
170
171
172
173
174
175

    # Wrap the model with FSDP
    # NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
    #       controls all communication.
    all_gpus = dist.new_group(backend='nccl')
    fsdp_wrap_policy = always_wrap_policy
176
    if opts.layer_type == te.TransformerLayer:
177
178
179
180
181
182
183
        # NOTE: FSDP causes illegal memory access without this special policy for Transformers
        fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
                                   transformer_layer_cls={te.TransformerLayer})
    te_model = FullyShardedDataParallel(te_model,
                                        process_group=all_gpus,
                                        use_orig_params=True,
                                        mixed_precision=MixedPrecision(
184
                                            param_dtype=opts.dtype,
185
186
187
188
                                            reduce_dtype=torch.float32,
                                        ),
                                        auto_wrap_policy=fsdp_wrap_policy)

189
190
191
192
193
194
195
196
    if opts.checkpoint_layer is not None:
        # Recompute the activations of the selected layer during the backward pass instead of
        # saving them during the forward pass
        apply_fsdp_checkpointing(te_model, blocks=opts.checkpoint_layer)
    elif not opts.no_te_fsdp:
        # Prepare TE modules to shard internal buffers that FSDP cannot shard on its own
        prepare_te_modules_for_fsdp(te_model)

197
    # Print out allocated device memory after the model parameters are sharded
198
199
200
    post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
    dist_print(f"Post-FSDP memory use = {post_mem_use}MiB")
    dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}")
201
202
203
204
205
206
207
208

    # Fp8 setup for TE
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

    # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
    optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)

209
210
211
212
213
214
215
216
217
218
219
    # Profile memory use
    if opts.profile_memory:
        torch.cuda.memory._record_memory_history(max_entries=100000)
    else:
        torch.cuda.reset_peak_memory_stats()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        start.record()

    for i in range(opts.num_iters):
220
        # Generate a random input batch
221
222
        x = torch.rand(opts.seq_length, opts.batch_size, opts.num_heads*opts.head_dim,
                    dtype=opts.dtype, device='cuda')
223
        # fp8_autocast needs to be given the FSDP process group for amax reductions
224
        with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
225
226
227
228
229
            y = te_model(x)
            loss = y.sum()
        # calculate gradient and take training step outside the fp8_autocast context
        loss.backward()
        optim.step()
230
        optim.zero_grad(set_to_none=True)
231
232
        del x

233
234
235
236
237
238
239
240
241
242
243
244

    if opts.profile_memory:
        torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle")
        torch.cuda.memory._record_memory_history(enabled=None)
    else:
        end.record()
        torch.cuda.synchronize()
        peak_mem = torch.cuda.max_memory_allocated()
        train_time = start.elapsed_time(end)/1000.
        dist_print(f"Training Time: {train_time}s")
        dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s")
        dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs")
245
246


247
248
# Run with:
#   torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) test_fsdp.py --defer-init
249
250
251
if __name__ == "__main__":
    args = parse_fsdp_args()
    train(args)