README.md 2.09 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

# Basic Example for Using PyTorch Fully Sharded Data Parallel mode with Transformer Engine

```bash
# FSDP without deferred initialization:
#     Duplicate modules initialized on each device. Load on device memory reduced only after
#     torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters.
11
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --no-defer-init
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Sample output on 8xL40S:
#    [GPU-0] WORLD_SIZE = 8
#    [GPU-0] TransformerEngine Model:
#    TransformerLayer(
#    (self_attention): MultiheadAttention(
#        (layernorm_qkv): LayerNormLinear()
#        (core_attention): DotProductAttention(
#        (flash_attention): FlashAttention()
#        (fused_attention): FusedAttention()
#        (unfused_attention): UnfusedDotProductAttention(
#            (scale_mask_softmax): FusedScaleMaskSoftmax()
#            (attention_dropout): Dropout(p=0.1, inplace=False)
#        )
#        )
#        (proj): Linear()
#    )
#    (layernorm_mlp): LayerNormMLP()
#    )
#    [GPU-0] Pre-FSDP memory use = 83.935232MiB
#    [GPU-0] Post-FSDP memory use = 10.491904MiB
#    [GPU-0] Iter. 1
#    [GPU-0] Iter. 2
#    [GPU-0] Iter. 3
#    [GPU-0] Training Time: 6.647654296875s
#    [GPU-0] Avg. Iter. Time: 2.2158847656250003s
#    [GPU-0] Peak memory use = 3000MiB

# FSDP with deferred initialization:
hugo-syn's avatar
hugo-syn committed
40
#    Modules initialized with empty parameters via `device='meta'` option. Zero load on device
41
42
#    memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on
#    on already sharded model parameters.
43
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py
44
45
46
47
48
49
50
51
# Sample output on 8xL40S:
#    [GPU-0] WORLD_SIZE = 8
#    ...
#    [GPU-0] Pre-FSDP memory use = 0.0MiB
#    [GPU-0] Post-FSDP memory use = 10.491904MiB
#    ...
```

52
**NOTE:** This example has `autocast()` enabled by default. To run on GPUs without Fp8 support
53
(e.g.: A100), add the `--no-fp8` option to the commands shown above.