Unverified Commit 428110b8 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[docs] minor doc update (#459)

parent 8f77255b
...@@ -17,7 +17,7 @@ FairScale supports: ...@@ -17,7 +17,7 @@ FairScale supports:
* Sharded training: * Sharded training:
* Optimizer state sharding (`fairscale.optim.OSS`) * Optimizer state sharding (`fairscale.optim.OSS`)
* Sharded Data Parallel (SDP) (`fairscale.nn.ShardedDataParallel`) * Sharded Data Parallel (SDP) (`fairscale.nn.ShardedDataParallel`)
* Fully Sharded Data Parallel (FSDP) (`fairscale.nn.FullyShardedDataParallel`) * Fully Sharded Data Parallel (FSDP) (`fairscale.nn.FullyShardedDataParallel`) (PyTorch >= 1.6)
* Optimization at scale: * Optimization at scale:
* AdaScale SGD (`fairscale.optim.AdaScale`) * AdaScale SGD (`fairscale.optim.AdaScale`)
* GPU memory optimization: * GPU memory optimization:
......
...@@ -25,6 +25,7 @@ from fairscale.utils.testing import ( ...@@ -25,6 +25,7 @@ from fairscale.utils.testing import (
get_cycles_per_ms, get_cycles_per_ms,
objects_are_equal, objects_are_equal,
spawn_for_all_world_sizes, spawn_for_all_world_sizes,
torch_version,
) )
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
...@@ -33,10 +34,8 @@ from fairscale.utils.testing import ( ...@@ -33,10 +34,8 @@ from fairscale.utils.testing import (
class DistributedTest(unittest.TestCase): class DistributedTest(unittest.TestCase):
def setUp(self): def setUp(self):
major, minor = torch.__version__.split(".")[:2] if torch_version() < (1, 6, 0):
major, minor = int(major), int(minor) raise unittest.SkipTest("Need pytorch version >= 1.6 due to lack of reduce_scatter")
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast")
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test") raise unittest.SkipTest("CUDA not available, skipping test")
if sys.platform == "win32": if sys.platform == "win32":
......
...@@ -85,7 +85,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test ...@@ -85,7 +85,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
def test_one_iteration(world_size, test_case, fsdp_config): def test_one_iteration(world_size, test_case, fsdp_config):
"""Test FSDP with uneven divide of parameter shards.""" """Test FSDP with uneven divide of parameter shards."""
if torch_version() < (1, 6, 0): if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend") pytest.skip("older pytorch doesn't support reduce_scatter")
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs.") pytest.skip("Not enough GPUs.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment