run_fsdp2_model.py 6.15 KB
Newer Older
1
2
#!/usr/bin/python3

3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# See LICENSE for license information.

import os
import sys
import argparse


import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn, optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext

yuguo's avatar
yuguo committed
21
22
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNet, self).__init__()
        self.fc1 = te.Linear(input_size, hidden_size)
        self.fc2 = te.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def save_custom_attrs(module):
    custom_attrs = {}
    for name, param in module.named_parameters():
        attrs = vars(param)
        custom_attrs[name] = {k: v for k, v in attrs.items()}
    return custom_attrs


def restore_custom_attrs(module, custom_attrs):
    for name, param in module.named_parameters():
        if name in custom_attrs:
            for attr_name, attr_value in custom_attrs[name].items():
                setattr(param, attr_name, attr_value)


def _parse_args(argv=None, namespace=None):
    parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
    parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
    parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
    parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
    parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
    parser.add_argument(
        "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
    )
    parser.add_argument(
        "--iter", type=int, default=10, help="Number of iterations for forward pass"
    )
    parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
    # Adding hsdp_dim as a list argument, comma-separated
    parser.add_argument(
        "--sharding-dims",
        type=int,
        nargs="+",
        help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
    )
    args = parser.parse_args(argv, namespace)
    if args.sharding_dims:
        assert len(args.sharding_dims) <= 2
    return args


sub_modules_to_wrap = [te.Linear]


def _train(args):
    assert "TORCHELASTIC_RUN_ID" in os.environ
    WORLD_RANK = int(os.getenv("RANK", "0"))
    WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
    LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
    LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
    assert LOCAL_SIZE == WORLD_SIZE

    # Set device and initialize RNG states
    torch.cuda.set_device(WORLD_RANK)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Initialize torch.distributed global process group and get DP/TP groups
    dist_init_kwargs = {
        "backend": "nccl",
        "rank": WORLD_RANK,
        "world_size": WORLD_SIZE,
    }
    assert dist.is_nccl_available()
    dist.init_process_group(**dist_init_kwargs)
    nccl_world = dist.new_group(backend="nccl")
    device = torch.device(f"cuda:{LOCAL_RANK}")

    # FP8 Configuration
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

    if not args.fp8_init:
        # Build model context (FP8 init)
        build_model_context = nullcontext
        build_model_context_args = {}

        from transformer_engine.pytorch import fp8_model_init

        build_model_context = fp8_model_init
        build_model_context_args["enabled"] = True

        # Build the model with the specified context
        with build_model_context(**build_model_context_args):
            model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
    else:
        model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
    # Move the model to the correct device

    model.to(device)

    if LOCAL_RANK == 0:
        print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
    # Creating a DeviceMesh for fully_shard
    world_size = int(WORLD_SIZE)
    device_ids = list(range(world_size))
    if LOCAL_RANK == 0:
        print(f"sharding-dims:{args.sharding_dims}")
    # Setup the sharding mesh for FSDP/HSDP
    if args.sharding_dims == None:  # FSDP
        mesh = DeviceMesh("cuda", device_ids)
    elif len(args.sharding_dims) == 1:
        assert args.sharding_dims[0] == device_ids[-1] + 1
        mesh = DeviceMesh("cuda", device_ids)
    elif len(args.sharding_dims) == 2:  # HSDP
        assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
        mesh = init_device_mesh(
            "cuda",
            (args.sharding_dims[0], args.sharding_dims[1]),
            mesh_dim_names=("replicate", "shard"),
        )
    else:
        assert False

    # Apply FSDP/HSDP
    custom_attrs = save_custom_attrs(model)
    for sub_module in model.modules():
        if any(
            isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
        ):
            fully_shard(sub_module, mesh=mesh)
    fully_shard(model, mesh=mesh)
    restore_custom_attrs(model, custom_attrs)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for iteration in range(args.iter):
        # Zero the parameter gradients
        optimizer.zero_grad()
        input_data = torch.randn(args.batch_size, args.input_size).to(device)
        output = model(input_data)
        target = torch.randn(args.batch_size, args.output_size).to(device)
        loss = F.mse_loss(output, target)
        loss.backward()
        optimizer.step()
        if LOCAL_RANK == 0:
            print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")

    dist.destroy_process_group()
    if LOCAL_RANK == 0:
        print(f"Rank {LOCAL_RANK}: Done...")
    return 0


if __name__ == "__main__":
    sys.exit(_train(_parse_args()))