Commit 0753f8b4 authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Add HSDP

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/463

Enable HSDP when training models.

Reviewed By: wat3rBro

Differential Revision: D42658128

fbshipit-source-id: 3c37c3b6c4abaa54d677447ee704f2e18c9d3b26
parent c5bf9222
...@@ -66,10 +66,14 @@ class ShardingAlgorithm(str, Enum): ...@@ -66,10 +66,14 @@ class ShardingAlgorithm(str, Enum):
It matches the strings used in D2Go config with the enum class :class:`ShardingStrategy` used by Pytorch FSDP module: It matches the strings used in D2Go config with the enum class :class:`ShardingStrategy` used by Pytorch FSDP module:
"grad_optim" => ShardingAlgorithm.SHARD_GRAD_OP => ShardingStrategy.SHARD_GRAD_OP "grad_optim" => ShardingAlgorithm.SHARD_GRAD_OP => ShardingStrategy.SHARD_GRAD_OP
"full" => ShardingAlgorithm.FULL_SHARD => ShardingStrategy.FULL_SHARD "full" => ShardingAlgorithm.FULL_SHARD => ShardingStrategy.FULL_SHARD
"hybrid" => ShardingAlgorithm.HYBRID_SHARD => ShardingStrategy.HYBRID_SHARD
"hybrid_zero2" => ShardingAlgorithm.HYBRID_SHARD_ZERO2 => ShardingStrategy._HYBRID_SHARD_ZERO2
""" """
SHARD_GRAD_OP = "grad_optim" SHARD_GRAD_OP = "grad_optim"
FULL_SHARD = "full" FULL_SHARD = "full"
HYBRID_SHARD = "hybrid"
HYBRID_SHARD_ZERO2 = "hybrid_zero2"
def is_fsdp_enabled(cfg): def is_fsdp_enabled(cfg):
...@@ -161,9 +165,20 @@ def build_fsdp( ...@@ -161,9 +165,20 @@ def build_fsdp(
elif sharding_algorithm == ShardingAlgorithm.FULL_SHARD: elif sharding_algorithm == ShardingAlgorithm.FULL_SHARD:
sharding_strategy = ShardingStrategy.FULL_SHARD sharding_strategy = ShardingStrategy.FULL_SHARD
logger.info("Optimizer + Gradient + Horizontal Model Sharding (ZeRO-3) is used") logger.info("Optimizer + Gradient + Horizontal Model Sharding (ZeRO-3) is used")
elif sharding_algorithm == ShardingAlgorithm.HYBRID_SHARD:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
logger.info(
"Optimizer + Gradient + Horizontal Model Sharding (ZeRO-3) within a node is used"
)
elif sharding_algorithm == ShardingAlgorithm.HYBRID_SHARD_ZERO2:
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
logger.info(
"Optimizer + Gradient State Sharding (ZeRO-2) within a node is used"
)
else: else:
raise ValueError( raise ValueError(
f"Invalid sharding algorithm for FSDP. Can be either {ShardingAlgorithm.SHARD_GRAD_OP} or {ShardingAlgorithm.FULL_SHARD}." f"Invalid sharding algorithm for FSDP. Can be {ShardingAlgorithm.SHARD_GRAD_OP}, "
+ f"{ShardingAlgorithm.FULL_SHARD}, {ShardingAlgorithm.HYBRID_SHARD}, or {ShardingAlgorithm.HYBRID_SHARD_ZERO2}."
) )
auto_wrap_policy = ( auto_wrap_policy = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment