Unverified Commit fae29959 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[chore] [cleanup]: pytest, pytorch new versions, fix tests (#933)



* update pytest versions

* [test] test related changes

- upgrade to newer pytorch versions
- added function to make test more deterministic on A100 and TF32
- fixed some tests so that they are correctly skipped on a single GPU system

* more fixes

* formatting overly long lines

* format

* better test without trigger a warning

* fix an optim state bug with newer pytorch

- adam optimizer seems to return "step" as a singleton tensor now in the
nightly build
- this fixes it assumeing non-tensor value can still be loaded back by
the optimizer

* improve oss.py

- use min_loss for regression checking is a bit more reliable
- also increased the num epochs from 10 to 12

* small oss.py fix

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 8527c587
...@@ -79,14 +79,14 @@ setup_venv: &setup_venv ...@@ -79,14 +79,14 @@ setup_venv: &setup_venv
pip install --upgrade pip pip install --upgrade pip
# most recent LTS version # most recent LTS version
install_dep_1_8_1: &install_dep_1_8_1 install_dep_1_8_2: &install_dep_1_8_2
- run: - run:
name: Install Dependencies with torch 1.8.1 (LTS) name: Install Dependencies with torch 1.8.2 (LTS)
command: | command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing # start installing
pip install --progress-bar off torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html pip install --progress-bar off torch==1.8.2+cu102 torchvision==0.9.2+cu102 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install --progress-bar off -r requirements-dev.txt pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
...@@ -95,14 +95,14 @@ install_dep_1_8_1: &install_dep_1_8_1 ...@@ -95,14 +95,14 @@ install_dep_1_8_1: &install_dep_1_8_1
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
# most recent stable version # most recent stable version
install_dep_1_10_0: &install_dep_1_10_0 install_dep_1_10_2: &install_dep_1_10_2
- run: - run:
name: Install Dependencies with torch 1.10.0 name: Install Dependencies with torch 1.10.2
command: | command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi
# start installing # start installing
pip install --progress-bar off torch==1.10.0+cu111 torchvision==0.11.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.10.2+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-dev.txt pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
...@@ -115,13 +115,13 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly ...@@ -115,13 +115,13 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
name: Install Dependencies with a torch nightly preview build name: Install Dependencies with a torch nightly preview build
command: | command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.12 && exit 0; fi
# start installing # start installing
pip install --progress-bar off --pre torch==1.11.0.dev20211231+cu111 torchvision==0.12.0.dev20211231+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html pip install --progress-bar off --pre torch==1.12.0.dev20220210+cu113 torchvision==0.13.0.dev20220210+cu113 -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
pip install --progress-bar off -r requirements-dev.txt pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "11"], "wrong torch version"' python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "12"], "wrong torch version"'
python -m torch.utils.collect_env python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
...@@ -161,7 +161,7 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -161,7 +161,7 @@ run_oss_benchmark: &run_oss_benchmark
name: Run OSS Benchmark name: Run OSS Benchmark
command: | command: |
python benchmarks/oss.py --world_size 4 --epochs 2 python benchmarks/oss.py --world_size 4 --epochs 2
python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp --epochs 12
run_oss_gloo: &run_oss_gloo run_oss_gloo: &run_oss_gloo
- run: - run:
...@@ -249,7 +249,7 @@ jobs: ...@@ -249,7 +249,7 @@ jobs:
keys: keys:
- cache-key-cpu-py37-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}} - cache-key-cpu-py37-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_2
- save_cache: - save_cache:
paths: paths:
...@@ -277,7 +277,7 @@ jobs: ...@@ -277,7 +277,7 @@ jobs:
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py38-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}} - cache-key-cpu-py38-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_2
- save_cache: - save_cache:
paths: paths:
...@@ -306,7 +306,7 @@ jobs: ...@@ -306,7 +306,7 @@ jobs:
keys: keys:
- cache-key-cpu-py39-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}} - cache-key-cpu-py39-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_2
- save_cache: - save_cache:
paths: paths:
...@@ -346,7 +346,7 @@ jobs: ...@@ -346,7 +346,7 @@ jobs:
keys: keys:
- cache-key-py-3-9-7-gpu-torch-1-8-1-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}} - cache-key-py-3-9-7-gpu-torch-1-8-1-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}
- <<: *install_dep_1_8_1 - <<: *install_dep_1_8_2
- save_cache: - save_cache:
paths: paths:
...@@ -389,7 +389,7 @@ jobs: ...@@ -389,7 +389,7 @@ jobs:
keys: keys:
- cache-key-py-3-9-7-gpu-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}} - cache-key-py-3-9-7-gpu-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_2
- save_cache: - save_cache:
paths: paths:
...@@ -470,7 +470,7 @@ jobs: ...@@ -470,7 +470,7 @@ jobs:
keys: keys:
- cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}} - cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_2
- save_cache: - save_cache:
paths: paths:
...@@ -520,7 +520,7 @@ jobs: ...@@ -520,7 +520,7 @@ jobs:
keys: keys:
- cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}} - cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_2
- save_cache: - save_cache:
paths: paths:
......
...@@ -109,7 +109,8 @@ def validate_benchmark(measurements, final_loss, args, check_regression): ...@@ -109,7 +109,8 @@ def validate_benchmark(measurements, final_loss, args, check_regression):
assert max_memory < 1.05 * golden_data["reference_memory"], ( assert max_memory < 1.05 * golden_data["reference_memory"], (
f"Memory use regression detected: " f"{max_memory} vs. {1.05* golden_data['reference_memory']}" f"Memory use regression detected: " f"{max_memory} vs. {1.05* golden_data['reference_memory']}"
) )
assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-2, ( # any min_loss < than golden + epsilon is OK.
assert cast(float, final_loss) - golden_data["reference_loss"] < 1e-2, (
f"Loss regression detected: " f"{final_loss} vs. {golden_data['reference_loss']}" f"Loss regression detected: " f"{final_loss} vs. {golden_data['reference_loss']}"
) )
logging.info("[Regression Test] VALID") logging.info("[Regression Test] VALID")
...@@ -176,6 +177,7 @@ def train( ...@@ -176,6 +177,7 @@ def train(
measurements = [] measurements = []
final_loss: Optional[float] = -1.0 final_loss: Optional[float] = -1.0
min_loss = 100.0
need_profiling = args.profile need_profiling = args.profile
for epoch in range(args.epochs): for epoch in range(args.epochs):
...@@ -264,14 +266,21 @@ def train( ...@@ -264,14 +266,21 @@ def train(
logging.info("... State dict collected") logging.info("... State dict collected")
measurements.append(n_items / epoch_runtime) measurements.append(n_items / epoch_runtime)
min_loss = min(min_loss, final_loss)
if dist.get_rank() == 0: if dist.get_rank() == 0:
logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}") logging.info(
f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. "
f"Loss {final_loss:.3f} min loss {min_loss:.3f}"
)
training_stop = time.monotonic() training_stop = time.monotonic()
img_per_sec = n_items / (training_stop - training_start) * args.epochs img_per_sec = n_items / (training_stop - training_start) * args.epochs
logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint") logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
validate_benchmark(measurements, final_loss, args, check_regression) # Use min_loss to check instead of final_loss since the final_loss is a bit random.
# If the training min_loss reaches certain number, we can be reasonably certain the
# training process was correct.
validate_benchmark(measurements, min_loss, args, check_regression)
dist.destroy_process_group() # type: ignore dist.destroy_process_group() # type: ignore
......
...@@ -119,7 +119,7 @@ def _unflatten_optim_state( ...@@ -119,7 +119,7 @@ def _unflatten_optim_state(
if not combined_state: if not combined_state:
return {}, global_to_local_id return {}, global_to_local_id
# copy non tensor state to all global entries # copy non tensor state (like the "step" count) to all global entries
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_global_params))} unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_global_params))}
if non_tensor_state[0].keys() == combined_state[0].keys(): if non_tensor_state[0].keys() == combined_state[0].keys():
......
...@@ -367,13 +367,16 @@ class FullyShardedDataParallel(nn.Module): ...@@ -367,13 +367,16 @@ class FullyShardedDataParallel(nn.Module):
# In a unit test dummy enviromnent, the process_group_reduce_scatter can be None. # In a unit test dummy enviromnent, the process_group_reduce_scatter can be None.
if self.process_group_reduce_scatter is not None: if self.process_group_reduce_scatter is not None:
reduce_scatter_group_size = self.process_group_reduce_scatter.size() reduce_scatter_group_size = self.process_group_reduce_scatter.size()
# Roll back to use the default process group for reduce scatter operation when the world size and reduce scatter process group size are differnt. # Roll back to use the default process group for reduce scatter operation when
# the world size and reduce scatter process group size are differnt.
if self.world_size != reduce_scatter_group_size: if self.world_size != reduce_scatter_group_size:
self.process_group_reduce_scatter = self.process_group self.process_group_reduce_scatter = self.process_group
logging.warn( logging.warn(
"Rolled back to use the default process group for the reduce scatter operation because the reduce_scatter process group" "Rolled back to use the default process group for the reduce scatter "
f"size is {reduce_scatter_group_size}, which is different with the world size {self.world_size}. Please make sure the process_group" "operation because the reduce_scatter process group "
"parameter uses all the available ranks for the optimized performance." f"size is {reduce_scatter_group_size}, which is different with the "
f"world size {self.world_size}. Please make sure the process_group "
"parameter uses all the available ranks for the optimal performance."
) )
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root self.disable_reshard_on_root = disable_reshard_on_root
...@@ -2309,6 +2312,20 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2309,6 +2312,20 @@ class FullyShardedDataParallel(nn.Module):
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict: def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
"""Return a new state dict filtering out the ones like MoE layers, which has
``no_broadcast_optim_state`` flag set.
We also make rooms for the optimizer state on rank 0.
"""
# In PyTorch version 1.12, Adam's `step` state changed from an int to a singleton
# tensor. We convert it back here. Otherwise, the step counter will be treated
# like a singleton tensor and comparison with original state dict would fail.
for _, bufs in osd["state"].items():
if "step" in bufs.keys():
assert type(bufs["step"]) is int or ou.is_singleton_tensor(bufs["step"])
if ou.is_singleton_tensor(bufs["step"]):
bufs["step"] = bufs["step"].item()
# Get uncollected_ids.
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state] uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}} new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0: if self.rank == 0:
......
...@@ -136,6 +136,15 @@ def torch_cuda_version(compiled: bool = False) -> Tuple[int, ...]: ...@@ -136,6 +136,15 @@ def torch_cuda_version(compiled: bool = False) -> Tuple[int, ...]:
return tuple(int(n) for n in numbering) return tuple(int(n) for n in numbering)
def make_cudnn_deterministic() -> None:
"""Make cudnn (matmul) deterministic"""
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# TF32 also make things nondeterministic. Disable it.
torch.backends.cuda.matmul.allow_tf32 = False # type: ignore
torch.backends.cudnn.allow_tf32 = False # type: ignore
def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool: def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool:
""" """
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
...@@ -218,8 +227,7 @@ def test_runner( ...@@ -218,8 +227,7 @@ def test_runner(
) -> None: ) -> None:
# At this point we're in a new process, torch options need to be set again # At this point we're in a new process, torch options need to be set again
if deterministic: if deterministic:
torch.backends.cudnn.deterministic = True make_cudnn_deterministic()
torch.backends.cudnn.benchmark = False
torch.manual_seed(1357) torch.manual_seed(1357)
test_func(rank, *args, **kwargs) test_func(rank, *args, **kwargs)
...@@ -270,8 +278,7 @@ def worker_process( ...@@ -270,8 +278,7 @@ def worker_process(
) )
if torch.cuda.is_available() and not hasattr(torch.backends.cudnn, "flags"): if torch.cuda.is_available() and not hasattr(torch.backends.cudnn, "flags"):
torch.backends.cudnn.benchmark = False make_cudnn_deterministic()
torch.backends.cudnn.deterministic = True
try: try:
with context: with context:
......
...@@ -14,9 +14,9 @@ mypy == 0.910 ...@@ -14,9 +14,9 @@ mypy == 0.910
pre-commit >= 2.15.0 pre-commit >= 2.15.0
# Tools for unit tests & coverage. # Tools for unit tests & coverage.
pytest == 5.4.1 pytest == 7.0.0
pytest-cov == 2.10.0 pytest-cov == 3.0.0
pytest-timeout == 1.4.2 pytest-timeout == 2.1.0
remote-pdb >= 2.1.0 remote-pdb >= 2.1.0
parameterized >= 0.8.1 parameterized >= 0.8.1
......
...@@ -22,6 +22,7 @@ import torch.nn as nn ...@@ -22,6 +22,7 @@ import torch.nn as nn
from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.utils import torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_single_gpu
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or torch_version() < (1, 9, 0), not torch.cuda.is_available() or torch_version() < (1, 9, 0),
...@@ -103,6 +104,7 @@ def create(devices): ...@@ -103,6 +104,7 @@ def create(devices):
@rpc_test() @rpc_test()
@skip_if_single_gpu
def create_multiple_layers(): def create_multiple_layers():
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})] model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=["worker0/cpu", "worker0/cpu"]) pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=["worker0/cpu", "worker0/cpu"])
...@@ -110,6 +112,7 @@ def create_multiple_layers(): ...@@ -110,6 +112,7 @@ def create_multiple_layers():
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def create_multiple_workers(devices): def create_multiple_workers(devices):
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})] model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=devices[:2]) pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=devices[:2])
...@@ -117,6 +120,7 @@ def create_multiple_workers(devices): ...@@ -117,6 +120,7 @@ def create_multiple_workers(devices):
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def parameter_rrefs(devices): def parameter_rrefs(devices):
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})] model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=devices[:2]) pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=devices[:2])
...@@ -149,6 +153,7 @@ def forward_chunks(devices): ...@@ -149,6 +153,7 @@ def forward_chunks(devices):
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@skip_if_single_gpu
def forward_multi(devices, checkpoint): def forward_multi(devices, checkpoint):
device = devices[0].split("/")[1] device = devices[0].split("/")[1]
torch.random.manual_seed(3) torch.random.manual_seed(3)
...@@ -166,6 +171,7 @@ def forward_multi(devices, checkpoint): ...@@ -166,6 +171,7 @@ def forward_multi(devices, checkpoint):
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def backward(devices): def backward(devices):
device = devices[0].split("/")[1] device = devices[0].split("/")[1]
torch.random.manual_seed(3) torch.random.manual_seed(3)
...@@ -183,6 +189,7 @@ def backward(devices): ...@@ -183,6 +189,7 @@ def backward(devices):
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def update(devices): def update(devices):
device = devices[0].split("/")[1] device = devices[0].split("/")[1]
torch.random.manual_seed(3) torch.random.manual_seed(3)
...@@ -223,6 +230,7 @@ def extract_partitions(graph: PipelineModulesGraph, pipeline: DistributedPipelin ...@@ -223,6 +230,7 @@ def extract_partitions(graph: PipelineModulesGraph, pipeline: DistributedPipelin
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def multi_input_multi_output_layers(devices): def multi_input_multi_output_layers(devices):
device = devices[0].split("/")[1] device = devices[0].split("/")[1]
torch.random.manual_seed(3) torch.random.manual_seed(3)
...@@ -289,6 +297,7 @@ class ShardedLinearLayer(nn.Module): ...@@ -289,6 +297,7 @@ class ShardedLinearLayer(nn.Module):
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def auto_graph_extract(devices): def auto_graph_extract(devices):
from fairscale.experimental.nn.distributed_pipeline.trace import make_graph from fairscale.experimental.nn.distributed_pipeline.trace import make_graph
......
...@@ -801,8 +801,11 @@ class MixtureOfExperts(NestedWrappedModule): ...@@ -801,8 +801,11 @@ class MixtureOfExperts(NestedWrappedModule):
if wrapper_config is not None: if wrapper_config is not None:
# we create a process group of size 1 for the expert params # we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group([group.rank()]) # world size 1 means no shard # we also need to pass that group as the reduce_scatter group.
expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config) expert_group = torch.distributed.new_group([group.rank()])
expert = FullyShardedDataParallel(
expert, process_group=expert_group, process_group_reduce_scatter=expert_group, **wrapper_config
)
shared = FullyShardedDataParallel(shared, group, **wrapper_config) shared = FullyShardedDataParallel(shared, group, **wrapper_config)
......
...@@ -13,7 +13,7 @@ from parameterized import parameterized ...@@ -13,7 +13,7 @@ from parameterized import parameterized
import torch import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal from fairscale.utils.testing import DummyProcessGroup, make_cudnn_deterministic, objects_are_equal
from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init
...@@ -64,6 +64,7 @@ class TestGradAcc(DistributedTest): ...@@ -64,6 +64,7 @@ class TestGradAcc(DistributedTest):
@classmethod @classmethod
def _test_grad_acc(self, model, batch_dim, use_no_sync_context=True): def _test_grad_acc(self, model, batch_dim, use_no_sync_context=True):
make_cudnn_deterministic()
# Generate two input batches. We'll test that we get the same grads if # Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync # we train on them sequentially while accumulating grads (with no_sync
# or without no_sync) vs. concatenating the batches and training in one go. # or without no_sync) vs. concatenating the batches and training in one go.
......
...@@ -66,7 +66,7 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): ...@@ -66,7 +66,7 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
# We use strings for precision and flatten instead of bool to # We use strings for precision and flatten instead of bool to
# make the pytest output more readable. # make the pytest output more readable.
@skip_if_no_cuda @skip_if_no_cuda
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2] if torch.cuda.device_count() > 1 else [1])
@pytest.mark.parametrize("precision", ["full", "mixed"]) @pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"]) @pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test(world_size, precision, flatten): def test(world_size, precision, flatten):
......
...@@ -24,11 +24,13 @@ from .test_fsdp import ( ...@@ -24,11 +24,13 @@ from .test_fsdp import (
) )
def first_tensor_numel(dct): def all_tensors_numel_except_for_step(dct):
"""Compute the sum of numel from all tensors from a dict, except when the key is `step`."""
ret = 0
for k, v in dct.items(): for k, v in dct.items():
if torch.is_tensor(v): if k != "step" and torch.is_tensor(v):
return v.numel() ret += v.numel()
return 0 return ret
def assert_equal(a, b): def assert_equal(a, b):
...@@ -123,8 +125,8 @@ class TestOptimizerUtils(DistributedTest): ...@@ -123,8 +125,8 @@ class TestOptimizerUtils(DistributedTest):
assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
assert_equal( assert_equal(
sum([first_tensor_numel(v) for k, v in sd["state"].items()]), sum([all_tensors_numel_except_for_step(v) for k, v in sd["state"].items()]),
sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]), sum([all_tensors_numel_except_for_step(v) for k, v in unwrapped_sd["state"].items()]),
) )
original_shard_sd = fsdp_optim.state_dict() original_shard_sd = fsdp_optim.state_dict()
...@@ -133,8 +135,8 @@ class TestOptimizerUtils(DistributedTest): ...@@ -133,8 +135,8 @@ class TestOptimizerUtils(DistributedTest):
original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu")
# Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks. # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
assert_equal( assert_equal(
[first_tensor_numel(v) for k, v in shard_sd["state"].items()], [all_tensors_numel_except_for_step(v) for k, v in shard_sd["state"].items()],
[first_tensor_numel(v) for k, v in original_shard_sd["state"].items()], [all_tensors_numel_except_for_step(v) for k, v in original_shard_sd["state"].items()],
) )
assert_equal( assert_equal(
[v for k, v in shard_sd["param_groups"][0].items()], [v for k, v in shard_sd["param_groups"][0].items()],
......
...@@ -207,6 +207,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc ...@@ -207,6 +207,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20]) @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"]) @pytest.mark.parametrize("backend", ["gloo", "nccl"])
@pytest.mark.parametrize("device", available_devices) @pytest.mark.parametrize("device", available_devices)
@skip_if_single_gpu
def test_inputs(reduce_buffer_size, backend, device): def test_inputs(reduce_buffer_size, backend, device):
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2 world_size = 2
......
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
...@@ -136,7 +134,7 @@ def run_test_initialize_affine_weight(rank, model_parallel_size, filename, filen ...@@ -136,7 +134,7 @@ def run_test_initialize_affine_weight(rank, model_parallel_size, filename, filen
torch.distributed.get_rank(), error torch.distributed.get_rank(), error
) )
) )
assert error < 1.0e-6 assert error < 1.0e-6, error
# ------------ # ------------
# Row parallel # Row parallel
...@@ -157,7 +155,7 @@ def run_test_initialize_affine_weight(rank, model_parallel_size, filename, filen ...@@ -157,7 +155,7 @@ def run_test_initialize_affine_weight(rank, model_parallel_size, filename, filen
print( print(
" row parallel max error (should be zero) on global rank {}: {}".format(torch.distributed.get_rank(), error) " row parallel max error (should be zero) on global rank {}: {}".format(torch.distributed.get_rank(), error)
) )
assert error < 1.0e-6 assert error < 1.0e-6, error
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
...@@ -217,18 +215,18 @@ def run_test_column_parallel_linear(rank, model_parallel_size, filename, filenam ...@@ -217,18 +215,18 @@ def run_test_column_parallel_linear(rank, model_parallel_size, filename, filenam
error = my_dLdA.sub(linear_layer.weight.grad).abs().max() error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier() torch.distributed.barrier()
print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error)) print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6 assert error < 1.0e-6, error
my_dLdb = torch.split(dLdb, output_size_coeff, dim=0)[rank].contiguous().clone() my_dLdb = torch.split(dLdb, output_size_coeff, dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max() error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier() torch.distributed.barrier()
print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error)) print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6 assert error < 1.0e-6, error
error = dLdX.sub(identity_layer.weight.grad).abs().max() error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier() torch.distributed.barrier()
print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error)) print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6 assert error < 1.0e-6, error
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
...@@ -278,17 +276,17 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename, filename_r ...@@ -278,17 +276,17 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename, filename_r
error = my_dLdA.sub(linear_layer.weight.grad).abs().max() error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier() torch.distributed.barrier()
print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error)) print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6 assert error < 1.0e-6, error
error = dLdb.sub(linear_layer.bias.grad).abs().max() error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier() torch.distributed.barrier()
print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error)) print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6 assert error < 1.0e-6, error
error = dLdX.sub(identity_layer.weight.grad).abs().max() error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier() torch.distributed.barrier()
print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error)) print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6 assert error < 1.0e-6, error
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
......
...@@ -13,6 +13,7 @@ import torch.multiprocessing as mp ...@@ -13,6 +13,7 @@ import torch.multiprocessing as mp
from fairscale.nn import MOELayer, Top2Gate from fairscale.nn import MOELayer, Top2Gate
from fairscale.utils import torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import make_cudnn_deterministic
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required" not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required"
...@@ -68,6 +69,7 @@ def expert_params(device): ...@@ -68,6 +69,7 @@ def expert_params(device):
@pg_test() @pg_test()
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def forward(device): def forward(device):
make_cudnn_deterministic()
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD) num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(4, 16, model_dim).to(device) input = torch.randn(4, 16, model_dim).to(device)
...@@ -85,6 +87,7 @@ def forward(device): ...@@ -85,6 +87,7 @@ def forward(device):
@pg_test() @pg_test()
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def forward_multi(device): def forward_multi(device):
make_cudnn_deterministic()
torch.set_printoptions(threshold=5000) torch.set_printoptions(threshold=5000)
num_local_experts = 4 num_local_experts = 4
model_dim = 4 model_dim = 4
...@@ -128,6 +131,7 @@ class RoundRobinGate(torch.nn.Module): ...@@ -128,6 +131,7 @@ class RoundRobinGate(torch.nn.Module):
@pg_test() @pg_test()
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def forward_routing(device): def forward_routing(device):
make_cudnn_deterministic()
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size() num_experts = dist.get_world_size()
input = torch.randn(4, 16, model_dim).to(device) input = torch.randn(4, 16, model_dim).to(device)
...@@ -149,6 +153,7 @@ def forward_routing(device): ...@@ -149,6 +153,7 @@ def forward_routing(device):
@pg_test() @pg_test()
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def forward_routing_multi(device): def forward_routing_multi(device):
make_cudnn_deterministic()
model_dim = 8 model_dim = 8
num_local_experts = 4 num_local_experts = 4
num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
...@@ -174,6 +179,7 @@ def forward_routing_multi(device): ...@@ -174,6 +179,7 @@ def forward_routing_multi(device):
@pg_test() @pg_test()
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def backward(device): def backward(device):
make_cudnn_deterministic()
loss = torch.nn.MSELoss() loss = torch.nn.MSELoss()
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD) num_experts = dist.get_world_size(dist.group.WORLD)
......
...@@ -27,7 +27,7 @@ def do_test_forward(device): ...@@ -27,7 +27,7 @@ def do_test_forward(device):
gate = Top2Gate(4, 6).to(device) gate = Top2Gate(4, 6).to(device)
capacity = 2 * 12 // 6 capacity = 2 * 12 // 6
l_aux, combine_weights, dispatch_mask = gate(input) l_aux, combine_weights, dispatch_mask = gate(input)
assert pytest.approx(l_aux.item(), 0.0283) assert pytest.approx(l_aux.item(), rel=0.01) == 0.0267, l_aux
assert combine_weights.shape == (12, 6, 4) assert combine_weights.shape == (12, 6, 4)
assert dispatch_mask.shape == (12, 6, 4) assert dispatch_mask.shape == (12, 6, 4)
assert torch.equal(combine_weights.bool(), dispatch_mask) assert torch.equal(combine_weights.bool(), dispatch_mask)
...@@ -35,9 +35,9 @@ def do_test_forward(device): ...@@ -35,9 +35,9 @@ def do_test_forward(device):
assert torch.all(combine_weights >= 0.0) assert torch.all(combine_weights >= 0.0)
assert torch.all(combine_weights <= 1.0) assert torch.all(combine_weights <= 1.0)
weights_sum = torch.sum(combine_weights).item() weights_sum = torch.sum(combine_weights).item()
assert round(weights_sum) == pytest.approx(weights_sum) assert round(weights_sum) == pytest.approx(weights_sum), weights_sum
# For this random seed, we get 12 slots filled. # For this random seed, we get 12 slots filled.
assert weights_sum == pytest.approx(12.0) assert weights_sum == pytest.approx(12.0), weights_sum
def test_forward_cpu(): def test_forward_cpu():
......
...@@ -70,6 +70,7 @@ def load_data(model_type: str) -> Union[DataLoader, Tuple[Any, Any]]: ...@@ -70,6 +70,7 @@ def load_data(model_type: str) -> Union[DataLoader, Tuple[Any, Any]]:
torch.manual_seed(10) torch.manual_seed(10)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# TODO: we should NOT do this download over and over again during test.
train_ds = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) train_ds = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_ds_loader = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=2) train_ds_loader = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=2)
...@@ -204,9 +205,11 @@ def train_vision_model(model: SimpleConvNet, per_layer_scaling=False): ...@@ -204,9 +205,11 @@ def train_vision_model(model: SimpleConvNet, per_layer_scaling=False):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_vision_model() -> None: def test_vision_model() -> None:
# The os.environ below doesn't seem to be enough if the test is run on CI with many other tests together. # The os.environ below doesn't seem to be enough if the test is run on CI with many other tests
# together.
# see: https://app.circleci.com/pipelines/github/facebookresearch/fairscale/4086/workflows/72b1470a-55f8-4a45-afe5-04641b093bef/jobs/45179/tests#failed-test-0 # see: https://app.circleci.com/pipelines/github/facebookresearch/fairscale/4086/workflows/72b1470a-55f8-4a45-afe5-04641b093bef/jobs/45179/tests#failed-test-0
# Skipping for now. # Skipping for now.
# Also, TODO (Min): improving downloading code above before re-enable this.
skip_a_test_if_in_CI() skip_a_test_if_in_CI()
# Remove randomness from various sources while testing. # Remove randomness from various sources while testing.
torch.use_deterministic_algorithms(True) # type: ignore torch.use_deterministic_algorithms(True) # type: ignore
......
...@@ -819,7 +819,7 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -819,7 +819,7 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_no_cuda @skip_if_single_gpu
def test_state_dict_distributed(): def test_state_dict_distributed():
world_size = 2 world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
......
...@@ -21,7 +21,7 @@ from torch.optim.lr_scheduler import LambdaLR ...@@ -21,7 +21,7 @@ from torch.optim.lr_scheduler import LambdaLR
from fairscale.optim import AdaScale from fairscale.optim import AdaScale
from fairscale.utils.golden_testing_data import adascale_test_data from fairscale.utils.golden_testing_data import adascale_test_data
from fairscale.utils.testing import skip_if_no_cuda from fairscale.utils.testing import make_cudnn_deterministic, skip_if_no_cuda
from fairscale.utils.testing_memory import find_tensor_by_shape from fairscale.utils.testing_memory import find_tensor_by_shape
...@@ -63,6 +63,7 @@ def test_loss_accum_cpu(): ...@@ -63,6 +63,7 @@ def test_loss_accum_cpu():
@pytest.mark.parametrize("test_case", adascale_test_data) @pytest.mark.parametrize("test_case", adascale_test_data)
def test_grad_accum(test_case, cpu): def test_grad_accum(test_case, cpu):
"""Test the basic functionality on CPU/GPU with gradient accumulation without DDP""" """Test the basic functionality on CPU/GPU with gradient accumulation without DDP"""
make_cudnn_deterministic()
model = Linear(2, 2, bias=True) model = Linear(2, 2, bias=True)
if not cpu: if not cpu:
if torch.cuda.device_count() < 1: if torch.cuda.device_count() < 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment