Unverified Commit 7d7edf6d authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Setup pre-commit github action and apply pre-commit to all files (#849)

* adding pre-commit files

* applying pre-commit to all files

* adding no-strict-optional argument to mypy in circle ci config

* fix typo

* updating python versions

* [skip ci] remove extra args

* adding python 3.9

* [skip ci] set pre-commit version in requirements-dev.txt

* set CACHE_VERSION

* move linters from circleci to github actions

* update python version

* update python version in benchmarks_2

* moving to python 3.9.7
parent 6f3931a4
...@@ -6,6 +6,14 @@ ...@@ -6,6 +6,14 @@
# https://github.com/facebookresearch/detectron2/blob/main/.circleci/config.yml # https://github.com/facebookresearch/detectron2/blob/main/.circleci/config.yml
# #
# Pro tip: download circle ci cli to validate the config locally during development. # Pro tip: download circle ci cli to validate the config locally during development.
#
# To reset/clean the cache update the CACHE_VERSION variable in project settings
# in the fairscale project in CircleCI. The CACHE_VERSION follows the convention
# v$(FAIRSCALE_VERSION)-${CACHE_NUMBER}. E.g. v0.4.2-1. CACHE_NUMBER must start
# at 1 and increase in whole numbers. When changing the CACHE_VERSION manually
# always set the FAIRSCALE_VERSION value to the fairscale version being tested.
# To reset the cache when not updating the fairscale version, only update the
# CACHE_NUMBER value.
version: 2.1 version: 2.1
orbs: orbs:
...@@ -15,23 +23,26 @@ orbs: ...@@ -15,23 +23,26 @@ orbs:
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
cpu_py37: &cpu_py37 cpu_py37: &cpu_py37
docker: docker:
# python version 3.7.12
- image: circleci/python:3.7 - image: circleci/python:3.7
resource_class: large resource_class: large
cpu_py38: &cpu_py38 cpu_py38: &cpu_py38
docker: docker:
# python version 3.8.12
- image: circleci/python:3.8 - image: circleci/python:3.8
resource_class: large resource_class: large
cpu_py39: &cpu_py39 cpu_py39: &cpu_py39
docker: docker:
# python version 3.9.7
- image: circleci/python:3.9 - image: circleci/python:3.9
resource_class: large resource_class: large
# Here are list of GPU images: # Here is the list of GPU images:
# https://circleci.com/docs/2.0/configuration-reference/#available-linux-gpu-images # https://circleci.com/docs/2.0/configuration-reference/#available-linux-gpu-images
# We need to use multiple gpus for several jobs. the resource_class values are # We need to use multiple gpus for several jobs. The resource_class
# available here T101565170 # values are available here T101565170
# gpu.nvidia.small.multi = 2 gpus with 16 GB ram each # gpu.nvidia.small.multi = 2 gpus with 16 GB ram each
# gpu.nvidia.medium.multi = 4 gpus with 16 GB ram each # gpu.nvidia.medium.multi = 4 gpus with 16 GB ram each
...@@ -122,30 +133,6 @@ install_repo: &install_repo ...@@ -122,30 +133,6 @@ install_repo: &install_repo
# Test import. # Test import.
python -c 'import sys; sys.path = sys.path[1:]; import fairscale' python -c 'import sys; sys.path = sys.path[1:]; import fairscale'
run_isort: &run_isort
- run:
name: Run Linter (isort)
command: |
isort . --check
run_black: &run_black
- run:
name: Run Linter (black)
command: |
black --check .
run_mypy: &run_mypy
- run:
name: Run type-checking (mypy)
command: |
mypy --ignore-missing-imports --scripts-are-modules --pretty .
run_flake8: &run_flake8
- run:
name: Run Linter (flake8)
command: |
flake8 --show-source --statistics
check_test_list: &check_test_list check_test_list: &check_test_list
- run: - run:
name: Verify that unit test list files are correct name: Verify that unit test list files are correct
...@@ -260,21 +247,16 @@ jobs: ...@@ -260,21 +247,16 @@ jobs:
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py37-torch-1-10-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-py37-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_0
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-cpu-py37-torch-1-10-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-cpu-py37-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
- <<: *run_isort
- <<: *run_black
- <<: *run_mypy
- <<: *run_flake8
- <<: *run_unittests - <<: *run_unittests
- <<: *run_doc_build - <<: *run_doc_build
...@@ -294,20 +276,15 @@ jobs: ...@@ -294,20 +276,15 @@ jobs:
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py38-torch-1-10-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-py38-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_0
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-cpu-py38-torch-1-10-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-cpu-py38-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
- <<: *run_isort
- <<: *run_black
- <<: *run_mypy
- <<: *run_flake8
- <<: *run_unittests - <<: *run_unittests
- <<: *run_doc_build - <<: *run_doc_build
...@@ -327,21 +304,16 @@ jobs: ...@@ -327,21 +304,16 @@ jobs:
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py39-torch-1-10-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-py39-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_0
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-cpu-py39-torch-1-10-0-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-cpu-py39-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
- <<: *run_isort
- <<: *run_black
- <<: *run_mypy
- <<: *run_flake8
- <<: *run_unittests - <<: *run_unittests
- <<: *run_doc_build - <<: *run_doc_build
...@@ -365,21 +337,21 @@ jobs: ...@@ -365,21 +337,21 @@ jobs:
# Run this to make sure we use python3 from the system. # Run this to make sure we use python3 from the system.
- setup_pyenv: - setup_pyenv:
version: 3.7.0 version: 3.9.7
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-py37-gpu-torch-1-8-1-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py-3-9-7-gpu-torch-1-8-1-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_dep_1_8_1 - <<: *install_dep_1_8_1
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-py37-gpu-torch-1-8-1-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py-3-9-7-gpu-torch-1-8-1-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -408,21 +380,21 @@ jobs: ...@@ -408,21 +380,21 @@ jobs:
# Run this to make sure we use python3 from the system. # Run this to make sure we use python3 from the system.
- setup_pyenv: - setup_pyenv:
version: 3.8.6 version: 3.9.7
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-py38-gpu-torch-1-10-0-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py-3-9-7-gpu-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_0
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-py38-gpu-torch-1-10-0-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py-3-9-7-gpu-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -449,21 +421,21 @@ jobs: ...@@ -449,21 +421,21 @@ jobs:
# Run this to make sure we use python3 from the system. # Run this to make sure we use python3 from the system.
- setup_pyenv: - setup_pyenv:
version: 3.8.6 version: 3.9.7
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-py38-gpu-pytorch-nightly-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py-3-9-7-gpu-pytorch-nightly-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_dep_pytorch_nightly - <<: *install_dep_pytorch_nightly
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-py38-gpu-pytorch-nightly-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py-3-9-7-gpu-pytorch-nightly-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -484,26 +456,26 @@ jobs: ...@@ -484,26 +456,26 @@ jobs:
- run: nvidia-smi - run: nvidia-smi
- setup_pyenv: - setup_pyenv:
version: 3.7.0 version: 3.9.7
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-py37-benchmarks-torch-1-10-0-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py-3-9-7-benchmarks-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data # Cache the MNIST directory that contains benchmark data
- restore_cache: - restore_cache:
keys: keys:
- cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}} - cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_0
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-py37-benchmarks-torch-1-10-0-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py-3-9-7-benchmarks-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -520,7 +492,7 @@ jobs: ...@@ -520,7 +492,7 @@ jobs:
- save_cache: - save_cache:
paths: paths:
- /tmp/MNIST - /tmp/MNIST
key: cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}} key: cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}}
benchmarks_2: benchmarks_2:
<<: *gpu_cu_11_2_medium_multi <<: *gpu_cu_11_2_medium_multi
...@@ -533,27 +505,27 @@ jobs: ...@@ -533,27 +505,27 @@ jobs:
- run: nvidia-smi - run: nvidia-smi
- setup_pyenv: - setup_pyenv:
version: 3.7.0 version: 3.9.7
- <<: *setup_venv - <<: *setup_venv
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-py37-benchmarks-torch-1-10-0-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-py-3-9-7-benchmarks-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data # Cache the MNIST directory that contains benchmark data
- restore_cache: - restore_cache:
keys: keys:
- cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}} - cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}}
- <<: *install_dep_1_10_0 - <<: *install_dep_1_10_0
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-py37-benchmarks-torch-1-10-0-cuda-11-2-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-py-3-9-7-benchmarks-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -562,7 +534,7 @@ jobs: ...@@ -562,7 +534,7 @@ jobs:
- save_cache: - save_cache:
paths: paths:
- /tmp/MNIST - /tmp/MNIST
key: cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}} key: cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION}}-{{checksum "benchmarks/datasets/mnist.py"}}
workflows: workflows:
......
name: pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
strategy:
matrix:
# make sure python versions are consistent with those used in .circleci/config.yml
python-version: ['3.7.12', '3.8.12', '3.9.7']
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: pre-commit/action@v2.0.3
...@@ -8,7 +8,7 @@ default_language_version: ...@@ -8,7 +8,7 @@ default_language_version:
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0 rev: v4.0.1
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: check-ast - id: check-ast
...@@ -20,29 +20,31 @@ repos: ...@@ -20,29 +20,31 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: https://github.com/ambv/black - repo: https://github.com/ambv/black
rev: 19.10b0 rev: 21.10b0
hooks: hooks:
- id: black - id: black
- repo: https://gitlab.com/pycqa/flake8 - repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9 rev: 4.0.1
hooks: hooks:
- id: flake8 - id: flake8
args: [--show-source, --statistics]
- repo: https://github.com/asottile/seed-isort-config - repo: https://github.com/asottile/seed-isort-config
rev: v2.1.0 rev: v2.2.0
hooks: hooks:
- id: seed-isort-config - id: seed-isort-config
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.6.4 rev: 5.10.1
hooks: hooks:
- id: isort - id: isort
exclude: README.md exclude: README.md
additional_dependencies: [toml] additional_dependencies: [toml]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.790' rev: 'v0.910'
hooks: hooks:
- id: mypy - id: mypy
args: [--no-strict-optional, --ignore-missing-imports, --scripts-are-modules, --pretty]
additional_dependencies: [numpy] additional_dependencies: [numpy]
...@@ -16,11 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -16,11 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
and gradient memory to be sharded despite being needed from different layers due to and gradient memory to be sharded despite being needed from different layers due to
weight sharing. [#836] weight sharing. [#836]
- OffloadModel: Fix node names to enable correct sharding in auto_shard.py [#830] - OffloadModel: Fix node names to enable correct sharding in auto_shard.py [#830]
- OSS: Relaxed speed and memory constraints on OSS golden data due to regression when we bumped up the - OSS: Relaxed speed and memory constraints on OSS golden data due to regression when we bumped up the
PyTorch version to 1.9. [#828] [#825] PyTorch version to 1.9. [#828] [#825]
- Chore: Update PyTorch version that we run benchmarks with. [#823] - Chore: Update PyTorch version that we run benchmarks with. [#823]
- Chore: Update PyTorch version that we run test with. [#809] - Chore: Update PyTorch version that we run test with. [#809]
- OffloadModel: Extend auto_shard.py to allow dealing with conditionals automatically when tracing with - OffloadModel: Extend auto_shard.py to allow dealing with conditionals automatically when tracing with
torch.fx. This will work for most cases except when the conditional is part of the root instance. [#817] torch.fx. This will work for most cases except when the conditional is part of the root instance. [#817]
- [MEVO]: a custom layer to help big vocab trainings. Experimental. Docs is still TBD. [#840] - [MEVO]: a custom layer to help big vocab trainings. Experimental. Docs is still TBD. [#840]
- SlowMoDistributedDataParallel[feature][experimental] - This is a distributed training wrapper which should be useful on clusters with slow network interconnects (eg Ethernet). This improves on performance as compared to Distributed Data Parallel in such clusters. [#378] - SlowMoDistributedDataParallel[feature][experimental] - This is a distributed training wrapper which should be useful on clusters with slow network interconnects (eg Ethernet). This improves on performance as compared to Distributed Data Parallel in such clusters. [#378]
......
...@@ -42,7 +42,10 @@ class BenchmarkLMDataset(Dataset): ...@@ -42,7 +42,10 @@ class BenchmarkLMDataset(Dataset):
""" """
def __init__( def __init__(
self, vocab_size=10000, max_source_positions=1024, total_samples=10000, self,
vocab_size=10000,
max_source_positions=1024,
total_samples=10000,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_source_positions = max_source_positions self.max_source_positions = max_source_positions
......
...@@ -35,7 +35,7 @@ KERNELS = [ ...@@ -35,7 +35,7 @@ KERNELS = [
def run_on_gpu(kernel, data, repeats, no_grad, fwd_bwd): def run_on_gpu(kernel, data, repeats, no_grad, fwd_bwd):
""" Measure both GPU runtime and peak memory usage of a kernel. """ """Measure both GPU runtime and peak memory usage of a kernel."""
tokens = data[0].shape[0] tokens = data[0].shape[0]
def get_cuda_data(): def get_cuda_data():
......
...@@ -142,7 +142,7 @@ class MySGD(Optimizer): ...@@ -142,7 +142,7 @@ class MySGD(Optimizer):
super(MySGD, self).__setstate__(state) super(MySGD, self).__setstate__(state)
def step(self, closure=None): def step(self, closure=None):
""" Performs a single optimization step. """Performs a single optimization step.
Args: Args:
closure (callable, optional): A closure that reevaluates the model closure (callable, optional): A closure that reevaluates the model
and returns the loss. and returns the loss.
...@@ -162,15 +162,15 @@ class MySGD(Optimizer): ...@@ -162,15 +162,15 @@ class MySGD(Optimizer):
class SpectrainSGDMomentum(Optimizer): class SpectrainSGDMomentum(Optimizer):
r""" r"""
Implements a SGD with momentum optimizer with Spectrain based weight Implements a SGD with momentum optimizer with Spectrain based weight
prediction. Please refer to the spectrain paper: https://arxiv.org/pdf/1809.02839.pdf prediction. Please refer to the spectrain paper: https://arxiv.org/pdf/1809.02839.pdf
for more details. for more details.
Args: Args:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts defining
parameter groups parameter groups
lr (float): learning rate (required) lr (float): learning rate (required)
momentum (float): momentum (default=0.9) momentum (float): momentum (default=0.9)
""" """
def __init__(self, params, lr, momentum=0.9): def __init__(self, params, lr, momentum=0.9):
...@@ -234,7 +234,7 @@ class SpectrainSGDMomentum(Optimizer): ...@@ -234,7 +234,7 @@ class SpectrainSGDMomentum(Optimizer):
p.data.sub_(param_state["momentum_buffer"].data, alpha=multiplier) p.data.sub_(param_state["momentum_buffer"].data, alpha=multiplier)
def step(self, weight_prediction=True, closure=None): def step(self, weight_prediction=True, closure=None):
""" Performs a single optimization step. """Performs a single optimization step.
Args: Args:
weight_prediction (bool, optional): Enable weight prediction based updates weight_prediction (bool, optional): Enable weight prediction based updates
closure (callable, optional): A closure that reevaluates the model closure (callable, optional): A closure that reevaluates the model
......
...@@ -413,7 +413,10 @@ parser.add_argument( ...@@ -413,7 +413,10 @@ parser.add_argument(
help="Print debugging statements which is more verbose than the default.", help="Print debugging statements which is more verbose than the default.",
) )
parser.add_argument( parser.add_argument(
"--model_name", default="lm", type=str, help="Language Model(LM) used to benchmark nn.pipe.", "--model_name",
default="lm",
type=str,
help="Language Model(LM) used to benchmark nn.pipe.",
) )
parser.add_argument( parser.add_argument(
"--use_synthetic_data", default=True, action="store_true", help="Uses synthetic data for running benchmarks." "--use_synthetic_data", default=True, action="store_true", help="Uses synthetic data for running benchmarks."
......
...@@ -320,14 +320,22 @@ if __name__ == "__main__": ...@@ -320,14 +320,22 @@ if __name__ == "__main__":
if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with DDP") logging.info("\n*** Benchmark OSS with DDP")
mp.spawn( mp.spawn(
train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True, # type: ignore train,
args=(args, BACKEND, OptimType.oss_ddp, args.check_regression),
nprocs=args.world_size,
join=True, # type: ignore
) )
if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with ShardedDDP") logging.info("\n*** Benchmark OSS with ShardedDDP")
mp.spawn( mp.spawn(
train, # type: ignore train, # type: ignore
args=(args, BACKEND, OptimType.oss_sharded_ddp, args.check_regression,), args=(
args,
BACKEND,
OptimType.oss_sharded_ddp,
args.check_regression,
),
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
) )
...@@ -34,13 +34,20 @@ from fairscale.nn.pipe.worker import Task ...@@ -34,13 +34,20 @@ from fairscale.nn.pipe.worker import Task
def create_task_without_skip_trackers( def create_task_without_skip_trackers(
checkpoint_stop: int, i: int, j: int, batch: Batch, partition: nn.Sequential, checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
) -> Task: ) -> Task:
# Determine whether checkpointing or not. # Determine whether checkpointing or not.
if i < checkpoint_stop: if i < checkpoint_stop:
def function( def function(
input: TensorOrTensors, partition: nn.Sequential = partition, chunk_id: int = i, part_id: int = j, input: TensorOrTensors,
partition: nn.Sequential = partition,
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors: ) -> TensorOrTensors:
with record_function("chunk%d-part%d" % (chunk_id, part_id)): with record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input) return partition(input)
...@@ -52,7 +59,10 @@ def create_task_without_skip_trackers( ...@@ -52,7 +59,10 @@ def create_task_without_skip_trackers(
else: else:
def compute( def compute(
batch: Batch = batch, partition: nn.Sequential = partition, chunk_id: int = i, part_id: int = j, batch: Batch = batch,
partition: nn.Sequential = partition,
chunk_id: int = i,
part_id: int = j,
) -> Batch: ) -> Batch:
with record_function("chunk%d-part%d" % (chunk_id, part_id)): with record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition) return batch.call(partition)
...@@ -93,7 +103,11 @@ class AsyncAMPnetEventLoop: ...@@ -93,7 +103,11 @@ class AsyncAMPnetEventLoop:
def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]: def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]:
task = create_task_without_skip_trackers( task = create_task_without_skip_trackers(
self.checkpoint_stop, index, self.group.rank(), batch, self.partitions[0].module, self.checkpoint_stop,
index,
self.group.rank(),
batch,
self.partitions[0].module,
) )
result = task.compute() result = task.compute()
task.finalize(result) task.finalize(result)
...@@ -258,7 +272,11 @@ class AsyncAMPnetEventLoop: ...@@ -258,7 +272,11 @@ class AsyncAMPnetEventLoop:
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, count, self.chunks, forward=True) optimizer.update_weight_using_future_predictions(cur_rank, N, count, self.chunks, forward=True)
task = create_task_without_skip_trackers( task = create_task_without_skip_trackers(
self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module, self.checkpoint_stop,
args.microbatch_index,
self.group.rank(),
batch,
self.partitions[0].module,
) )
output = task.compute() output = task.compute()
activations[args.microbatch_index] = output activations[args.microbatch_index] = output
......
...@@ -20,9 +20,9 @@ __all__ = ["AMPnetPipe"] ...@@ -20,9 +20,9 @@ __all__ = ["AMPnetPipe"]
class AMPnetPipe(AsyncPipe): class AMPnetPipe(AsyncPipe):
""" """
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients. which avoids the bubble issue, by using stale weights and gradients.
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786 The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
""" """
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
...@@ -46,8 +46,8 @@ class AMPnetPipe(AsyncPipe): ...@@ -46,8 +46,8 @@ class AMPnetPipe(AsyncPipe):
assert self.group assert self.group
rank = self.group.rank() rank = self.group.rank()
transport = self.pipeline.transport # type: ignore transport = self.pipeline.transport
checkpoint_stop = self.pipeline.checkpoint_stop # type: ignore checkpoint_stop = self.pipeline.checkpoint_stop
ampnet_event_loop = AsyncAMPnetEventLoop( ampnet_event_loop = AsyncAMPnetEventLoop(
partitions, partitions,
self.group, self.group,
......
...@@ -312,7 +312,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -312,7 +312,7 @@ class SlowMoDistributedDataParallel(Module):
self.logger.debug("Initialization of SlowMoDistributedDataParallel complete") self.logger.debug("Initialization of SlowMoDistributedDataParallel complete")
def _initialize_logger(self, verbose: bool, process_rank: int) -> None: def _initialize_logger(self, verbose: bool, process_rank: int) -> None:
""" Initializes the logger """ """Initializes the logger"""
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
if verbose: if verbose:
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.DEBUG)
...@@ -331,7 +331,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -331,7 +331,7 @@ class SlowMoDistributedDataParallel(Module):
master_group: Optional[torch.distributed.ProcessGroup], master_group: Optional[torch.distributed.ProcessGroup],
local_node_group: Optional[torch.distributed.ProcessGroup], local_node_group: Optional[torch.distributed.ProcessGroup],
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" Creates the process groups required for the SlowMo implementation """ """Creates the process groups required for the SlowMo implementation"""
self.local_rank = process_rank % self.nprocs_per_node self.local_rank = process_rank % self.nprocs_per_node
assert ( assert (
...@@ -392,7 +392,12 @@ class SlowMoDistributedDataParallel(Module): ...@@ -392,7 +392,12 @@ class SlowMoDistributedDataParallel(Module):
self.logger.debug("Initializing local process groups") self.logger.debug("Initializing local process groups")
for node in range(logical_world_size): for node in range(logical_world_size):
node_processes_ranks = list(range(node * self.nprocs_per_node, (node + 1) * self.nprocs_per_node,)) node_processes_ranks = list(
range(
node * self.nprocs_per_node,
(node + 1) * self.nprocs_per_node,
)
)
# Process group to communicate between processes on this machine # Process group to communicate between processes on this machine
new_local_group = create_process_group(node_processes_ranks) new_local_group = create_process_group(node_processes_ranks)
if process_rank in node_processes_ranks: if process_rank in node_processes_ranks:
...@@ -401,24 +406,26 @@ class SlowMoDistributedDataParallel(Module): ...@@ -401,24 +406,26 @@ class SlowMoDistributedDataParallel(Module):
self.logger.debug("Initialization of local groups complete") self.logger.debug("Initialization of local groups complete")
def forward(self, *inputs: Any, **kwargs: Any) -> Union[torch.Tensor, List[torch.Tensor]]: def forward(self, *inputs: Any, **kwargs: Any) -> Union[torch.Tensor, List[torch.Tensor]]:
""" Forward pass performed in parallel across all devices on node """ """Forward pass performed in parallel across all devices on node"""
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def _sync_params(self) -> None: def _sync_params(self) -> None:
""" Synchronize parameters across devices (intra-node) """ """Synchronize parameters across devices (intra-node)"""
if self.local_node_group is None: if self.local_node_group is None:
return return
# intra-node parameter sync # intra-node parameter sync
params = cast(List[torch.Tensor], list(self.module.parameters())) params = cast(List[torch.Tensor], list(self.module.parameters()))
communication_op = functools.partial( communication_op = functools.partial(
dist.broadcast, src=self.logical_rank * self.nprocs_per_node, group=self.local_node_group, dist.broadcast,
src=self.logical_rank * self.nprocs_per_node,
group=self.local_node_group,
) )
communicate(params, communication_op) communicate(params, communication_op)
self.logger.debug("Intra-node param sync complete") self.logger.debug("Intra-node param sync complete")
def _sync_buffers(self) -> None: def _sync_buffers(self) -> None:
""" Synchronize buffers across nodes """ """Synchronize buffers across nodes"""
# module buffer sync # module buffer sync
if self.broadcast_buffers and len(self.module_buffers) > 0: if self.broadcast_buffers and len(self.module_buffers) > 0:
# Synchronize buffers across processes. # Synchronize buffers across processes.
...@@ -432,17 +439,18 @@ class SlowMoDistributedDataParallel(Module): ...@@ -432,17 +439,18 @@ class SlowMoDistributedDataParallel(Module):
dist._broadcast_coalesced(process_group, tensors, buffer_size) dist._broadcast_coalesced(process_group, tensors, buffer_size)
def _create_event_recorder(self, event_name: str) -> EventRecorder: def _create_event_recorder(self, event_name: str) -> EventRecorder:
""" Creates an cuda event recorder which helps in profiling """ """Creates an cuda event recorder which helps in profiling"""
return create_event_recorder(event_name, dummy=not self.profile_mode) return create_event_recorder(event_name, dummy=not self.profile_mode)
def _fp16_fp32_iterator( def _fp16_fp32_iterator(
self, optimizer: torch.optim.Optimizer, fp32_params: Optional[torch.Tensor] self, optimizer: torch.optim.Optimizer, fp32_params: Optional[torch.Tensor]
) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]:
""" Iterator for those fp16 parameters which have a fp32 copy """ """Iterator for those fp16 parameters which have a fp32 copy"""
# Handle apex fp16 optimizer # Handle apex fp16 optimizer
if hasattr(optimizer, "_amp_stash") and hasattr(optimizer._amp_stash, "fp16_groups"): if hasattr(optimizer, "_amp_stash") and hasattr(optimizer._amp_stash, "fp16_groups"):
for p_fp16_group, p_fp32_group in zip( for p_fp16_group, p_fp32_group in zip(
optimizer._amp_stash.fp16_groups, optimizer._amp_stash.fp32_from_fp16_groups, optimizer._amp_stash.fp16_groups,
optimizer._amp_stash.fp32_from_fp16_groups,
): ):
for p_fp16, p_fp32 in zip(p_fp16_group, p_fp32_group): for p_fp16, p_fp32 in zip(p_fp16_group, p_fp32_group):
yield p_fp16, p_fp32 yield p_fp16, p_fp32
...@@ -594,12 +602,12 @@ class SlowMoDistributedDataParallel(Module): ...@@ -594,12 +602,12 @@ class SlowMoDistributedDataParallel(Module):
ef1.copy_(p_fp32 - p_fp16.float()) ef1.copy_(p_fp32 - p_fp16.float())
def perform_slowmo(self, optimizer: torch.optim.Optimizer, fp32_params: Optional[torch.Tensor] = None) -> None: def perform_slowmo(self, optimizer: torch.optim.Optimizer, fp32_params: Optional[torch.Tensor] = None) -> None:
""" This is to be called after optimizer.step(). It performs the approximate averaging using """This is to be called after optimizer.step(). It performs the approximate averaging using
the base algorithm (SGP/ LocalSGD) and the slow momentum step. Since LocalSGD and the slow the base algorithm (SGP/ LocalSGD) and the slow momentum step. Since LocalSGD and the slow
momentum step are not performed every iteration, it only performs those when needed. momentum step are not performed every iteration, it only performs those when needed.
It is recommended to call ``model.zero_grad(set_to_none=True)`` just before calling this function. This It is recommended to call ``model.zero_grad(set_to_none=True)`` just before calling this function. This
is because ``model.zero_grad(set_to_none=True)`` frees up the memory occupied by the gradients, some of which is because ``model.zero_grad(set_to_none=True)`` frees up the memory occupied by the gradients, some of which
may be reused by this function. may be reused by this function.
Args: Args:
...@@ -645,7 +653,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -645,7 +653,7 @@ class SlowMoDistributedDataParallel(Module):
self.num_updates += 1 self.num_updates += 1
def _init_global_momentum_buffers(self, optimizer: torch.optim.Optimizer) -> None: def _init_global_momentum_buffers(self, optimizer: torch.optim.Optimizer) -> None:
""" Initializes the slow momentum buffers """ """Initializes the slow momentum buffers"""
self.global_momentum_buffers_initialized = True self.global_momentum_buffers_initialized = True
if not self.slowmo: if not self.slowmo:
...@@ -707,7 +715,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -707,7 +715,7 @@ class SlowMoDistributedDataParallel(Module):
self.global_momentum_buffer = torch.zeros_like(self.old_params).detach() self.global_momentum_buffer = torch.zeros_like(self.old_params).detach()
def _distributed_comm(self, optimizer: torch.optim.Optimizer, mode: str) -> None: def _distributed_comm(self, optimizer: torch.optim.Optimizer, mode: str) -> None:
""" Performs the communication needed for the efficient SlowMo implementation """ """Performs the communication needed for the efficient SlowMo implementation"""
offset = 0 offset = 0
slowmo_comm_lists: List[List[torch.Tensor]] = [[] for _ in range(self.slowmo_num_shards)] slowmo_comm_lists: List[List[torch.Tensor]] = [[] for _ in range(self.slowmo_num_shards)]
with torch.no_grad(): with torch.no_grad():
...@@ -743,7 +751,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -743,7 +751,7 @@ class SlowMoDistributedDataParallel(Module):
communicate(slowmo_comm_list, communication_op) communicate(slowmo_comm_list, communication_op)
def _global_momentum_step(self, optimizer: torch.optim.Optimizer) -> None: def _global_momentum_step(self, optimizer: torch.optim.Optimizer) -> None:
""" Performs the slow momentum step """ """Performs the slow momentum step"""
if not self.slowmo: if not self.slowmo:
return return
...@@ -760,7 +768,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -760,7 +768,7 @@ class SlowMoDistributedDataParallel(Module):
self._distributed_comm(optimizer, mode="scatter") self._distributed_comm(optimizer, mode="scatter")
def _perform_local_optimization(self, optimizer: torch.optim.Optimizer) -> None: def _perform_local_optimization(self, optimizer: torch.optim.Optimizer) -> None:
""" Performs the slow momentum on the local shard """ """Performs the slow momentum on the local shard"""
assert self.portion_start is not None assert self.portion_start is not None
with torch.no_grad(): with torch.no_grad():
...@@ -838,7 +846,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -838,7 +846,7 @@ class SlowMoDistributedDataParallel(Module):
self.logger.debug("making forward pre-hook") self.logger.debug("making forward pre-hook")
def hook(*unused: Any) -> None: def hook(*unused: Any) -> None:
""" Query gossip queue and de-bias during forward pass """ """Query gossip queue and de-bias during forward pass"""
# sync buffers before the forward pass # sync buffers before the forward pass
self._sync_buffers() self._sync_buffers()
...@@ -869,7 +877,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -869,7 +877,7 @@ class SlowMoDistributedDataParallel(Module):
use_streams: bool = True, use_streams: bool = True,
slowmo_sgp_average_params: bool = False, slowmo_sgp_average_params: bool = False,
) -> None: ) -> None:
""" Perform initialization for Stochastic Gradient Push base algorithm """ """Perform initialization for Stochastic Gradient Push base algorithm"""
if graph is None: if graph is None:
graph = NPDDEGraph(logical_rank, logical_world_size, self.nprocs_per_node, self.local_rank) graph = NPDDEGraph(logical_rank, logical_world_size, self.nprocs_per_node, self.local_rank)
...@@ -959,7 +967,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -959,7 +967,7 @@ class SlowMoDistributedDataParallel(Module):
super(SlowMoDistributedDataParallel, self).load_state_dict(cast(Dict[str, torch.Tensor], state_dict)) super(SlowMoDistributedDataParallel, self).load_state_dict(cast(Dict[str, torch.Tensor], state_dict))
def _sgp_ps_numerator(self) -> None: def _sgp_ps_numerator(self) -> None:
""" Convert model params to ps-numerator """ """Convert model params to ps-numerator"""
if not self.is_sgp_ps_numerator: if not self.is_sgp_ps_numerator:
if not self.lazy_mixing: if not self.lazy_mixing:
ps_weight = self.ps_weight ps_weight = self.ps_weight
...@@ -969,7 +977,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -969,7 +977,7 @@ class SlowMoDistributedDataParallel(Module):
self.is_sgp_ps_numerator = True self.is_sgp_ps_numerator = True
def _sgp_unbias(self) -> None: def _sgp_unbias(self) -> None:
""" Convert model params to de-biased estimate """ """Convert model params to de-biased estimate"""
if self.is_sgp_ps_numerator: if self.is_sgp_ps_numerator:
if not self.lazy_mixing: if not self.lazy_mixing:
ps_weight = self.ps_weight ps_weight = self.ps_weight
...@@ -992,7 +1000,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -992,7 +1000,7 @@ class SlowMoDistributedDataParallel(Module):
return self return self
def _sgp_query_gossip_queue(self, non_blocking: bool = False) -> bool: def _sgp_query_gossip_queue(self, non_blocking: bool = False) -> bool:
""" Check gossip-queue for push-sum residuals and update model """ """Check gossip-queue for push-sum residuals and update model"""
if not self.gossip_enable: if not self.gossip_enable:
return False return False
...@@ -1046,7 +1054,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -1046,7 +1054,7 @@ class SlowMoDistributedDataParallel(Module):
return False return False
def _sgp_transfer_params(self, mix: bool = True) -> bool: def _sgp_transfer_params(self, mix: bool = True) -> bool:
""" Transfers COPY of model parameters to gossip queue """ """Transfers COPY of model parameters to gossip queue"""
if not self.gossip_enable or self.process_rank % self.nprocs_per_node != 0: if not self.gossip_enable or self.process_rank % self.nprocs_per_node != 0:
return False return False
...@@ -1130,7 +1138,7 @@ class SlowMoDistributedDataParallel(Module): ...@@ -1130,7 +1138,7 @@ class SlowMoDistributedDataParallel(Module):
gossip_ps_factor: torch.Tensor, gossip_ps_factor: torch.Tensor,
gossip_stream: torch.cuda.Stream, gossip_stream: torch.cuda.Stream,
) -> None: ) -> None:
""" Gossip thread, which performs push-sum on model params """ """Gossip thread, which performs push-sum on model params"""
logger = make_logger(dist_config["logical_rank"], dist_config["verbose"]) logger = make_logger(dist_config["logical_rank"], dist_config["verbose"])
gossip_params_by_dtype = group_by_dtype(gossip_params) gossip_params_by_dtype = group_by_dtype(gossip_params)
......
...@@ -30,7 +30,7 @@ class dist_backend(str, Enum): ...@@ -30,7 +30,7 @@ class dist_backend(str, Enum):
class Gossiper(object): class Gossiper(object):
""" Generic gossip averaging object for multi-peer communication """Generic gossip averaging object for multi-peer communication
Args: Args:
msg (torch.Tensor): message used to initialize recv buffer msg (torch.Tensor): message used to initialize recv buffer
...@@ -121,7 +121,7 @@ class Gossiper(object): ...@@ -121,7 +121,7 @@ class Gossiper(object):
self._graph_manager.peers_per_itr = v self._graph_manager.peers_per_itr = v
def refresh_peers_(self, rotate: Optional[bool] = None) -> None: def refresh_peers_(self, rotate: Optional[bool] = None) -> None:
""" Update in- and out-peers """ """Update in- and out-peers"""
if rotate is None: if rotate is None:
rotate = self._graph_manager.is_dynamic_graph() rotate = self._graph_manager.is_dynamic_graph()
# cannot cycle peers in a static graph # cannot cycle peers in a static graph
...@@ -129,11 +129,11 @@ class Gossiper(object): ...@@ -129,11 +129,11 @@ class Gossiper(object):
self.out_edges, self.in_edges = self._graph_manager.get_edges(rotate) self.out_edges, self.in_edges = self._graph_manager.get_edges(rotate)
def refresh_mixing_weights_(self, residual_adjusted: bool = False) -> None: def refresh_mixing_weights_(self, residual_adjusted: bool = False) -> None:
""" Update mixing-matrix weights """ """Update mixing-matrix weights"""
self.mixing_weights = self._mixing_manager.get_mixing_weights(residual_adjusted) self.mixing_weights = self._mixing_manager.get_mixing_weights(residual_adjusted)
def mix_out_msg_(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Iterator[torch.Tensor]: def mix_out_msg_(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Iterator[torch.Tensor]:
""" Returns a generator mixing messages on the fly """ """Returns a generator mixing messages on the fly"""
self.refresh_mixing_weights_(residual_adjusted=True) self.refresh_mixing_weights_(residual_adjusted=True)
self.ps_weight = ps_weight self.ps_weight = ps_weight
...@@ -153,14 +153,14 @@ class Gossiper(object): ...@@ -153,14 +153,14 @@ class Gossiper(object):
yield out_msg.mul(weight.type(out_msg.dtype)) # type: ignore yield out_msg.mul(weight.type(out_msg.dtype)) # type: ignore
def clean_msg_buffers_(self) -> None: def clean_msg_buffers_(self) -> None:
""" Clean outgoing message buffer """ """Clean outgoing message buffer"""
while len(self.out_msg_buffer) > 0: while len(self.out_msg_buffer) > 0:
req, msg = self.out_msg_buffer.pop() req, msg = self.out_msg_buffer.pop()
req.wait() req.wait()
msg.set_() msg.set_()
def parse_in_msg_buffer(self) -> Tuple[torch.Tensor, torch.Tensor]: def parse_in_msg_buffer(self) -> Tuple[torch.Tensor, torch.Tensor]:
""" Parse in-msg buffer and return msg and ps-weight separately """ """Parse in-msg buffer and return msg and ps-weight separately"""
msg = self.in_msg_buffer msg = self.in_msg_buffer
if not self.regular: if not self.regular:
return msg.narrow(0, 0, len(msg) - 1), msg[-1] return msg.narrow(0, 0, len(msg) - 1), msg[-1]
...@@ -168,15 +168,15 @@ class Gossiper(object): ...@@ -168,15 +168,15 @@ class Gossiper(object):
return msg, self.ps_weight * self.peers_per_itr_device return msg, self.ps_weight * self.peers_per_itr_device
def mix(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def mix(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" Single gossip step """ """Single gossip step"""
raise NotImplementedError raise NotImplementedError
class PushSum(Gossiper): class PushSum(Gossiper):
""" 1-peer Push-Sum consensus averaging module """ """1-peer Push-Sum consensus averaging module"""
def mix(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def mix(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" Consensus averaging step """ """Consensus averaging step"""
# out_msg must be on the correct device # out_msg must be on the correct device
assert out_msg.device.type == self.device.type assert out_msg.device.type == self.device.type
if self.logger is not None: if self.logger is not None:
...@@ -189,7 +189,12 @@ class PushSum(Gossiper): ...@@ -189,7 +189,12 @@ class PushSum(Gossiper):
for out_edge in self.out_edges: for out_edge in self.out_edges:
msg = next(mixed_out_msgs) msg = next(mixed_out_msgs)
assert self.rank == out_edge.src assert self.rank == out_edge.src
req = dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group, async_op=True,) req = dist.broadcast(
tensor=msg,
src=out_edge.src,
group=out_edge.process_group,
async_op=True,
)
self.out_msg_buffer.append((req, msg)) self.out_msg_buffer.append((req, msg))
# blocking recv w/ some code optimization to avoid buffer prep overhead # blocking recv w/ some code optimization to avoid buffer prep overhead
...@@ -204,7 +209,9 @@ class PushSum(Gossiper): ...@@ -204,7 +209,9 @@ class PushSum(Gossiper):
for in_edge in self.in_edges: for in_edge in self.in_edges:
dist.broadcast( dist.broadcast(
tensor=self.placeholder, src=in_edge.src, group=in_edge.process_group, tensor=self.placeholder,
src=in_edge.src,
group=in_edge.process_group,
) )
self.in_msg_buffer.add_(self.placeholder) # type: ignore self.in_msg_buffer.add_(self.placeholder) # type: ignore
...@@ -214,7 +221,7 @@ class PushSum(Gossiper): ...@@ -214,7 +221,7 @@ class PushSum(Gossiper):
class PushPull(Gossiper): class PushPull(Gossiper):
""" Doubly-stochastic consensus averaging module """ """Doubly-stochastic consensus averaging module"""
def mix(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def mix(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# out_msg must be on the correct device # out_msg must be on the correct device
...@@ -232,11 +239,15 @@ class PushPull(Gossiper): ...@@ -232,11 +239,15 @@ class PushPull(Gossiper):
if not self.passive: if not self.passive:
dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group) dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
dist.broadcast( dist.broadcast(
tensor=self.in_msg_buffer, src=in_edge.src, group=in_edge.process_group, tensor=self.in_msg_buffer,
src=in_edge.src,
group=in_edge.process_group,
) )
else: else:
dist.broadcast( dist.broadcast(
tensor=self.in_msg_buffer, src=in_edge.src, group=in_edge.process_group, tensor=self.in_msg_buffer,
src=in_edge.src,
group=in_edge.process_group,
) )
dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group) dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
...@@ -251,11 +262,15 @@ class PushPull(Gossiper): ...@@ -251,11 +262,15 @@ class PushPull(Gossiper):
if not self.passive: if not self.passive:
dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group) dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
dist.broadcast( dist.broadcast(
tensor=self.placeholder, src=in_edge.src, group=in_edge.process_group, tensor=self.placeholder,
src=in_edge.src,
group=in_edge.process_group,
) )
else: else:
dist.broadcast( dist.broadcast(
tensor=self.placeholder, src=in_edge.src, group=in_edge.process_group, tensor=self.placeholder,
src=in_edge.src,
group=in_edge.process_group,
) )
dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group) dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
self.in_msg_buffer.add_(self.placeholder) # type: ignore self.in_msg_buffer.add_(self.placeholder) # type: ignore
......
...@@ -77,26 +77,26 @@ class GraphManager(ABC): ...@@ -77,26 +77,26 @@ class GraphManager(ABC):
@abstractmethod @abstractmethod
def is_regular_graph(self) -> bool: def is_regular_graph(self) -> bool:
""" Whether each node has the same number of in-peers as out-peers """ """Whether each node has the same number of in-peers as out-peers"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def is_bipartite_graph(self) -> bool: def is_bipartite_graph(self) -> bool:
""" Whether graph is bipartite or not """ """Whether graph is bipartite or not"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def is_passive(self, rank: Optional[int] = None) -> bool: def is_passive(self, rank: Optional[int] = None) -> bool:
""" Whether 'rank' is a passive node or not """ """Whether 'rank' is a passive node or not"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def is_dynamic_graph(self) -> bool: def is_dynamic_graph(self) -> bool:
""" Whether the graph-type is dynamic (as opposed to static) """ """Whether the graph-type is dynamic (as opposed to static)"""
raise NotImplementedError raise NotImplementedError
def get_peers(self, rotate: bool = False) -> Tuple[List[int], List[int]]: def get_peers(self, rotate: bool = False) -> Tuple[List[int], List[int]]:
""" Returns the out and in-peers corresponding to 'self.rank' """ """Returns the out and in-peers corresponding to 'self.rank'"""
# cycle through in- and out-peers by updating group-index # cycle through in- and out-peers by updating group-index
if rotate: if rotate:
self._rotate_group_indices() self._rotate_group_indices()
...@@ -113,8 +113,8 @@ class GraphManager(ABC): ...@@ -113,8 +113,8 @@ class GraphManager(ABC):
return out_peers, in_peers return out_peers, in_peers
def get_edges(self, rotate: bool = False) -> Tuple[List[Edge], List[Edge]]: def get_edges(self, rotate: bool = False) -> Tuple[List[Edge], List[Edge]]:
""" Returns the pairwise process groups between rank and the out and """Returns the pairwise process groups between rank and the out and
in-peers corresponding to 'self.rank' """ in-peers corresponding to 'self.rank'"""
# cycle through in- and out-peers by updating group-index # cycle through in- and out-peers by updating group-index
if rotate: if rotate:
self._rotate_group_indices() self._rotate_group_indices()
...@@ -131,17 +131,17 @@ class GraphManager(ABC): ...@@ -131,17 +131,17 @@ class GraphManager(ABC):
return out_edges, in_edges return out_edges, in_edges
def _rotate_group_indices(self) -> None: def _rotate_group_indices(self) -> None:
""" Incerement group indices to point to the next out-peer """ """Incerement group indices to point to the next out-peer"""
increment = self.peers_per_itr increment = self.peers_per_itr
for i, group_index in enumerate(self._group_indices): for i, group_index in enumerate(self._group_indices):
self._group_indices[i] = int((group_index + increment) % len(self.phone_book[self.rank])) self._group_indices[i] = int((group_index + increment) % len(self.phone_book[self.rank]))
def _rotate_forward(self, r: int, p: int) -> int: def _rotate_forward(self, r: int, p: int) -> int:
""" Helper function returns peer that is p hops ahead of r """ """Helper function returns peer that is p hops ahead of r"""
return (r + p) % self.world_size return (r + p) % self.world_size
def _rotate_backward(self, r: int, p: int) -> int: def _rotate_backward(self, r: int, p: int) -> int:
""" Helper function returns peer that is p hops behind r """ """Helper function returns peer that is p hops behind r"""
return (r - p) % self.world_size return (r - p) % self.world_size
......
...@@ -32,18 +32,18 @@ class MixingManager(ABC): ...@@ -32,18 +32,18 @@ class MixingManager(ABC):
@abstractmethod @abstractmethod
def is_uniform(self) -> bool: def is_uniform(self) -> bool:
""" Whether mixing weights are distributed uniformly over peers """ """Whether mixing weights are distributed uniformly over peers"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_mixing_weights(self, residual_adjusted: bool = True) -> Dict[Union[str, int], torch.Tensor]: def get_mixing_weights(self, residual_adjusted: bool = True) -> Dict[Union[str, int], torch.Tensor]:
""" Create mixing weight dictionary using uniform allocation """ """Create mixing weight dictionary using uniform allocation"""
raise NotImplementedError raise NotImplementedError
class UniformMixing(MixingManager): class UniformMixing(MixingManager):
def get_mixing_weights(self, residual_adjusted: bool = True) -> Dict[Union[str, int], torch.Tensor]: def get_mixing_weights(self, residual_adjusted: bool = True) -> Dict[Union[str, int], torch.Tensor]:
""" Create mixing weight dictionary using uniform allocation """ """Create mixing weight dictionary using uniform allocation"""
mixing_weights: Dict[Union[str, int], torch.Tensor] = {} mixing_weights: Dict[Union[str, int], torch.Tensor] = {}
out_peers, _ = self.graph_manager.get_peers() out_peers, _ = self.graph_manager.get_peers()
......
...@@ -36,7 +36,7 @@ def create_event_recorder(event_name: str, dummy: bool = False) -> EventRecorder ...@@ -36,7 +36,7 @@ def create_event_recorder(event_name: str, dummy: bool = False) -> EventRecorder
class CudaEventRecorder(EventRecorder): class CudaEventRecorder(EventRecorder):
""" Allows profiling in an easy-to-use manner. CudaEventRecorder can be used """Allows profiling in an easy-to-use manner. CudaEventRecorder can be used
in a loop. When it is used in a loop (or when an event recorder is created in a loop. When it is used in a loop (or when an event recorder is created
multiple times with the same name), get_timings returns the statistics of the multiple times with the same name), get_timings returns the statistics of the
timings since the last reset. Note: in case the number of timings is greater than timings since the last reset. Note: in case the number of timings is greater than
...@@ -92,19 +92,22 @@ class CudaEventRecorder(EventRecorder): ...@@ -92,19 +92,22 @@ class CudaEventRecorder(EventRecorder):
time_taken_list = [event_recorder.find_time_elapsed() for event_recorder in event_recorder_list] time_taken_list = [event_recorder.find_time_elapsed() for event_recorder in event_recorder_list]
all_timings_str += ("{}: Time taken: avg: {}, std: {}, count: " "{}\n").format( all_timings_str += ("{}: Time taken: avg: {}, std: {}, count: " "{}\n").format(
event_name, statistics.mean(time_taken_list), statistics.pstdev(time_taken_list), len(time_taken_list), event_name,
statistics.mean(time_taken_list),
statistics.pstdev(time_taken_list),
len(time_taken_list),
) )
return all_timings_str return all_timings_str
@classmethod @classmethod
def get_timings(cls) -> str: def get_timings(cls) -> str:
""" Returns the timings since last reset was called """ """Returns the timings since last reset was called"""
return cls.get_common_timings(cls.event_recorders, "Timings since last reset") return cls.get_common_timings(cls.event_recorders, "Timings since last reset")
@classmethod @classmethod
def get_all_timings(cls) -> str: def get_all_timings(cls) -> str:
""" Returns the statistics of all the timings """ """Returns the statistics of all the timings"""
return cls.get_common_timings(cls.all_event_recorders, "All timings") return cls.get_common_timings(cls.all_event_recorders, "All timings")
......
...@@ -86,7 +86,10 @@ def communicate(tensors: List[torch.Tensor], communication_op: Any, logger: logg ...@@ -86,7 +86,10 @@ def communicate(tensors: List[torch.Tensor], communication_op: Any, logger: logg
if logger is not None: if logger is not None:
logger.debug("Commmunication completed") logger.debug("Commmunication completed")
with torch.no_grad(): with torch.no_grad():
for f, t in zip(unflatten_tensors(flat_tensor, tensors_with_same_dtype), tensors_with_same_dtype,): for f, t in zip(
unflatten_tensors(flat_tensor, tensors_with_same_dtype),
tensors_with_same_dtype,
):
t.copy_(f) t.copy_(f)
if logger is not None: if logger is not None:
logger.debug("Unflatten completed") logger.debug("Unflatten completed")
......
...@@ -15,7 +15,7 @@ from .data import DataConsumer ...@@ -15,7 +15,7 @@ from .data import DataConsumer
class MultiInputSequential(nn.Module): class MultiInputSequential(nn.Module):
"""A variation of nn.Sequential, that allows the first module in the sequence accepts """A variation of nn.Sequential, that allows the first module in the sequence accepts
multiple inputs. To be used internally by _split_module multiple inputs. To be used internally by _split_module
""" """
def __init__(self, *modules: nn.Module) -> None: def __init__(self, *modules: nn.Module) -> None:
...@@ -198,7 +198,9 @@ class PipelineModulesGraph(nn.Module): ...@@ -198,7 +198,9 @@ class PipelineModulesGraph(nn.Module):
remote_module = partition[0].module.get_module_rref() remote_module = partition[0].module.get_module_rref()
else: else:
remote_module = rpc.remote( remote_module = rpc.remote(
partition[0].module.on, RemoteSequential, args=([p.module.get_module_rref() for p in partition],), partition[0].module.on,
RemoteSequential,
args=([p.module.get_module_rref() for p in partition],),
) )
partitions.append((partition, remote_module)) partitions.append((partition, remote_module))
......
...@@ -25,7 +25,7 @@ ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] ...@@ -25,7 +25,7 @@ ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
class DistributedPipelineRecord: class DistributedPipelineRecord:
""" A class for storing a single mini-batch (consisting of multiple micro-batches) as input to """A class for storing a single mini-batch (consisting of multiple micro-batches) as input to
a single partition. a single partition.
Args: Args:
device: the local device that runs the partition. device: the local device that runs the partition.
...@@ -73,7 +73,7 @@ class DistributedPipelineRecord: ...@@ -73,7 +73,7 @@ class DistributedPipelineRecord:
return {} return {}
def feed(self, chunk: int, input_idx: int, input: Tensor) -> Tensor: def feed(self, chunk: int, input_idx: int, input: Tensor) -> Tensor:
""" This function is called remotely to provide individual tensors of a given chunk.""" """This function is called remotely to provide individual tensors of a given chunk."""
if input.device.type == "cpu": if input.device.type == "cpu":
input = input.to(self.device) input = input.to(self.device)
cuda_stream = torch.cuda.current_stream(input.device) if input.device.type == "cuda" else None cuda_stream = torch.cuda.current_stream(input.device) if input.device.type == "cuda" else None
...@@ -267,8 +267,8 @@ class PartitionHandler: ...@@ -267,8 +267,8 @@ class PartitionHandler:
def run_pipeline(self, pipeline_record_rref: rpc.RRef) -> Optional[Tensor]: def run_pipeline(self, pipeline_record_rref: rpc.RRef) -> Optional[Tensor]:
"""Processes a min-batch on this partition. """Processes a min-batch on this partition.
If this is the last partition (pipeline_record has no consumer), concatenates results of processing If this is the last partition (pipeline_record has no consumer), concatenates results of processing
all chunks and returns the result as the output of the model on the whole mini-batch. all chunks and returns the result as the output of the model on the whole mini-batch.
""" """
pipeline_record = pipeline_record_rref.local_value() pipeline_record = pipeline_record_rref.local_value()
self.run(pipeline_record) self.run(pipeline_record)
......
...@@ -70,7 +70,12 @@ class DistributedPipeline(nn.Module): ...@@ -70,7 +70,12 @@ class DistributedPipeline(nn.Module):
DataConsumer = DataConsumer[Partition] DataConsumer = DataConsumer[Partition]
def __init__(self, graph: PipelineModulesGraph, chunks: int = 1, checkpoint: str = "except_last",) -> None: def __init__(
self,
graph: PipelineModulesGraph,
chunks: int = 1,
checkpoint: str = "except_last",
) -> None:
super().__init__() super().__init__()
check_pytorch_version() check_pytorch_version()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment