fsdp.py 8.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import argparse
from functools import partial

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

def lowercase(s):
    return str(s).lower()

def torch_dtype(d):
    typemap = {
        'fp32' : torch.float32,
        'float32' : torch.float32,
        'fp16' : torch.float16,
        'float16' : torch.float16,
        'bf16' : torch.bfloat16,
        'bfloat16' : torch.bfloat16
    }
    if lowercase(d) not in typemap.keys():
        raise TypeError
    return typemap[lowercase(d)]

te_layer_map = {
    'linear': te.Linear,
    'layernorm': te.LayerNorm,
    'rmsnorm': te.RMSNorm,
    'layernormlinear': te.LayerNormLinear,
    'layernormmlp': te.LayerNormMLP,
    'multiheadattention': te.MultiheadAttention,
    'transformerlayer': te.TransformerLayer
}
def te_layer(l):
    if lowercase(l) not in te_layer_map.keys():
        raise TypeError
    return te_layer_map[lowercase(l)]

def get_layer_args(args):
    hidden_size = args.num_heads * args.head_dim
    layer_args = (hidden_size, )
    layer_kwargs = {
        'params_dtype': args.dtype,
        'device': 'meta' if args.defer_init else 'cuda'
    }
    if args.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
        ffn_hidden_size = 3 * hidden_size if args.num_layers == 1 else hidden_size
        layer_args += (ffn_hidden_size, )
        layer_kwargs['bias'] = True
        if args.layer_type == te.LayerNormMLP:
            layer_kwargs['seq_length'] = args.seq_length
    elif args.layer_type == te.MultiheadAttention:
        layer_args += (args.num_heads, )
        layer_kwargs['fuse_qkv_params'] = True
    elif args.layer_type == te.TransformerLayer:
        layer_args += (3 * hidden_size, args.num_heads)
        layer_kwargs['fuse_qkv_params'] = True
        layer_kwargs['seq_length'] = args.seq_length
    return layer_args, layer_kwargs

def parse_fsdp_args():
    parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
                                    "torch.distributed.fsdp.FullyShardedDataParallel strategy.")
    parser.add_argument("-t", "--layer-type", type=te_layer, default=te.TransformerLayer,
                        choices=list(te_layer_map.values()),
                        help="TE module type used to construct the test model.")
    parser.add_argument("--no-fp8", action="store_true", default=False,
                        help="Disables the te.fp8_autocast() context.")
    parser.add_argument('-i', "--num-iters", type=int, default=3,
                        help="Number of dummy 'training' iterations.")
    parser.add_argument('-b', "--batch-size", type=int, default=32,
                        help="Input batch size.")
    parser.add_argument('-s', "--seq-length", type=int, default=1048,
                        help="Input sequence length.")
    parser.add_argument('-n', "--num-heads", type=int, default=16,
                        help="Number of attention heads.")
    parser.add_argument('-d', "--head-dim", type=int, default=128,
                        help="Dimension of each attention head (number of KV channels).")
    parser.add_argument('-l', "--num-layers", type=int, default=1,
                        help="Number of modules chained together with nn.Sequential.")
    parser.add_argument("--seed", type=int, default=1234,
                        help="PyTorch RNG seed.")
    parser.add_argument("--defer-init", action="store_true",
                        help="Defer module parameter initialization until after FSDP sharding.")
    parser.add_argument('-v', "--verbose", action="store_true", default=False,
                        help="Print out information from all GPUs instead of only the root GPU-0.")
    parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
                        help="Data type for input tensor and Transformer Engine module parameters.")
    return parser.parse_args()

def train(args):
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    # Initialize torch.distributed global process group
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)
    if local_rank == 0:
        print(f"[GPU-0] WORLD_SIZE = {world_size}\n\n", end='')
    torch.manual_seed(args.seed)

    # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
    layer_args, layer_kwargs = get_layer_args(args)
    if args.num_layers > 1:
        te_layer_list = []
        for i in range(args.num_layers):
            if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
                layer_kwargs['layer_number'] = i+1
                te_layer_list.append(args.layer_type(*layer_args, **layer_kwargs))
        te_model = nn.Sequential(*te_layer_list)
    else:
        # Single layer model
        te_model = args.layer_type(*layer_args, **layer_kwargs)
    if local_rank == 0:
        print(f"[GPU-0] TransformerEngine Model:\n{te_model}\n", end='')

    # Print out allocated device memory before the model parameters are sharded by FSDP
    pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
    if local_rank == 0 or args.verbose:
        print(f"[GPU-{local_rank}] Pre-FSDP memory use = {pre_mem_use}MiB\n", end='')

    # Wrap the model with FSDP
    # NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
    #       controls all communication.
    all_gpus = dist.new_group(backend='nccl')
    fsdp_wrap_policy = always_wrap_policy
    if args.layer_type == te.TransformerLayer:
        # NOTE: FSDP causes illegal memory access without this special policy for Transformers
        fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
                                   transformer_layer_cls={te.TransformerLayer})
    te_model = FullyShardedDataParallel(te_model,
                                        process_group=all_gpus,
                                        use_orig_params=True,
                                        mixed_precision=MixedPrecision(
                                            param_dtype=args.dtype,
                                            reduce_dtype=torch.float32,
                                        ),
                                        sync_module_states=True,
                                        auto_wrap_policy=fsdp_wrap_policy)

    # Print out allocated device memory after the model parameters are sharded
    post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
    if local_rank == 0 or args.verbose:
        print(f"[GPU-{local_rank}] Post-FSDP memory use = {post_mem_use}MiB\n", end='')

    # Fp8 setup for TE
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

    # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
    optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)

    # Start and time dummy "training" iterations
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    start.record()
    for i in range(args.num_iters):
        # Generate a random input batch
        x = torch.rand(args.seq_length, args.batch_size,
                       args.num_heads*args.head_dim).to(dtype=args.dtype).cuda()
        # fp8_autocast needs to be given the FSDP process group for amax reductions
        with te.fp8_autocast(enabled=not args.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
            y = te_model(x)
            loss = y.sum()
        # calculate gradient and take training step outside the fp8_autocast context
        loss.backward()
        optim.step()
        del x
        if local_rank == 0:
            print(f"[GPU-0] Iter. {i+1}\n", end='')
    end.record()
    torch.cuda.synchronize()

    # Print out "training" time and peak memory use stats
    train_time = start.elapsed_time(end)/1000.
    max_memory_alloc = int(torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") * 1e-6)
    if local_rank == 0 or args.verbose:
        print(f"[GPU-{local_rank}] Training Time: {train_time}s\n" +
              f"[GPU-{local_rank}] Avg. Iter. Time: {train_time /args.num_iters}s\n" +
              f"[GPU-{local_rank}] Peak memory use = {max_memory_alloc}MiB\n\n", end='')


if __name__ == "__main__":
    args = parse_fsdp_args()
    train(args)