Unverified Commit 72c6bab2 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[chore] Rename and move checkpoint_activations from misc folder. (#654)

* rename files

* add newly renamed file

* rename and move checkpoint activations related files

* add test files to ci list

* fix lint errors

* modify docs

* add changelog

* retain old path for now

* fix lint errors

* add another import test case

* fix merge conflict

* add missing test file
parent c141f8db
......@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
- Cleanup - rename and move the checkpoint_activations wrapper ([654]https://github.com/facebookresearch/fairscale/pull/654)
### Added
- FSDP: added `force\_input\_to\_fp32` flag for SyncBatchNorm
......
......@@ -11,5 +11,5 @@ API Reference
nn/sharded_ddp
nn/fsdp
nn/fsdp_tips
nn/misc/checkpoint_activations
nn/checkpoint/checkpoint_activations
experimental/nn/offload_model
checkpoint_wrapper
==================
.. autoclass:: fairscale.nn.misc.checkpoint_wrapper
.. autoclass:: fairscale.nn.checkpoint.checkpoint_wrapper
:members:
:undoc-members:
......@@ -5,8 +5,9 @@
from typing import List
from .checkpoint import checkpoint_wrapper
from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper, checkpoint_wrapper
from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from .checkpoint_activations import checkpoint_wrapper
__all__: List[str] = []
......@@ -5,7 +5,10 @@
from typing import List
from .checkpoint_activations import checkpoint_wrapper
# TODO(anj-s): Remove this once we have deprecated fairscale.nn.misc.checkpoint_wrapper path
# in favor of fairscale.nn.checkpoint.checkpoint_wrapper.
from fairscale.nn.checkpoint import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper
from .param_bucket import GradBucket, ParamBucket
......
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_apply.py
......@@ -8,6 +6,10 @@ tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py
......
......@@ -10,7 +10,8 @@ import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
from fairscale.utils.testing import skip_if_no_cuda, torch_version
......@@ -252,3 +253,17 @@ def test_multiin_multiout(device, multiout, checkpoint_config):
if no_cpt[key] != cpt[key]:
print(no_cpt, cpt)
assert 0
def test_deprecated_path():
# Check if import works as before.
# from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.nn import checkpoint_wrapper
ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
ffn = checkpoint_wrapper(ffn, {})
# Check if direct import works as before.
ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
ffn = deprecated_checkpoint_wrapper(ffn, {})
......@@ -14,7 +14,7 @@ import torch
from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from torch.optim import SGD
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import objects_are_equal, torch_version
NORM_TYPES = [LayerNorm, BatchNorm2d]
......
......@@ -16,8 +16,8 @@ from parameterized import parameterized
import torch
from torch import nn
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import (
DeviceAndTypeCheckModule,
DummyProcessGroup,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment