Unverified Commit 5739930f authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[chore] Rename and move utils.py from optim/ to utils/ (#669)

* rename and move optim/utils.py

* attach the new file
parent 99b30a04
...@@ -23,9 +23,9 @@ import torch.nn.functional as F ...@@ -23,9 +23,9 @@ import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.containers import apply_to_tensors from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.params import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.utils.state_dict import replace_by_prefix_
......
...@@ -23,7 +23,7 @@ import torch.distributed as dist ...@@ -23,7 +23,7 @@ import torch.distributed as dist
from fairscale.nn.misc import GradBucket from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.utils import Workhandle, get_global_rank from fairscale.utils.params import Workhandle, get_global_rank
def _trainable(param: torch.Tensor) -> bool: def _trainable(param: torch.Tensor) -> bool:
......
...@@ -17,8 +17,7 @@ from torch.nn import Parameter ...@@ -17,8 +17,7 @@ from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from fairscale.nn.misc import ParamBucket from fairscale.nn.misc import ParamBucket
from fairscale.utils.params import broadcast_object, calc_grad_norm, get_global_rank, recursive_copy_to_device
from .utils import broadcast_object, calc_grad_norm, get_global_rank, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
......
...@@ -11,7 +11,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore ...@@ -11,7 +11,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore
from fairscale.nn import FullyShardedDataParallel from fairscale.nn import FullyShardedDataParallel
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from fairscale.optim.utils import recursive_copy_to_device from fairscale.utils.params import recursive_copy_to_device
from fairscale.utils.testing import objects_are_equal from fairscale.utils.testing import objects_are_equal
from .test_fsdp import ( from .test_fsdp import (
...@@ -92,7 +92,7 @@ class TestOptimizerUtils(DistributedTest): ...@@ -92,7 +92,7 @@ class TestOptimizerUtils(DistributedTest):
tstart = time() tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart duration = time() - tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise # Switching from fairscale.utils.params.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3 cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
......
...@@ -22,6 +22,7 @@ import torch.multiprocessing as mp ...@@ -22,6 +22,7 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim import fairscale.optim as optim
import fairscale.utils as utils
from fairscale.utils.testing import ( from fairscale.utils.testing import (
check_same_model_params, check_same_model_params,
check_same_models_across_ranks, check_same_models_across_ranks,
...@@ -40,7 +41,7 @@ try: ...@@ -40,7 +41,7 @@ try:
_torch_broadcast_object = True _torch_broadcast_object = True
except ImportError: except ImportError:
from fairscale.optim.utils import broadcast_object # noqa from fairscale.utils.params import broadcast_object # noqa
_torch_broadcast_object = False _torch_broadcast_object = False
...@@ -56,7 +57,7 @@ def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch ...@@ -56,7 +57,7 @@ def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch
dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD) dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD)
package_sync = package[0] package_sync = package[0]
else: else:
package_sync = optim.utils.broadcast_object( package_sync = utils.params.broadcast_object(
something_to_sync, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device something_to_sync, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment