Unverified Commit c593bcef authored by Neil Tenenholtz's avatar Neil Tenenholtz Committed by GitHub
Browse files

Fix test of FSDP2 by correcting init logic and applying autocast (#2105)



* Fix test of FSDP2 by correcting init logic and applying autocast

This fixes multiple issues in the FSDP2 test, namely
1. Previously fp8 init was performed when `args.fp8_init == False`. I have updated the logic to match what I presume was intended by leveraging the nullcontext context manager.
2. `te.fp8_autocast` was previously not called; the recipe was created but was unused. The autocast context manager now wraps the model's computation.
Signed-off-by: default avatarNeil Tenenholtz <ntenenz@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix typo
Signed-off-by: default avatarNeil Tenenholtz <ntenenz@users.noreply.github.com>

* Update tests/pytorch/distributed/run_fsdp2_model.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix bug when constructing context for model init
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarNeil Tenenholtz <ntenenz@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent a7a69ca6
......@@ -105,23 +105,19 @@ def _train(args):
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
build_model_context_args = {}
# Create build context manager
if args.fp8_init:
from transformer_engine.pytorch import quantized_model_init
build_model_context = quantized_model_init
build_model_context_args["enabled"] = True
build_model_context = quantized_model_init()
else:
build_model_context = nullcontext()
# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
else:
with build_model_context:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
# Move the model to the correct device
# Move the model to the correct device
model.to(device)
if LOCAL_RANK == 0:
......@@ -163,6 +159,7 @@ def _train(args):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
with te.autocast(enabled=True, recipe=fp8_recipe):
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment