Unverified Commit 462b7f3a authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fixing fsdp autowrap functionality (#17922)

* fixing fsdp autowrap functionality

* update version and quality

* update torch version to latest stable version
parent 3a064bd4
...@@ -384,13 +384,12 @@ class Trainer: ...@@ -384,13 +384,12 @@ class Trainer:
if args.local_rank == -1: if args.local_rank == -1:
raise ValueError("Using fsdp only works in distributed training.") raise ValueError("Using fsdp only works in distributed training.")
# dep_version_check("torch>=1.12.0.dev20220418+cu113") # dep_version_check("torch>=1.12.0")
# Would have to update setup.py with torch>=1.12.0.dev20220418+cu113 # Would have to update setup.py with torch>=1.12.0
# which isn't ideally given that it's a dev version # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
# and it will force people not using FSDP to also use torch>=1.12.0.dev20220418+cu113
# below is the current alternative. # below is the current alternative.
if version.parse(torch.__version__) < version.parse("1.12.0.dev20220418+cu113"): if version.parse(torch.__version__) < version.parse("1.12.0"):
raise ValueError("FSDP requires PyTorch >= 1.12.0.dev20220418+cu113") raise ValueError("FSDP requires PyTorch >= 1.12.0")
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
...@@ -1285,7 +1284,7 @@ class Trainer: ...@@ -1285,7 +1284,7 @@ class Trainer:
# PyTorch FSDP! # PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import default_auto_wrap_policy from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
if FSDPOption.OFFLOAD in self.args.fsdp: if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True) cpu_offload = CPUOffload(offload_params=True)
...@@ -1296,7 +1295,7 @@ class Trainer: ...@@ -1296,7 +1295,7 @@ class Trainer:
if FSDPOption.AUTO_WRAP in self.args.fsdp: if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_min_num_params > 0: if self.args.fsdp_min_num_params > 0:
auto_wrap_policy = functools.partial( auto_wrap_policy = functools.partial(
default_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
) )
if type(model) != FSDP: if type(model) != FSDP:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment