"docs/vscode:/vscode.git/clone" did not exist on "bbd61829a2591b5c8fdfd547565bfd3c08d1d582"
Unverified Commit 3ecf76f4 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[cleanup] CI test updates; mypy cleanup; partial broadcast_object cleanup;...


[cleanup] CI test updates; mypy cleanup; partial broadcast_object cleanup; pre-commit documentation (#744)

* changelog; mypy; oss cleanup

* more broadcast_object cleanup in FSDP

* one more mypy fix

* retire pytorch 1.6 from circleci, add new lightly, add 1.8 LTS and 1.9 stable release

* update torch version for LTS

* minor fixes

* update cache key

* trying newer gpu VMs

* bump the cache

* update to gpu.medium, which should be 2 GPUs

* update nightly version

* add pre-commit instruction

* fixed CHANGELOG after merging

* updated to newer nightly

* retained the older broadcast function for older GPUs for oss.py

* fixed a bug

* added a comment

* fixing a test for pytorch 1.10

* testing a fix

* Update fairscale/optim/oss.py

* Update CONTRIBUTING.md
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 95d31d4d
......@@ -35,15 +35,17 @@ gpu: &gpu
CUDA_VERSION: "10.2"
CUDA_HOME: /usr/local/cuda-10.2
machine:
# This image actually has cuda-11.1 installed, but it doesn't seems to affect us
# using pytorch cu10 builds below.
image: ubuntu-1604-cuda-10.2:202012-01
resource_class: gpu.large
gpu_cu111: &gpu_cu111
environment:
CUDA_VERSION: "11.1"
CUDA_HOME: /usr/local/cuda-11.1
CUDA_VERSION: "11.2"
CUDA_HOME: /usr/local/cuda-11.2
machine:
image: ubuntu-1604-cuda-11.1:202012-01
image: ubuntu-2004-cuda-11.2:202103-01
resource_class: gpu.large
# -------------------------------------------------------------------------------------
......@@ -62,21 +64,6 @@ setup_venv: &setup_venv
which pip
pip install --upgrade pip
install_dep_160: &install_dep_160
- run:
name: Install Dependencies with torch 1.6.0
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.6 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "6"], "wrong torch version"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
install_dep_171: &install_dep_171
- run:
name: Install Dependencies with torch 1.7.1
......@@ -94,12 +81,12 @@ install_dep_171: &install_dep_171
install_dep_181: &install_dep_181
- run:
name: Install Dependencies with torch 1.8.1
name: Install Dependencies with torch 1.8.1 (LTS)
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.8.1+cu101 torchvision==0.9.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
......@@ -112,9 +99,9 @@ install_dep_190: &install_dep_190
name: Install Dependencies with torch 1.9.0
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.9 && exit 0; fi
# start installing
pip install --progress-bar off install torch==1.9.0+cu102 torchvision==0.10.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
......@@ -122,6 +109,21 @@ install_dep_190: &install_dep_190
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
install_dep_pytorch_nightly: &install_dep_pytorch_nightly
- run:
name: Install Dependencies with a torch nightly preview build
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi
# start installing
pip install --progress-bar off --pre torch==1.10.0.dev20210901+cu111 torchvision==0.11.0.dev20210901+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "10"], "wrong torch version"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
install_repo: &install_repo
- run:
name: Install Repository
......@@ -337,14 +339,14 @@ jobs:
# Cache the venv directory that contains dependencies
- restore_cache:
keys:
- cache-key-cpu-py39-181-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- cache-key-cpu-py39-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_181
- save_cache:
paths:
- ~/venv
key: cache-key-cpu-py39-181-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key: cache-key-cpu-py39-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
......@@ -358,14 +360,13 @@ jobs:
- store_test_results:
path: test-results
gpu_tests_160:
gpu_tests_171:
parameters:
test_list_file:
type: string
default: "/dev/non_exist"
<<: *gpu
<<: *gpu_cu111
working_directory: ~/fairscale
......@@ -374,22 +375,23 @@ jobs:
- run: nvidia-smi
# Run this to make sure we use python3 from the system.
- setup_pyenv:
version: 3.7.0
version: 3.8.6
- <<: *setup_venv
# Cache the venv directory that contains dependencies
- restore_cache:
keys:
- cache-key-gpu-160-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- cache-key-py38-gpu-171-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_160
- <<: *install_dep_171
- save_cache:
paths:
- ~/venv
key: cache-key-gpu-160-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key: cache-key-py38-gpu-171-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
......@@ -401,13 +403,13 @@ jobs:
- <<: *upload_coverage
gpu_tests_171:
gpu_tests_181:
parameters:
test_list_file:
type: string
default: "/dev/non_exist"
<<: *gpu_cu111
<<: *gpu
working_directory: ~/fairscale
......@@ -418,21 +420,21 @@ jobs:
# Run this to make sure we use python3 from the system.
- setup_pyenv:
version: 3.8.6
version: 3.7.0
- <<: *setup_venv
# Cache the venv directory that contains dependencies
- restore_cache:
keys:
- cache-key-gpu-171-110-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- cache-key-py37-gpu-181-102-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_171
- <<: *install_dep_181
- save_cache:
paths:
- ~/venv
key: cache-key-gpu-171-110-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key: cache-key-py37-gpu-181-102-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
......@@ -444,13 +446,13 @@ jobs:
- <<: *upload_coverage
gpu_tests_181:
gpu_tests_190:
parameters:
test_list_file:
type: string
default: "/dev/non_exist"
<<: *gpu
<<: *gpu_cu111
working_directory: ~/fairscale
......@@ -461,21 +463,21 @@ jobs:
# Run this to make sure we use python3 from the system.
- setup_pyenv:
version: 3.7.0
version: 3.8.6
- <<: *setup_venv
# Cache the venv directory that contains dependencies
- restore_cache:
keys:
- cache-key-gpu-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- cache-key-py38-gpu-190-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_181
- <<: *install_dep_190
- save_cache:
paths:
- ~/venv
key: cache-key-gpu-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key: cache-key-py38-gpu-190-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
......@@ -485,15 +487,13 @@ jobs:
- store_test_results:
path: test-results
- <<: *upload_coverage
gpu_tests_190:
gpu_tests_pytorch_nightly:
parameters:
test_list_file:
type: string
default: "/dev/non_exist"
<<: *gpu
<<: *gpu_cu111
working_directory: ~/fairscale
......@@ -504,21 +504,21 @@ jobs:
# Run this to make sure we use python3 from the system.
- setup_pyenv:
version: 3.7.0
version: 3.8.6
- <<: *setup_venv
# Cache the venv directory that contains dependencies
- restore_cache:
keys:
- cache-key-gpu-190-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- cache-key-py38-gpu-pytorch-nightly-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_190
- <<: *install_dep_pytorch_nightly
- save_cache:
paths:
- ~/venv
key: cache-key-gpu-190-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key: cache-key-py38-gpu-pytorch-nightly-111-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
......@@ -546,7 +546,7 @@ jobs:
# Cache the venv directory that contains dependencies
- restore_cache:
keys:
- cache-key-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- cache-key-py37-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
- restore_cache:
......@@ -558,7 +558,7 @@ jobs:
- save_cache:
paths:
- ~/venv
key: cache-key-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key: cache-key-py37-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
......@@ -595,7 +595,7 @@ jobs:
# Cache the venv directory that contains dependencies
- restore_cache:
keys:
- cache-key-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- cache-key-py37-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
......@@ -608,7 +608,7 @@ jobs:
- save_cache:
paths:
- ~/venv
key: cache-key-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key: cache-key-py37-benchmarks-181-101-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
......@@ -627,29 +627,29 @@ workflows:
- cpu_tests_py37
- cpu_tests_py38
- cpu_tests_py39
- gpu_tests_160:
test_list_file: tests/ci_test_list_1.txt
- gpu_tests_171:
test_list_file: tests/ci_test_list_1.txt
- gpu_tests_181:
test_list_file: tests/ci_test_list_1.txt
- gpu_tests_190:
test_list_file: tests/ci_test_list_1.txt
- gpu_tests_160:
test_list_file: tests/ci_test_list_2.txt
- gpu_tests_pytorch_nightly:
test_list_file: tests/ci_test_list_1.txt
- gpu_tests_171:
test_list_file: tests/ci_test_list_2.txt
- gpu_tests_181:
test_list_file: tests/ci_test_list_2.txt
- gpu_tests_190:
test_list_file: tests/ci_test_list_2.txt
- gpu_tests_160:
test_list_file: tests/ci_test_list_3.txt
- gpu_tests_pytorch_nightly:
test_list_file: tests/ci_test_list_2.txt
- gpu_tests_171:
test_list_file: tests/ci_test_list_3.txt
- gpu_tests_181:
test_list_file: tests/ci_test_list_3.txt
- gpu_tests_190:
test_list_file: tests/ci_test_list_3.txt
- gpu_tests_pytorch_nightly:
test_list_file: tests/ci_test_list_3.txt
- benchmarks_1
- benchmarks_2
......@@ -17,7 +17,7 @@ We actively welcome your pull requests.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
5. Make sure your code passes static analysis (see below).
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
......@@ -50,27 +50,36 @@ outlined on that page and do not file a public issue.
* We follow the [PEP8](https://www.python.org/dev/peps/pep-0008/) style guide.
* In your editor, install the [editorconfig](https://editorconfig.org/) extension
which should ensure that you are following the same standards as us.
* Please run black and isort before opening up your PR.
```
black .
isort .
flake8
```
* Please read the [editorconfig](.editorconfig) file to understand the exact coding style preferences.
* Please place Python code related to models in fairscale/nn. Place Python code related to optimizers
in fairscale/optim. Place C++ extensions in fairscale/clib.
* Please put `__all__:List[str] = []` in new `__init__.py` files for consistent importing behavior
and less development overhead in maintaining an importing list.
* Please setup pre-commit before opening up your PR.
## Testing
### Pre-commit
### Static analysis
```
pip install -r requirements-dev.txt
pre-commit install
```
After the above, your `git commit` command will automatically trigger pre-commit
checks, which are static code analysis tools we use.
### Run statis analysis by hand (without using pre-commit)
Note that, trailing spaces are not checked by the manual commands below, but they are checked by the pre-commit hooks above.
```
black .
isort .
flake8
mypy --ignore-missing-imports --scripts-are-modules --pretty .
```
## Testing
### Unit tests
```
......
......@@ -47,8 +47,8 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) ->
if torch_version()[:2] >= (1, 7):
_forward = torch.jit.script(_forward) # type: ignore
_track_running_stats = torch.jit.script(_track_running_stats) # type: ignore
_forward = torch.jit.script(_forward)
_track_running_stats = torch.jit.script(_track_running_stats)
class _SyncBatchNormFunction(torch.autograd.Function):
......
......@@ -46,7 +46,7 @@ from fairscale.utils.parallel import (
get_process_group_cached,
validate_process_group,
)
from fairscale.utils.params import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
......@@ -526,13 +526,14 @@ class FullyShardedDataParallel(nn.Module):
if self.move_grads_to_cpu:
total_norm = total_norm.cpu()
# Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
# multiply by clip_coef
for p in params_with_grad:
p.grad.detach().mul_(clip_coef.to(p.grad.device)) # type: ignore
assert p.grad is not None
p.grad.detach().mul_(clip_coef.to(p.grad.device))
return total_norm
......@@ -1566,7 +1567,7 @@ class FullyShardedDataParallel(nn.Module):
# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base"):
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group) # type: ignore
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)
......@@ -1828,19 +1829,17 @@ class FullyShardedDataParallel(nn.Module):
raise ValueError(msg)
def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances] from teach rank."""
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances] from each rank."""
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
my_pad_info: List[List[int]] = [cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances]
for rank in range(self.world_size):
if rank == self.rank:
pad_info = [m.numel_padded_per_param for m in self._fsdp_instances]
pad_info = my_pad_info
else:
pad_info = dummy_tensor # type: ignore
pad_info = broadcast_object(
pad_info, src_rank=rank, group=self.process_group, dist_device=self.compute_device
)
pad_info = [[0]] * len(my_pad_info)
dist.broadcast_object_list(pad_info, src=rank, group=self.process_group)
if self.rank == 0:
world_pad_info.append(pad_info) # type: ignore
world_pad_info.append(pad_info)
return world_pad_info
def _gather_optim_state(
......
......@@ -5,6 +5,7 @@
from collections import OrderedDict
import copy
import io
from itertools import chain
import logging
from math import inf
......@@ -17,7 +18,7 @@ from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from fairscale.nn.misc import ParamBucket
from fairscale.utils.params import broadcast_object, calc_grad_norm, get_global_rank, recursive_copy_to_device
from fairscale.utils.params import calc_grad_norm, get_global_rank, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -27,6 +28,54 @@ else:
_params_t = Any
_gpu_is_old: Optional[bool] = None
def _gpu_capabilities_older_than_50() -> bool:
"""Return True if the GPU's compute capability is older than SM50."""
global _gpu_is_old
if _gpu_is_old is None:
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(f"cuda:{i}")
if major <= 5:
_gpu_is_old = True
if _gpu_is_old is None:
_gpu_is_old = False
return _gpu_is_old
def _broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
This is only needed for some older GPUs where dist.broadcast_object_list seems to hang. Also
the hang behavior persist across processes once it happens. I.e. once we call dist.broadcast_object_list,
subsequent calls with _broadcast_object also hang.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
......@@ -285,17 +334,30 @@ class OSS(Optimizer):
if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=dist_device)
)
broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device,
)
if _gpu_capabilities_older_than_50():
_broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device
)
else:
obj_list = [state_to_share]
dist.broadcast_object_list(
obj_list, src=self.global_rank, group=self.group,
)
else:
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=dist_device,
)
if _gpu_capabilities_older_than_50():
replica_state = _broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=dist_device,
)
else:
obj_list = [torch.tensor([0], dtype=torch.uint8, device=dist_device)]
dist.broadcast_object_list(
obj_list, src=self._local_to_global_rank[rank], group=self.group,
)
replica_state = obj_list[0]
if should_collect_state:
self._all_states.append(
......
......@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
from collections import abc
import io
from math import inf
from typing import Any, Callable, Dict, List, Optional
......@@ -56,36 +55,6 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return value
# backward compatibility - this is needed for torch 1.5 which does not expose this functionality
# FIXME: to be dropped alongside torch1.5 support, when time comes
def broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:
r"""Calculate gradient norm of an iterable of parameters.
Returns:
......
......@@ -27,7 +27,7 @@ class Bucket:
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base"):
dist._reduce_scatter_base( # type: ignore
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group
)
else:
......@@ -132,7 +132,7 @@ class ReduceScatterBucketer:
output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base"):
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group) # type: ignore
dist._reduce_scatter_base(output, input_flattened, group=group)
else:
# fallback
dist.reduce_scatter(output, input_list, group=group)
......
......@@ -79,4 +79,4 @@ if __name__ == "__main__":
# Bump this number if you want to force a CI cache invalidation on the pip venv.
# CI cache version: 4
# CI cache version: 8
......@@ -29,6 +29,7 @@ from . import optim as optim
from . import nn as nn
from . import testing as testing
from . import utils as utils
from . import jit as jit
#MODIFIED BY TORCHGPIPE
from . import backends
......
......@@ -46,6 +46,9 @@ def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def reduce_scatter(tensor: Tensor, input_list: List[Tensor], op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
# These two functions takes flatten tensors directly, avoiding internal buffer allocations overheads.
def _all_gather_base(input_tensor: Tensor, output_tensor: Tensor, group:Optional[ProcessGroup] = None): ...
def _reduce_scatter_base(output_tensor: Tensor, input_tensor: Tensor, group:Optional[ProcessGroup] = None): ...
def destroy_process_group() -> None: ...
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Callable
def script(fn: Callable) -> Callable: ...
......@@ -113,6 +113,10 @@ class DistributedTest(unittest.TestCase):
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
if config.get("cpu_offload", False):
# In pytorch 1.10, assert_allclose below checks for tensor device match. Therefore,
# we need to move the CPU tensor to CUDA in case we are doing cpu_offload.
shard_loss = shard_loss.cuda()
shard_state_dict = model.state_dict()
try:
......
......@@ -92,7 +92,6 @@ class TestOptimizerUtils(DistributedTest):
tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart
# Switching from fairscale.utils.params.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
......
......@@ -22,7 +22,6 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim
import fairscale.utils as utils
from fairscale.utils import torch_version
from fairscale.utils.testing import (
check_same_model_params,
......@@ -36,15 +35,6 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
RECIPIENT_RANK = 1
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from fairscale.utils.params import broadcast_object # noqa
_torch_broadcast_object = False
def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
url = "file://" + tempfile_name
......@@ -52,15 +42,9 @@ def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch.device) -> Any:
if _torch_broadcast_object:
package = [something_to_sync]
dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD)
package_sync = package[0]
else:
package_sync = utils.params.broadcast_object(
something_to_sync, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
package = [something_to_sync]
dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD)
package_sync = package[0]
return package_sync
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment