Unverified Commit 3ecf76f4 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[cleanup] CI test updates; mypy cleanup; partial broadcast_object cleanup;...


[cleanup] CI test updates; mypy cleanup; partial broadcast_object cleanup; pre-commit documentation (#744)

* changelog; mypy; oss cleanup

* more broadcast_object cleanup in FSDP

* one more mypy fix

* retire pytorch 1.6 from circleci, add new lightly, add 1.8 LTS and 1.9 stable release

* update torch version for LTS

* minor fixes

* update cache key

* trying newer gpu VMs

* bump the cache

* update to gpu.medium, which should be 2 GPUs

* update nightly version

* add pre-commit instruction

* fixed CHANGELOG after merging

* updated to newer nightly

* retained the older broadcast function for older GPUs for oss.py

* fixed a bug

* added a comment

* fixing a test for pytorch 1.10

* testing a fix

* Update fairscale/optim/oss.py

* Update CONTRIBUTING.md
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 95d31d4d
...@@ -35,15 +35,17 @@ gpu: &gpu ...@@ -35,15 +35,17 @@ gpu: &gpu
CUDA_VERSION: "10.2" CUDA_VERSION: "10.2"
CUDA_HOME: /usr/local/cuda-10.2 CUDA_HOME: /usr/local/cuda-10.2
machine: machine:
# This image actually has cuda-11.1 installed, but it doesn't seems to affect us
# using pytorch cu10 builds below.
image: ubuntu-1604-cuda-10.2:202012-01 image: ubuntu-1604-cuda-10.2:202012-01
resource_class: gpu.large resource_class: gpu.large
gpu_cu111: &gpu_cu111 gpu_cu111: &gpu_cu111
environment: environment:
CUDA_VERSION: "11.1" CUDA_VERSION: "11.2"
CUDA_HOME: /usr/local/cuda-11.1 CUDA_HOME: /usr/local/cuda-11.2
machine: machine:
image: ubuntu-1604-cuda-11.1:202012-01 image: ubuntu-2004-cuda-11.2:202103-01
resource_class: gpu.large resource_class: gpu.large
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
...@@ -62,21 +64,6 @@ setup_venv: &setup_venv ...@@ -62,21 +64,6 @@ setup_venv: &setup_venv
which pip which pip
pip install --upgrade pip pip install --upgrade pip
install_dep_160: &install_dep_160
- run:
name: Install Dependencies with torch 1.6.0
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.6 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "6"], "wrong torch version"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
install_dep_171: &install_dep_171 install_dep_171: &install_dep_171
- run: - run:
name: Install Dependencies with torch 1.7.1 name: Install Dependencies with torch 1.7.1
...@@ -94,12 +81,12 @@ install_dep_171: &install_dep_171 ...@@ -94,12 +81,12 @@ install_dep_171: &install_dep_171
install_dep_181: &install_dep_181 install_dep_181: &install_dep_181
- run: - run:
name: Install Dependencies with torch 1.8.1 name: Install Dependencies with torch 1.8.1 (LTS)
command: | command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing # start installing
pip install --progress-bar off torch==1.8.1+cu101 torchvision==0.9.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
...@@ -112,9 +99,9 @@ install_dep_190: &install_dep_190 ...@@ -112,9 +99,9 @@ install_dep_190: &install_dep_190
name: Install Dependencies with torch 1.9.0 name: Install Dependencies with torch 1.9.0
command: | command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.9 && exit 0; fi
# start installing # start installing
pip install --progress-bar off install torch==1.9.0+cu102 torchvision==0.10.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
...@@ -122,6 +109,21 @@ install_dep_190: &install_dep_190 ...@@ -122,6 +109,21 @@ install_dep_190: &install_dep_190
python -m torch.utils.collect_env python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
install_dep_pytorch_nightly: &install_dep_pytorch_nightly
- run:
name: Install Dependencies with a torch nightly preview build
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi
# start installing
pip install --progress-bar off --pre torch==1.10.0.dev20210901+cu111 torchvision==0.11.0.dev20210901+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "10"], "wrong torch version"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
install_repo: &install_repo install_repo: &install_repo
- run: - run:
name: Install Repository name: Install Repository
...@@ -337,14 +339,14 @@ jobs: ...@@ -337,14 +339,14 @@ jobs:
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py39-181-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-py39-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_181 - <<: *install_dep_181
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-cpu-py39-181-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-cpu-py39-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -358,14 +360,13 @@ jobs: ...@@ -358,14 +360,13 @@ jobs:
- store_test_results: - store_test_results:
path: test-results path: test-results
gpu_tests_171:
gpu_tests_160:
parameters: parameters:
test_list_file: test_list_file:
type: string type: string
default: "/dev/non_exist" default: "/dev/non_exist"
<<: *gpu <<: *gpu_cu111
working_directory: ~/fairscale working_directory: ~/fairscale
...@@ -374,22 +375,23 @@ jobs: ...@@ -374,22 +375,23 @@ jobs:
- run: nvidia-smi - run: nvidia-smi
# Run this to make sure we use python3 from the system.
- setup_pyenv: - setup_pyenv:
version: 3.7.0 version: 3.8.6
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-gpu-160-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py38-gpu-171-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_160 - <<: *install_dep_171
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-gpu-160-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py38-gpu-171-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -401,13 +403,13 @@ jobs: ...@@ -401,13 +403,13 @@ jobs:
- <<: *upload_coverage - <<: *upload_coverage
gpu_tests_171: gpu_tests_181:
parameters: parameters:
test_list_file: test_list_file:
type: string type: string
default: "/dev/non_exist" default: "/dev/non_exist"
<<: *gpu_cu111 <<: *gpu
working_directory: ~/fairscale working_directory: ~/fairscale
...@@ -418,21 +420,21 @@ jobs: ...@@ -418,21 +420,21 @@ jobs:
# Run this to make sure we use python3 from the system. # Run this to make sure we use python3 from the system.
- setup_pyenv: - setup_pyenv:
version: 3.8.6 version: 3.7.0
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-gpu-171-110-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py37-gpu-181-102-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_171 - <<: *install_dep_181
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-gpu-171-110-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py37-gpu-181-102-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -444,13 +446,13 @@ jobs: ...@@ -444,13 +446,13 @@ jobs:
- <<: *upload_coverage - <<: *upload_coverage
gpu_tests_181: gpu_tests_190:
parameters: parameters:
test_list_file: test_list_file:
type: string type: string
default: "/dev/non_exist" default: "/dev/non_exist"
<<: *gpu <<: *gpu_cu111
working_directory: ~/fairscale working_directory: ~/fairscale
...@@ -461,21 +463,21 @@ jobs: ...@@ -461,21 +463,21 @@ jobs:
# Run this to make sure we use python3 from the system. # Run this to make sure we use python3 from the system.
- setup_pyenv: - setup_pyenv:
version: 3.7.0 version: 3.8.6
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-gpu-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py38-gpu-190-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_181 - <<: *install_dep_190
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-gpu-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py38-gpu-190-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -485,15 +487,13 @@ jobs: ...@@ -485,15 +487,13 @@ jobs:
- store_test_results: - store_test_results:
path: test-results path: test-results
- <<: *upload_coverage gpu_tests_pytorch_nightly:
gpu_tests_190:
parameters: parameters:
test_list_file: test_list_file:
type: string type: string
default: "/dev/non_exist" default: "/dev/non_exist"
<<: *gpu <<: *gpu_cu111
working_directory: ~/fairscale working_directory: ~/fairscale
...@@ -504,21 +504,21 @@ jobs: ...@@ -504,21 +504,21 @@ jobs:
# Run this to make sure we use python3 from the system. # Run this to make sure we use python3 from the system.
- setup_pyenv: - setup_pyenv:
version: 3.7.0 version: 3.8.6
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-gpu-190-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py38-gpu-pytorch-nightly-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_190 - <<: *install_dep_pytorch_nightly
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-gpu-190-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py38-gpu-pytorch-nightly-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -546,7 +546,7 @@ jobs: ...@@ -546,7 +546,7 @@ jobs:
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py37-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data # Cache the MNIST directory that contains benchmark data
- restore_cache: - restore_cache:
...@@ -558,7 +558,7 @@ jobs: ...@@ -558,7 +558,7 @@ jobs:
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py37-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -595,7 +595,7 @@ jobs: ...@@ -595,7 +595,7 @@ jobs:
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py37-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data # Cache the MNIST directory that contains benchmark data
...@@ -608,7 +608,7 @@ jobs: ...@@ -608,7 +608,7 @@ jobs:
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py37-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -627,29 +627,29 @@ workflows: ...@@ -627,29 +627,29 @@ workflows:
- cpu_tests_py37 - cpu_tests_py37
- cpu_tests_py38 - cpu_tests_py38
- cpu_tests_py39 - cpu_tests_py39
- gpu_tests_160:
test_list_file: tests/ci_test_list_1.txt
- gpu_tests_171: - gpu_tests_171:
test_list_file: tests/ci_test_list_1.txt test_list_file: tests/ci_test_list_1.txt
- gpu_tests_181: - gpu_tests_181:
test_list_file: tests/ci_test_list_1.txt test_list_file: tests/ci_test_list_1.txt
- gpu_tests_190: - gpu_tests_190:
test_list_file: tests/ci_test_list_1.txt test_list_file: tests/ci_test_list_1.txt
- gpu_tests_160: - gpu_tests_pytorch_nightly:
test_list_file: tests/ci_test_list_2.txt test_list_file: tests/ci_test_list_1.txt
- gpu_tests_171: - gpu_tests_171:
test_list_file: tests/ci_test_list_2.txt test_list_file: tests/ci_test_list_2.txt
- gpu_tests_181: - gpu_tests_181:
test_list_file: tests/ci_test_list_2.txt test_list_file: tests/ci_test_list_2.txt
- gpu_tests_190: - gpu_tests_190:
test_list_file: tests/ci_test_list_2.txt test_list_file: tests/ci_test_list_2.txt
- gpu_tests_160: - gpu_tests_pytorch_nightly:
test_list_file: tests/ci_test_list_3.txt test_list_file: tests/ci_test_list_2.txt
- gpu_tests_171: - gpu_tests_171:
test_list_file: tests/ci_test_list_3.txt test_list_file: tests/ci_test_list_3.txt
- gpu_tests_181: - gpu_tests_181:
test_list_file: tests/ci_test_list_3.txt test_list_file: tests/ci_test_list_3.txt
- gpu_tests_190: - gpu_tests_190:
test_list_file: tests/ci_test_list_3.txt test_list_file: tests/ci_test_list_3.txt
- gpu_tests_pytorch_nightly:
test_list_file: tests/ci_test_list_3.txt
- benchmarks_1 - benchmarks_1
- benchmarks_2 - benchmarks_2
...@@ -17,7 +17,7 @@ We actively welcome your pull requests. ...@@ -17,7 +17,7 @@ We actively welcome your pull requests.
2. If you've added code that should be tested, add tests. 2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation. 3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes. 4. Ensure the test suite passes.
5. Make sure your code lints. 5. Make sure your code passes static analysis (see below).
6. If you haven't already, complete the Contributor License Agreement ("CLA"). 6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA") ## Contributor License Agreement ("CLA")
...@@ -50,27 +50,36 @@ outlined on that page and do not file a public issue. ...@@ -50,27 +50,36 @@ outlined on that page and do not file a public issue.
* We follow the [PEP8](https://www.python.org/dev/peps/pep-0008/) style guide. * We follow the [PEP8](https://www.python.org/dev/peps/pep-0008/) style guide.
* In your editor, install the [editorconfig](https://editorconfig.org/) extension * In your editor, install the [editorconfig](https://editorconfig.org/) extension
which should ensure that you are following the same standards as us. which should ensure that you are following the same standards as us.
* Please run black and isort before opening up your PR.
```
black .
isort .
flake8
```
* Please read the [editorconfig](.editorconfig) file to understand the exact coding style preferences. * Please read the [editorconfig](.editorconfig) file to understand the exact coding style preferences.
* Please place Python code related to models in fairscale/nn. Place Python code related to optimizers * Please place Python code related to models in fairscale/nn. Place Python code related to optimizers
in fairscale/optim. Place C++ extensions in fairscale/clib. in fairscale/optim. Place C++ extensions in fairscale/clib.
* Please put `__all__:List[str] = []` in new `__init__.py` files for consistent importing behavior * Please put `__all__:List[str] = []` in new `__init__.py` files for consistent importing behavior
and less development overhead in maintaining an importing list. and less development overhead in maintaining an importing list.
* Please setup pre-commit before opening up your PR.
## Testing ### Pre-commit
### Static analysis ```
pip install -r requirements-dev.txt
pre-commit install
```
After the above, your `git commit` command will automatically trigger pre-commit
checks, which are static code analysis tools we use.
### Run statis analysis by hand (without using pre-commit)
Note that, trailing spaces are not checked by the manual commands below, but they are checked by the pre-commit hooks above.
``` ```
black .
isort .
flake8
mypy --ignore-missing-imports --scripts-are-modules --pretty . mypy --ignore-missing-imports --scripts-are-modules --pretty .
``` ```
## Testing
### Unit tests ### Unit tests
``` ```
......
...@@ -47,8 +47,8 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) -> ...@@ -47,8 +47,8 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) ->
if torch_version()[:2] >= (1, 7): if torch_version()[:2] >= (1, 7):
_forward = torch.jit.script(_forward) # type: ignore _forward = torch.jit.script(_forward)
_track_running_stats = torch.jit.script(_track_running_stats) # type: ignore _track_running_stats = torch.jit.script(_track_running_stats)
class _SyncBatchNormFunction(torch.autograd.Function): class _SyncBatchNormFunction(torch.autograd.Function):
......
...@@ -46,7 +46,7 @@ from fairscale.utils.parallel import ( ...@@ -46,7 +46,7 @@ from fairscale.utils.parallel import (
get_process_group_cached, get_process_group_cached,
validate_process_group, validate_process_group,
) )
from fairscale.utils.params import broadcast_object, calc_grad_norm, recursive_copy_to_device from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.utils.state_dict import replace_by_prefix_
...@@ -526,13 +526,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -526,13 +526,14 @@ class FullyShardedDataParallel(nn.Module):
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
total_norm = total_norm.cpu() total_norm = total_norm.cpu()
# Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq) # Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1: if clip_coef < 1:
# multiply by clip_coef # multiply by clip_coef
for p in params_with_grad: for p in params_with_grad:
p.grad.detach().mul_(clip_coef.to(p.grad.device)) # type: ignore assert p.grad is not None
p.grad.detach().mul_(clip_coef.to(p.grad.device))
return total_norm return total_norm
...@@ -1566,7 +1567,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1566,7 +1567,7 @@ class FullyShardedDataParallel(nn.Module):
# Fill output_tensor with (p.data for each shard in self.world_size) # Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base"): if hasattr(dist, "_all_gather_base"):
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather. # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group) # type: ignore dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else: else:
chunks = list(output_tensor.chunk(self.world_size)) chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group) dist.all_gather(chunks, p_data, group=self.process_group)
...@@ -1828,19 +1829,17 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1828,19 +1829,17 @@ class FullyShardedDataParallel(nn.Module):
raise ValueError(msg) raise ValueError(msg)
def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]: def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances] from teach rank.""" """Collect [x.numel_padded_per_param for x in self._fsdp_instances] from each rank."""
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world. world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
my_pad_info: List[List[int]] = [cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances]
for rank in range(self.world_size): for rank in range(self.world_size):
if rank == self.rank: if rank == self.rank:
pad_info = [m.numel_padded_per_param for m in self._fsdp_instances] pad_info = my_pad_info
else: else:
pad_info = dummy_tensor # type: ignore pad_info = [[0]] * len(my_pad_info)
pad_info = broadcast_object( dist.broadcast_object_list(pad_info, src=rank, group=self.process_group)
pad_info, src_rank=rank, group=self.process_group, dist_device=self.compute_device
)
if self.rank == 0: if self.rank == 0:
world_pad_info.append(pad_info) # type: ignore world_pad_info.append(pad_info)
return world_pad_info return world_pad_info
def _gather_optim_state( def _gather_optim_state(
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from collections import OrderedDict from collections import OrderedDict
import copy import copy
import io
from itertools import chain from itertools import chain
import logging import logging
from math import inf from math import inf
...@@ -17,7 +18,7 @@ from torch.nn import Parameter ...@@ -17,7 +18,7 @@ from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from fairscale.nn.misc import ParamBucket from fairscale.nn.misc import ParamBucket
from fairscale.utils.params import broadcast_object, calc_grad_norm, get_global_rank, recursive_copy_to_device from fairscale.utils.params import calc_grad_norm, get_global_rank, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -27,6 +28,54 @@ else: ...@@ -27,6 +28,54 @@ else:
_params_t = Any _params_t = Any
_gpu_is_old: Optional[bool] = None
def _gpu_capabilities_older_than_50() -> bool:
"""Return True if the GPU's compute capability is older than SM50."""
global _gpu_is_old
if _gpu_is_old is None:
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(f"cuda:{i}")
if major <= 5:
_gpu_is_old = True
if _gpu_is_old is None:
_gpu_is_old = False
return _gpu_is_old
def _broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
This is only needed for some older GPUs where dist.broadcast_object_list seems to hang. Also
the hang behavior persist across processes once it happens. I.e. once we call dist.broadcast_object_list,
subsequent calls with _broadcast_object also hang.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
class OSS(Optimizer): class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_. optimizer and shards its state as described by ZeRO_.
...@@ -285,17 +334,30 @@ class OSS(Optimizer): ...@@ -285,17 +334,30 @@ class OSS(Optimizer):
if should_send_state if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=dist_device) else torch.tensor([0], dtype=torch.uint8, device=dist_device)
) )
broadcast_object( if _gpu_capabilities_older_than_50():
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device, _broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device
)
else:
obj_list = [state_to_share]
dist.broadcast_object_list(
obj_list, src=self.global_rank, group=self.group,
) )
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
replica_state = broadcast_object( if _gpu_capabilities_older_than_50():
replica_state = _broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=dist_device), torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank], src_rank=self._local_to_global_rank[rank],
group=self.group, group=self.group,
dist_device=dist_device, dist_device=dist_device,
) )
else:
obj_list = [torch.tensor([0], dtype=torch.uint8, device=dist_device)]
dist.broadcast_object_list(
obj_list, src=self._local_to_global_rank[rank], group=self.group,
)
replica_state = obj_list[0]
if should_collect_state: if should_collect_state:
self._all_states.append( self._all_states.append(
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import abc from collections import abc
import io
from math import inf from math import inf
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -56,36 +55,6 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic ...@@ -56,36 +55,6 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return value return value
# backward compatibility - this is needed for torch 1.5 which does not expose this functionality
# FIXME: to be dropped alongside torch1.5 support, when time comes
def broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor: def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:
r"""Calculate gradient norm of an iterable of parameters. r"""Calculate gradient norm of an iterable of parameters.
Returns: Returns:
......
...@@ -27,7 +27,7 @@ class Bucket: ...@@ -27,7 +27,7 @@ class Bucket:
return return
# reduce-scatter bucket # reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base"): if hasattr(dist, "_reduce_scatter_base"):
dist._reduce_scatter_base( # type: ignore dist._reduce_scatter_base(
self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group
) )
else: else:
...@@ -132,7 +132,7 @@ class ReduceScatterBucketer: ...@@ -132,7 +132,7 @@ class ReduceScatterBucketer:
output = torch.zeros_like(input_list[0]) output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base"): if hasattr(dist, "_reduce_scatter_base"):
input_flattened = torch.cat(input_list) input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group) # type: ignore dist._reduce_scatter_base(output, input_flattened, group=group)
else: else:
# fallback # fallback
dist.reduce_scatter(output, input_list, group=group) dist.reduce_scatter(output, input_list, group=group)
......
...@@ -79,4 +79,4 @@ if __name__ == "__main__": ...@@ -79,4 +79,4 @@ if __name__ == "__main__":
# Bump this number if you want to force a CI cache invalidation on the pip venv. # Bump this number if you want to force a CI cache invalidation on the pip venv.
# CI cache version: 4 # CI cache version: 8
...@@ -29,6 +29,7 @@ from . import optim as optim ...@@ -29,6 +29,7 @@ from . import optim as optim
from . import nn as nn from . import nn as nn
from . import testing as testing from . import testing as testing
from . import utils as utils from . import utils as utils
from . import jit as jit
#MODIFIED BY TORCHGPIPE #MODIFIED BY TORCHGPIPE
from . import backends from . import backends
......
...@@ -46,6 +46,9 @@ def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional ...@@ -46,6 +46,9 @@ def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def reduce_scatter(tensor: Tensor, input_list: List[Tensor], op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def reduce_scatter(tensor: Tensor, input_list: List[Tensor], op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
# These two functions takes flatten tensors directly, avoiding internal buffer allocations overheads.
def _all_gather_base(input_tensor: Tensor, output_tensor: Tensor, group:Optional[ProcessGroup] = None): ...
def _reduce_scatter_base(output_tensor: Tensor, input_tensor: Tensor, group:Optional[ProcessGroup] = None): ...
def destroy_process_group() -> None: ... def destroy_process_group() -> None: ...
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Callable
def script(fn: Callable) -> Callable: ...
...@@ -113,6 +113,10 @@ class DistributedTest(unittest.TestCase): ...@@ -113,6 +113,10 @@ class DistributedTest(unittest.TestCase):
else: else:
assert next(model.parameters()).device == torch.device("cpu") assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
if config.get("cpu_offload", False):
# In pytorch 1.10, assert_allclose below checks for tensor device match. Therefore,
# we need to move the CPU tensor to CUDA in case we are doing cpu_offload.
shard_loss = shard_loss.cuda()
shard_state_dict = model.state_dict() shard_state_dict = model.state_dict()
try: try:
......
...@@ -92,7 +92,6 @@ class TestOptimizerUtils(DistributedTest): ...@@ -92,7 +92,6 @@ class TestOptimizerUtils(DistributedTest):
tstart = time() tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart duration = time() - tstart
# Switching from fairscale.utils.params.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3 cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
......
...@@ -22,7 +22,6 @@ import torch.multiprocessing as mp ...@@ -22,7 +22,6 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim import fairscale.optim as optim
import fairscale.utils as utils
from fairscale.utils import torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import ( from fairscale.utils.testing import (
check_same_model_params, check_same_model_params,
...@@ -36,15 +35,6 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO ...@@ -36,15 +35,6 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
RECIPIENT_RANK = 1 RECIPIENT_RANK = 1
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from fairscale.utils.params import broadcast_object # noqa
_torch_broadcast_object = False
def dist_init(rank, world_size, tempfile_name, backend=BACKEND): def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
url = "file://" + tempfile_name url = "file://" + tempfile_name
...@@ -52,15 +42,9 @@ def dist_init(rank, world_size, tempfile_name, backend=BACKEND): ...@@ -52,15 +42,9 @@ def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch.device) -> Any: def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch.device) -> Any:
if _torch_broadcast_object:
package = [something_to_sync] package = [something_to_sync]
dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD) dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD)
package_sync = package[0] package_sync = package[0]
else:
package_sync = utils.params.broadcast_object(
something_to_sync, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
return package_sync return package_sync
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment