"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a46d97c96dfb2f7f9ddc7f4f889d9856b46428ad"
Unverified Commit 5c5866b3 authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

Add is root check to only cast to FP16 on main FSDP wrapper (#452)

parent c114a219
...@@ -736,7 +736,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -736,7 +736,7 @@ class FullyShardedDataParallel(nn.Module):
# Start of a forward pass. # Start of a forward pass.
self.training_state = TrainingState.FORWARD self.training_state = TrainingState.FORWARD
if self.mixed_precision: if self._is_root and self.mixed_precision:
args, kwargs = cast_inputs_to_fp16(*args, **kwargs) args, kwargs = cast_inputs_to_fp16(*args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to # All-gather full parameters. This will also transfer FP32 parameters to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment