Unverified Commit c6c66584 authored by Wing Lian's avatar Wing Lian Committed by GitHub
Browse files

Fix check for backword_pos (#23075)

parent f31a510b
...@@ -458,7 +458,9 @@ class Trainer: ...@@ -458,7 +458,9 @@ class Trainer:
self.fsdp = ShardingStrategy.NO_SHARD self.fsdp = ShardingStrategy.NO_SHARD
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch: if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get(
"backward_prefetch", []
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
self.forward_prefetch = False self.forward_prefetch = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment