run_fsdp2_model.py 6.02 KB
Newer Older
1
2
#!/usr/bin/python3

3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#
# See LICENSE for license information.

import os
import sys
import argparse

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn, optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext


class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNet, self).__init__()
        self.fc1 = te.Linear(input_size, hidden_size)
        self.fc2 = te.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def save_custom_attrs(module):
    custom_attrs = {}
    for name, param in module.named_parameters():
        attrs = vars(param)
        custom_attrs[name] = {k: v for k, v in attrs.items()}
    return custom_attrs


def restore_custom_attrs(module, custom_attrs):
    for name, param in module.named_parameters():
        if name in custom_attrs:
            for attr_name, attr_value in custom_attrs[name].items():
                setattr(param, attr_name, attr_value)


def _parse_args(argv=None, namespace=None):
    parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
    parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
    parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
    parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
    parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
    parser.add_argument(
        "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
    )
    parser.add_argument(
        "--iter", type=int, default=10, help="Number of iterations for forward pass"
    )
    parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
    # Adding hsdp_dim as a list argument, comma-separated
    parser.add_argument(
        "--sharding-dims",
        type=int,
        nargs="+",
        help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
    )
    args = parser.parse_args(argv, namespace)
    if args.sharding_dims:
        assert len(args.sharding_dims) <= 2
    return args


sub_modules_to_wrap = [te.Linear]


def _train(args):
    assert "TORCHELASTIC_RUN_ID" in os.environ
    WORLD_RANK = int(os.getenv("RANK", "0"))
    WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
    LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
    LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
    assert LOCAL_SIZE == WORLD_SIZE

    # Set device and initialize RNG states
    torch.cuda.set_device(WORLD_RANK)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Initialize torch.distributed global process group and get DP/TP groups
    dist_init_kwargs = {
        "backend": "nccl",
        "rank": WORLD_RANK,
        "world_size": WORLD_SIZE,
    }
    assert dist.is_nccl_available()
    dist.init_process_group(**dist_init_kwargs)
    nccl_world = dist.new_group(backend="nccl")
    device = torch.device(f"cuda:{LOCAL_RANK}")

    # FP8 Configuration
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

108
109
    # Create build context manager
    if args.fp8_init:
110
        from transformer_engine.pytorch import quantized_model_init
111

112
        build_model_context = quantized_model_init()
113
    else:
114
115
116
117
        build_model_context = nullcontext()

    # Build the model with the specified context
    with build_model_context:
118
119
        model = SimpleNet(args.input_size, args.hidden_size, args.output_size)

120
    # Move the model to the correct device
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    model.to(device)

    if LOCAL_RANK == 0:
        print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
    # Creating a DeviceMesh for fully_shard
    world_size = int(WORLD_SIZE)
    device_ids = list(range(world_size))
    if LOCAL_RANK == 0:
        print(f"sharding-dims:{args.sharding_dims}")
    # Setup the sharding mesh for FSDP/HSDP
    if args.sharding_dims == None:  # FSDP
        mesh = DeviceMesh("cuda", device_ids)
    elif len(args.sharding_dims) == 1:
        assert args.sharding_dims[0] == device_ids[-1] + 1
        mesh = DeviceMesh("cuda", device_ids)
    elif len(args.sharding_dims) == 2:  # HSDP
        assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
        mesh = init_device_mesh(
            "cuda",
            (args.sharding_dims[0], args.sharding_dims[1]),
            mesh_dim_names=("replicate", "shard"),
        )
    else:
        assert False

    # Apply FSDP/HSDP
    custom_attrs = save_custom_attrs(model)
    for sub_module in model.modules():
        if any(
            isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
        ):
            fully_shard(sub_module, mesh=mesh)
    fully_shard(model, mesh=mesh)
    restore_custom_attrs(model, custom_attrs)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for iteration in range(args.iter):
        # Zero the parameter gradients
        optimizer.zero_grad()
        input_data = torch.randn(args.batch_size, args.input_size).to(device)
162
163
        with te.autocast(enabled=True, recipe=fp8_recipe):
            output = model(input_data)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        target = torch.randn(args.batch_size, args.output_size).to(device)
        loss = F.mse_loss(output, target)
        loss.backward()
        optimizer.step()
        if LOCAL_RANK == 0:
            print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")

    dist.destroy_process_group()
    if LOCAL_RANK == 0:
        print(f"Rank {LOCAL_RANK}: Done...")
    return 0


if __name__ == "__main__":
    sys.exit(_train(_parse_args()))