Unverified Commit 72c6bab2 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[chore] Rename and move checkpoint_activations from misc folder. (#654)

* rename files

* add newly renamed file

* rename and move checkpoint activations related files

* add test files to ci list

* fix lint errors

* modify docs

* add changelog

* retain old path for now

* fix lint errors

* add another import test case

* fix merge conflict

* add missing test file
parent c141f8db
...@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag - FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar - setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647)) - SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
- Cleanup - rename and move the checkpoint_activations wrapper ([654]https://github.com/facebookresearch/fairscale/pull/654)
### Added ### Added
- FSDP: added `force\_input\_to\_fp32` flag for SyncBatchNorm - FSDP: added `force\_input\_to\_fp32` flag for SyncBatchNorm
......
...@@ -11,5 +11,5 @@ API Reference ...@@ -11,5 +11,5 @@ API Reference
nn/sharded_ddp nn/sharded_ddp
nn/fsdp nn/fsdp
nn/fsdp_tips nn/fsdp_tips
nn/misc/checkpoint_activations nn/checkpoint/checkpoint_activations
experimental/nn/offload_model experimental/nn/offload_model
checkpoint_wrapper checkpoint_wrapper
================== ==================
.. autoclass:: fairscale.nn.misc.checkpoint_wrapper .. autoclass:: fairscale.nn.checkpoint.checkpoint_wrapper
:members: :members:
:undoc-members: :undoc-members:
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
from typing import List from typing import List
from .checkpoint import checkpoint_wrapper
from .data_parallel import FullyShardedDataParallel, ShardedDataParallel from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper, checkpoint_wrapper from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from .checkpoint_activations import checkpoint_wrapper
__all__: List[str] = []
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
from typing import List from typing import List
from .checkpoint_activations import checkpoint_wrapper # TODO(anj-s): Remove this once we have deprecated fairscale.nn.misc.checkpoint_wrapper path
# in favor of fairscale.nn.checkpoint.checkpoint_wrapper.
from fairscale.nn.checkpoint import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper from .flatten_params_wrapper import FlattenParamsWrapper
from .param_bucket import GradBucket, ParamBucket from .param_bucket import GradBucket, ParamBucket
......
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_overlap.py tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_apply.py tests/nn/data_parallel/test_fsdp_apply.py
...@@ -8,6 +6,10 @@ tests/utils/test_reduce_scatter_bucketer.py ...@@ -8,6 +6,10 @@ tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py tests/utils/test_containers.py
tests/utils/test_parallel.py tests/utils/test_parallel.py
tests/utils/test_state_dict.py tests/utils/test_state_dict.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py tests/nn/pipe_process/test_transparency.py
......
...@@ -10,7 +10,8 @@ import torch ...@@ -10,7 +10,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
from fairscale.utils.testing import skip_if_no_cuda, torch_version from fairscale.utils.testing import skip_if_no_cuda, torch_version
...@@ -252,3 +253,17 @@ def test_multiin_multiout(device, multiout, checkpoint_config): ...@@ -252,3 +253,17 @@ def test_multiin_multiout(device, multiout, checkpoint_config):
if no_cpt[key] != cpt[key]: if no_cpt[key] != cpt[key]:
print(no_cpt, cpt) print(no_cpt, cpt)
assert 0 assert 0
def test_deprecated_path():
# Check if import works as before.
# from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.nn import checkpoint_wrapper
ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
ffn = checkpoint_wrapper(ffn, {})
# Check if direct import works as before.
ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
ffn = deprecated_checkpoint_wrapper(ffn, {})
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from torch.optim import SGD from torch.optim import SGD
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import objects_are_equal, torch_version from fairscale.utils.testing import objects_are_equal, torch_version
NORM_TYPES = [LayerNorm, BatchNorm2d] NORM_TYPES = [LayerNorm, BatchNorm2d]
......
...@@ -16,8 +16,8 @@ from parameterized import parameterized ...@@ -16,8 +16,8 @@ from parameterized import parameterized
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import ( from fairscale.utils.testing import (
DeviceAndTypeCheckModule, DeviceAndTypeCheckModule,
DummyProcessGroup, DummyProcessGroup,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment