Unverified Commit 5c5866b3 authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

Add is root check to only cast to FP16 on main FSDP wrapper (#452)

parent c114a219
......@@ -736,7 +736,7 @@ class FullyShardedDataParallel(nn.Module):
# Start of a forward pass.
self.training_state = TrainingState.FORWARD
if self.mixed_precision:
if self._is_root and self.mixed_precision:
args, kwargs = cast_inputs_to_fp16(*args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment