Unverified Commit 9a6ca9bd authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[cleanup] consistent __init__.py for import * (#550)

- fixes #471
- one less thing to worry about during development.
parent 0233efca
# Contributing to fairscale # Contributing to fairscale
We want to make contributing to this project as easy and transparent as We want to make contributing to this project as easy and transparent as
possible. possible.
## Our Development Process ## Our Development Process
Minor changes and improvements will be released on an ongoing basis. Larger Minor changes and improvements will be released on an ongoing basis. Larger
changes (e.g., changesets implementing a new paper) will be released on a changes (e.g., changesets implementing a new paper) will be released on a
more periodic basis. more periodic basis.
## Pull Requests ## Pull Requests
We actively welcome your pull requests. We actively welcome your pull requests.
1. Fork the repo and create your branch from `master`. 1. Fork the repo and create your branch from `master`.
...@@ -18,12 +21,14 @@ We actively welcome your pull requests. ...@@ -18,12 +21,14 @@ We actively welcome your pull requests.
6. If you haven't already, complete the Contributor License Agreement ("CLA"). 6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA") ## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects. to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla> Complete your CLA here: <https://code.facebook.com/cla>
## Issues ## Issues
We use GitHub issues to track public bugs. Please ensure your description is We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue. clear and has sufficient instructions to be able to reproduce the issue.
...@@ -41,17 +46,21 @@ outlined on that page and do not file a public issue. ...@@ -41,17 +46,21 @@ outlined on that page and do not file a public issue.
``` ```
## Coding Style ## Coding Style
* In your editor, install the [editorconfig](https://editorconfig.org/) extension * In your editor, install the [editorconfig](https://editorconfig.org/) extension
which should ensure that you are following the same standards as us. which should ensure that you are following the same standards as us.
* Ideally, run black and isort before opening up your PR. * Please run black and isort before opening up your PR.
``` ```
black . black .
isort isort .
flake8 flake8
``` ```
* Read the [editorconfig](.editorconfig) file to understand the exact coding style preferences. * Please read the [editorconfig](.editorconfig) file to understand the exact coding style preferences.
* Place Python code related to models in fairscale/nn. Place Python code related to optimizers in fairscale/optim. Place C++ extensions in fairscale/clib. * Please place Python code related to models in fairscale/nn. Place Python code related to optimizers
in fairscale/optim. Place C++ extensions in fairscale/clib.
* Please put `__all__:List[str] = []` in new `__init__.py` files for consistent importing behavior
and less development overhead in maintaining an importing list.
## Testing ## Testing
...@@ -72,7 +81,8 @@ python -m pytest tests/nn/data_parallel/test_oss_ddp.py::test_on_cpu ...@@ -72,7 +81,8 @@ python -m pytest tests/nn/data_parallel/test_oss_ddp.py::test_on_cpu
### Check test coverage ### Check test coverage
``` ```
python -m pytest --cov-report term --cov=fairscale/nn/data_parallel tests/nn/data_parallel/test_oss_ddp.py::test_on_cpu python -m pytest --cov-report term --cov=fairscale/nn/data_parallel \
tests/nn/data_parallel/test_oss_ddp.py::test_on_cpu
``` ```
### CircleCI status ### CircleCI status
...@@ -108,6 +118,7 @@ Any line of the commit message cannot be longer 100 characters! This allows the ...@@ -108,6 +118,7 @@ Any line of the commit message cannot be longer 100 characters! This allows the
to read on github as well as in various git tools. to read on github as well as in various git tools.
### Type ### Type
Must be one of the following: Must be one of the following:
* **feat**: A new feature * **feat**: A new feature
...@@ -122,5 +133,6 @@ generation ...@@ -122,5 +133,6 @@ generation
* **docs**: Documentation only changes * **docs**: Documentation only changes
## License ## License
By contributing to fairscale, you agree that your contributions will be licensed By contributing to fairscale, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree. under the LICENSE file in the root directory of this source tree.
...@@ -10,4 +10,8 @@ __version__ = "0.3.2" ...@@ -10,4 +10,8 @@ __version__ = "0.3.2"
# Import most common subpackages # Import most common subpackages
################################################################################ ################################################################################
from typing import List
from . import nn from . import nn
__all__: List[str] = []
...@@ -7,4 +7,8 @@ ...@@ -7,4 +7,8 @@
# Import most common subpackages # Import most common subpackages
################################################################################ ################################################################################
from typing import List
from . import nn from . import nn
__all__: List[str] = []
...@@ -2,3 +2,7 @@ ...@@ -2,3 +2,7 @@
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
__all__: List[str] = []
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
__all__: List[str] = []
...@@ -3,23 +3,12 @@ ...@@ -3,23 +3,12 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
from .data_parallel import FullyShardedDataParallel, ShardedDataParallel from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper, checkpoint_wrapper from .misc import FlattenParamsWrapper, checkpoint_wrapper
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
__all__ = [ __all__: List[str] = []
"FlattenParamsWrapper",
"checkpoint_wrapper",
"FullyShardedDataParallel",
"ShardedDataParallel",
"Pipe",
"PipeRPCWrapper",
"MOELayer",
"Top2Gate",
"auto_wrap",
"default_auto_wrap_policy",
"enable_wrap",
"wrap",
]
...@@ -3,5 +3,9 @@ ...@@ -3,5 +3,9 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState, auto_wrap_bn from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState, auto_wrap_bn
from .sharded_ddp import ShardedDataParallel from .sharded_ddp import ShardedDataParallel
__all__: List[str] = []
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
from .checkpoint_activations import checkpoint_wrapper from .checkpoint_activations import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper from .flatten_params_wrapper import FlattenParamsWrapper
from .param_bucket import GradBucket, ParamBucket from .param_bucket import GradBucket, ParamBucket
__all__: List[str] = []
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
from .cross_entropy import vocab_parallel_cross_entropy from .cross_entropy import vocab_parallel_cross_entropy
from .initialize import ( from .initialize import (
destroy_model_parallel, destroy_model_parallel,
...@@ -20,3 +22,5 @@ from .initialize import ( ...@@ -20,3 +22,5 @@ from .initialize import (
from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from .mappings import copy_to_model_parallel_region, gather_from_model_parallel_region from .mappings import copy_to_model_parallel_region, gather_from_model_parallel_region
from .random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed from .random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed
__all__: List[str] = []
...@@ -3,5 +3,9 @@ ...@@ -3,5 +3,9 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
from .moe_layer import MOELayer from .moe_layer import MOELayer
from .top2gate import Top2Gate from .top2gate import Top2Gate
__all__: List[str] = []
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from .auto_wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap from .auto_wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
__all__: List[str] = []
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
:mod:`fairscale.optim` is a package implementing various torch optimization algorithms. :mod:`fairscale.optim` is a package implementing various torch optimization algorithms.
""" """
import logging import logging
from typing import List
from .adascale import AdaScale, AdaScaleWrapper from .adascale import AdaScale, AdaScaleWrapper
from .oss import OSS from .oss import OSS
...@@ -19,3 +20,5 @@ try: ...@@ -19,3 +20,5 @@ try:
from .grad_scaler import GradScaler from .grad_scaler import GradScaler
except ImportError: except ImportError:
logging.warning("Torch AMP is not available on this platform") logging.warning("Torch AMP is not available on this platform")
__all__: List[str] = []
...@@ -2,3 +2,7 @@ ...@@ -2,3 +2,7 @@
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
__all__: List[str] = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment