Unverified Commit 5eb6b8c7 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[chores]: make CI more efficient and update py39 env a bit (#447)

* [chores]: CI py39 on GPU and more efficiency

* add test list files

* fix

* add test list files

* split benchmark run into 2 runs

* fix 1.8 version and balance benchmarks

* fix

* fix

* fix

* fix

* recording tests

* py39 install fix

* test again

* move tests

* reorg tests

* skip tests for torch 1.8 due to an upstream bug

* removed __init__.py from tests since it confuses pytest

* Revert "removed __init__.py from tests since it confuses pytest"

This reverts commit 7e156ba33dfaa5ed052031780613ec0cb57a45b0.

* don't include __init__ in file list

* notes on __init__.py and added missing ones

* fixed mypy in a test file

* balance test runtime

* better pip install

* balance more

* pip fix

* balance

* balance more, all test should finish within 20m now

* minor license update

* trying cu102

* more doc and addressed Ben's comments

* debugging

* debugging...
parent 5ecac15a
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
# #
# Adopted from # Adopted from
# https://github.com/facebookresearch/detectron2/blob/master/.circleci/config.yml # https://github.com/facebookresearch/detectron2/blob/master/.circleci/config.yml
#
# Pro tip: download circle ci cli to validate the config locally during development.
version: 2.1 version: 2.1
...@@ -25,6 +27,8 @@ cpu_py39: &cpu_py39 ...@@ -25,6 +27,8 @@ cpu_py39: &cpu_py39
- image: circleci/python:3.9 - image: circleci/python:3.9
resource_class: medium resource_class: medium
# Here are list of GPU images:
# https://circleci.com/docs/2.0/configuration-reference/#available-linux-gpu-images
gpu: &gpu gpu: &gpu
environment: environment:
CUDA_VERSION: "10.1" CUDA_VERSION: "10.1"
...@@ -32,6 +36,13 @@ gpu: &gpu ...@@ -32,6 +36,13 @@ gpu: &gpu
image: ubuntu-1604-cuda-10.1:201909-23 image: ubuntu-1604-cuda-10.1:201909-23
resource_class: gpu.large resource_class: gpu.large
gpu_cu111: &gpu_cu111
environment:
CUDA_VERSION: "11.1"
machine:
image: ubuntu-1604-cuda-11.1:202012-01
resource_class: gpu.large
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
# Re-usable commands # Re-usable commands
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
...@@ -86,18 +97,41 @@ install_dep_171: &install_dep_171 ...@@ -86,18 +97,41 @@ install_dep_171: &install_dep_171
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"' python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"'
python -m torch.utils.collect_env python -m torch.utils.collect_env
install_dep_171_cu110: &install_dep_171_cu110
- run:
name: Install Dependencies with torch 1.7.1+cu110
command: |
sudo add-apt-repository universe
sudo apt-get update
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"'
python -m torch.utils.collect_env
install_dep_180: &install_dep_180 install_dep_180: &install_dep_180
- run: - run:
name: Install Dependencies with torch 1.8.0 nightly name: Install Dependencies with torch 1.8.0 nightly
command: | command: |
sudo apt-get install -y libopenmpi-dev sudo apt-get install -y libopenmpi-dev
pip install --pre --progress-bar off torch==1.8.0.dev20210128+cu110 -f https://download.pytorch.org/whl/nightly/cu110/torch_nightly.html
pip install --progress-bar off git+https://github.com/min-xu-ai/torch_pg.git@c723ab4#egg=torch-pg
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
# TODO: We don't use 180 to run benchmark yet, because torchvision is not yet available on py39. # Since we are using nightly builds, we bypass the benchmarks req file
# and install ourselves for testing.
#pip install --progress-bar off -r requirements-benchmarks.txt #pip install --progress-bar off -r requirements-benchmarks.txt
# torchvision nightly wants torch 1.9.
pip install --pre --progress-bar off torchtext==0.6.0 \
torchvision==0.9.0.dev20210222+cu112 \
-f https://download.pytorch.org/whl/nightly/cu112/torch_nightly.html
# we only use it a bit in benchmarking, so it might be safe to use 1.8.
pip install --pre --progress-bar off torch==1.8.0.dev20210210+cu112 \
-f https://download.pytorch.org/whl/nightly/cu112/torch_nightly.html
pip install --progress-bar off git+https://github.com/min-xu-ai/torch_pg.git@c723ab4#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "8"], "wrong torch version"' python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "8"], "wrong torch version"'
pip list|grep torch
python -m torch.utils.collect_env python -m torch.utils.collect_env
install_repo_cpu: &install_repo_cpu install_repo_cpu: &install_repo_cpu
...@@ -115,6 +149,13 @@ install_repo_gpu: &install_repo_gpu ...@@ -115,6 +149,13 @@ install_repo_gpu: &install_repo_gpu
export CUDA_HOME=/usr/local/cuda-10.1 export CUDA_HOME=/usr/local/cuda-10.1
pip install -e . pip install -e .
install_repo_gpu_cu111: &install_repo_gpu_cu111
- run:
name: Install Repository
command: |
export CUDA_HOME=/usr/local/cuda-11.1
pip install -e .
run_isort: &run_isort run_isort: &run_isort
- run: - run:
...@@ -140,6 +181,12 @@ run_flake8: &run_flake8 ...@@ -140,6 +181,12 @@ run_flake8: &run_flake8
command: | command: |
flake8 --show-source --statistics flake8 --show-source --statistics
check_test_list: &check_test_list
- run:
name: Verify that unit test list files are correct
command: |
bash ./tests/ci_test_list_check.sh
# TODO (Min): figure out how to do coverage nightly or on-demand. Doing it # TODO (Min): figure out how to do coverage nightly or on-demand. Doing it
# on every commit seems like an overkill since we can easily figure out which # on every commit seems like an overkill since we can easily figure out which
...@@ -209,17 +256,30 @@ run_doc_build: &run_doc_build ...@@ -209,17 +256,30 @@ run_doc_build: &run_doc_build
make singlehtml | tee make.out make singlehtml | tee make.out
! tail make.out | grep -q warning ! tail make.out | grep -q warning
# This is an alias to run all unit tests possible on a platform.
run_unittests: &run_unittests
- run:
name: Run all unit tests.
# We run all and not stopping on failure on CPU since docker time is cheaper.
command: |
pytest --junitxml=test-results/junit.xml --verbose --timeout 60
commands: commands:
run_unittests:
# This is a command (like a function) that run tests from a given test_list_file.
# If test_list_file is not given, this results in an error.
run_unittests_from_list:
parameters: parameters:
test_dir: test_list_file:
type: string type: string
default: "." # Default to run all tests, which may take a long time on GPUs. default: "/dev/non_exist" # Default to error out
steps: steps:
- run: - run:
name: Run Unit Tests name: Run Unit Tests
# we use pytest -x so that it stops on first failure to save GPU time, which is expensive.
command: | command: |
pytest --junitxml=test-results/junit.xml --verbose --timeout 60 <<parameters.test_dir>> if [ ! -f <<parameters.test_list_file>> ]; then exit 1; fi
pytest -x --junitxml=test-results/junit.xml --verbose --timeout 60 `cat <<parameters.test_list_file>>`
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
# Jobs to run # Jobs to run
...@@ -233,6 +293,7 @@ jobs: ...@@ -233,6 +293,7 @@ jobs:
steps: steps:
- checkout - checkout
- <<: *check_test_list
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
...@@ -253,7 +314,7 @@ jobs: ...@@ -253,7 +314,7 @@ jobs:
- <<: *run_black - <<: *run_black
- <<: *run_mypy - <<: *run_mypy
- <<: *run_flake8 - <<: *run_flake8
- run_unittests - <<: *run_unittests
- <<: *run_mpi_unittests - <<: *run_mpi_unittests
- <<: *run_doc_build - <<: *run_doc_build
...@@ -267,6 +328,7 @@ jobs: ...@@ -267,6 +328,7 @@ jobs:
steps: steps:
- checkout - checkout
- <<: *check_test_list
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
...@@ -287,7 +349,7 @@ jobs: ...@@ -287,7 +349,7 @@ jobs:
- <<: *run_black - <<: *run_black
- <<: *run_mypy - <<: *run_mypy
- <<: *run_flake8 - <<: *run_flake8
- run_unittests - <<: *run_unittests
- <<: *run_mpi_unittests - <<: *run_mpi_unittests
- <<: *run_doc_build - <<: *run_doc_build
...@@ -301,12 +363,13 @@ jobs: ...@@ -301,12 +363,13 @@ jobs:
steps: steps:
- checkout - checkout
- <<: *check_test_list
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py39-180-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-py39-180-3-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# py3.9 doesn't work well with torch < 1.8. See this PR: # py3.9 doesn't work well with torch < 1.8. See this PR:
# https://github.com/pytorch/pytorch/pull/50998 # https://github.com/pytorch/pytorch/pull/50998
...@@ -317,7 +380,7 @@ jobs: ...@@ -317,7 +380,7 @@ jobs:
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-cpu-py39-180-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-cpu-py39-180-3-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo_cpu - <<: *install_repo_cpu
...@@ -325,7 +388,7 @@ jobs: ...@@ -325,7 +388,7 @@ jobs:
- <<: *run_black - <<: *run_black
- <<: *run_mypy - <<: *run_mypy
- <<: *run_flake8 - <<: *run_flake8
- run_unittests - <<: *run_unittests
- <<: *run_mpi_unittests - <<: *run_mpi_unittests
- <<: *run_doc_build - <<: *run_doc_build
...@@ -335,9 +398,9 @@ jobs: ...@@ -335,9 +398,9 @@ jobs:
gpu_tests_151: gpu_tests_151:
parameters: parameters:
test_dir: test_list_file:
type: string type: string
default: "." default: "/dev/non_exist"
<<: *gpu <<: *gpu
...@@ -366,17 +429,17 @@ jobs: ...@@ -366,17 +429,17 @@ jobs:
- <<: *install_repo_gpu - <<: *install_repo_gpu
- run_unittests: - run_unittests_from_list:
test_dir: <<parameters.test_dir>> test_list_file: <<parameters.test_list_file>>
- store_test_results: - store_test_results:
path: test-results path: test-results
gpu_tests_160: gpu_tests_160:
parameters: parameters:
test_dir: test_list_file:
type: string type: string
default: "." default: "/dev/non_exist"
<<: *gpu <<: *gpu
...@@ -405,19 +468,19 @@ jobs: ...@@ -405,19 +468,19 @@ jobs:
- <<: *install_repo_gpu - <<: *install_repo_gpu
- run_unittests: - run_unittests_from_list:
test_dir: <<parameters.test_dir>> test_list_file: <<parameters.test_list_file>>
- store_test_results: - store_test_results:
path: test-results path: test-results
gpu_tests_171: gpu_tests_171:
parameters: parameters:
test_dir: test_list_file:
type: string type: string
default: "." default: "/dev/non_exist"
<<: *gpu <<: *gpu_cu111
working_directory: ~/fairscale working_directory: ~/fairscale
...@@ -426,31 +489,32 @@ jobs: ...@@ -426,31 +489,32 @@ jobs:
- run: nvidia-smi - run: nvidia-smi
- run: pyenv global 3.7.0 # Run this to make sure we use python3 from the system.
- run: pyenv global 3.8.6
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-gpu-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-gpu-cu111-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_171 - <<: *install_dep_171_cu110
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-gpu-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-gpu-cu111-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo_gpu - <<: *install_repo_gpu_cu111
- run_unittests: - run_unittests_from_list:
test_dir: <<parameters.test_dir>> test_list_file: <<parameters.test_list_file>>
- store_test_results: - store_test_results:
path: test-results path: test-results
benchmarks: benchmarks_1:
<<: *gpu <<: *gpu
working_directory: ~/fairscale working_directory: ~/fairscale
...@@ -486,14 +550,47 @@ jobs: ...@@ -486,14 +550,47 @@ jobs:
- <<: *run_mp_pipe_benchmark - <<: *run_mp_pipe_benchmark
- <<: *run_oss_benchmark - <<: *run_oss_amp
- <<: *run_oss_for_each
- <<: *run_oss_gloo - <<: *run_oss_gloo
- <<: *run_oss_amp
- <<: *run_oss_for_each
benchmarks_2:
<<: *gpu
working_directory: ~/fairscale
steps:
- checkout
- run: nvidia-smi
- run: pyenv uninstall -f 3.7.0
- run: pyenv install 3.7.0
- run: pyenv global 3.7.0
- <<: *setup_venv
# Cache the venv directory that contains dependencies
- restore_cache:
keys:
- cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_171
- save_cache:
paths:
- ~/venv
key: cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo_gpu
- <<: *run_oss_benchmark
workflows: workflows:
...@@ -504,27 +601,22 @@ workflows: ...@@ -504,27 +601,22 @@ workflows:
- cpu_tests_py38 - cpu_tests_py38
- cpu_tests_py39 - cpu_tests_py39
- gpu_tests_151: - gpu_tests_151:
test_dir: tests/experimental test_list_file: tests/ci_test_list_1.txt
- gpu_tests_160:
test_dir: tests/experimental
- gpu_tests_171:
test_dir: tests/experimental
- gpu_tests_151:
test_dir: tests/nn
- gpu_tests_160: - gpu_tests_160:
test_dir: tests/nn test_list_file: tests/ci_test_list_1.txt
- gpu_tests_171: - gpu_tests_171:
test_dir: tests/nn test_list_file: tests/ci_test_list_1.txt
- gpu_tests_151: - gpu_tests_151:
test_dir: tests/optim test_list_file: tests/ci_test_list_2.txt
- gpu_tests_160: - gpu_tests_160:
test_dir: tests/optim test_list_file: tests/ci_test_list_2.txt
- gpu_tests_171: - gpu_tests_171:
test_dir: tests/optim test_list_file: tests/ci_test_list_2.txt
- gpu_tests_151: - gpu_tests_151:
test_dir: tests/utils test_list_file: tests/ci_test_list_3.txt
- gpu_tests_160: - gpu_tests_160:
test_dir: tests/utils test_list_file: tests/ci_test_list_3.txt
- gpu_tests_171: - gpu_tests_171:
test_dir: tests/utils test_list_file: tests/ci_test_list_3.txt
- benchmarks - benchmarks_1
- benchmarks_2
...@@ -101,10 +101,15 @@ def validate_benchmark(measurements, final_loss, args, check_regression): ...@@ -101,10 +101,15 @@ def validate_benchmark(measurements, final_loss, args, check_regression):
logging.info(f"[{rank}] : Median speed: {median:.2f} +/- {mad:.2f}") logging.info(f"[{rank}] : Median speed: {median:.2f} +/- {mad:.2f}")
if check_regression and rank == 0: if check_regression and rank == 0:
assert (median + 3.0 * mad) > golden_data["reference_speed"], "Speed regression detected" assert median + 3.0 * mad > golden_data["reference_speed"], (
assert max_memory < 1.05 * golden_data["reference_memory"], "Memory use regression detected" f"Speed regression detected: " f"{median + 3.0 * mad} vs. {golden_data['reference_speed']}"
assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-3, "Loss regression detected" )
assert max_memory < 1.05 * golden_data["reference_memory"], (
f"Memory use regression detected: " f"{max_memory} vs. {1.05* golden_data['reference_memory']}"
)
assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-3, (
f"Loss regression detected: " f"{final_loss} vs. {golden_data['reference_loss']}"
)
logging.info("[Regression Test] VALID") logging.info("[Regression Test] VALID")
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import os import os
import re import re
...@@ -55,7 +58,7 @@ else: ...@@ -55,7 +58,7 @@ else:
if __name__ == "__main__": if __name__ == "__main__":
setuptools.setup( setuptools.setup(
name="fairscale", name="fairscale",
description="fairscale: A PyTorch library for large-scale and high-performance training.", description="FairScale: A PyTorch library for large-scale and high-performance training.",
version=find_version("fairscale/__init__.py"), version=find_version("fairscale/__init__.py"),
setup_requires=["ninja"], # ninja is required to build extensions setup_requires=["ninja"], # ninja is required to build extensions
install_requires=fetch_requirements(), install_requires=fetch_requirements(),
......
...@@ -2,3 +2,19 @@ ...@@ -2,3 +2,19 @@
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
#
#
# We need to have __init__.py in tests dir due to a pytest issue.
#
# if you have:
# tests/
# aa/test_name.py
# bb/test_name.py
#
# running `pytest tests` will give an error like "import file mismatch"
# because it can't distinguish between the file in `aa` and `bb` with
# the same file name. Add __init__.py file fixes it.
#
# However, `pytest tests/__init__.py` triggers running tests that's
# not related. So we just don't include any __init__.py in the test
# list files.
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/data_parallel/test_fsdp.py
tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_inplace.py
tests/nn/pipe_process/test_bugs.py
tests/nn/pipe_process/conftest.py
tests/nn/pipe_process/test_rpc.py
tests/nn/model_parallel/test_initialize.py
tests/nn/model_parallel/test_random.py
tests/nn/model_parallel/test_cross_entropy.py
tests/nn/model_parallel/test_layers.py
tests/nn/pipe/test_microbatch.py
tests/nn/pipe/test_checkpoint.py
tests/nn/pipe/test_worker.py
tests/nn/pipe/test_balance.py
tests/nn/pipe/test_pipe.py
tests/nn/pipe/test_transparency.py
tests/nn/pipe/test_inplace.py
tests/nn/pipe/test_copy.py
tests/nn/pipe/test_bugs.py
tests/nn/pipe/conftest.py
tests/nn/pipe/test_pipeline.py
tests/nn/pipe/test_phony.py
tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_features_sharded_ddp.py
tests/nn/data_parallel/test_pytorch_parity_sharded_ddp.py
tests/nn/pipe/skip/test_gpipe.py
tests/nn/pipe/skip/test_verify_skippables.py
tests/nn/pipe/skip/test_stash_pop.py
tests/nn/pipe/skip/test_api.py
tests/nn/pipe/skip/test_leak.py
tests/nn/pipe/skip/test_portal.py
tests/nn/pipe/skip/test_tracker.py
tests/nn/pipe/skip/test_inspect_skip_layout.py
tests/nn/pipe/test_checkpoint_ddp.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/optim/test_single_node_adascale.py
tests/optim/test_adam.py
tests/optim/test_oss.py
tests/optim/test_oss_adascale.py
tests/optim/test_ddp_adascale.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
#!/bin/sh
# Verify that we don't miss any tests.
find tests -name \*.py -type f| grep -v __init__.py | sort | uniq > /tmp/find.out
cat tests/ci_test_list*.txt | sort | uniq > /tmp/cat.out
if ! diff /tmp/find.out /tmp/cat.out ; then
echo "Unit test is missing from CI"
echo "See the diff above to fix it"
exit 1
fi
...@@ -24,7 +24,14 @@ from torch.optim.optimizer import Optimizer ...@@ -24,7 +24,14 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version
# Current on CI, there appears to be a bug with torch 1.8
# See:
# https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112
# So we skip this file in that case until it is fixed.
if torch_version() >= (1, 8, 0):
pytestmark = pytest.mark.skip
class MySGD(Optimizer): class MySGD(Optimizer):
......
...@@ -150,7 +150,7 @@ class TestMixedPrecision(DistributedTest): ...@@ -150,7 +150,7 @@ class TestMixedPrecision(DistributedTest):
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction # Patch torch.distributed.reduce_scatter to check the dtype of the reduction
orig_reduce_scatter = torch.distributed.reduce_scatter orig_reduce_scatter = torch.distributed.reduce_scatter
model = DeviceAndTypeCheckModule( model: nn.Module = DeviceAndTypeCheckModule(
expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype,
) )
......
...@@ -34,6 +34,13 @@ from fairscale.nn.model_parallel.initialize import ( ...@@ -34,6 +34,13 @@ from fairscale.nn.model_parallel.initialize import (
from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version
# Current on CI, there appears to be a bug with torch 1.8
# See:
# https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112
# So we skip this file in that case until it is fixed.
if torch_version() >= (1, 8, 0):
pytestmark = pytest.mark.skip
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment