Unverified Commit 5739930f authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[chore] Rename and move utils.py from optim/ to utils/ (#669)

* rename and move optim/utils.py

* attach the new file
parent 99b30a04
......@@ -23,9 +23,9 @@ import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.params import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
......
......@@ -23,7 +23,7 @@ import torch.distributed as dist
from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS
from fairscale.optim.utils import Workhandle, get_global_rank
from fairscale.utils.params import Workhandle, get_global_rank
def _trainable(param: torch.Tensor) -> bool:
......
......@@ -17,8 +17,7 @@ from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from fairscale.nn.misc import ParamBucket
from .utils import broadcast_object, calc_grad_norm, get_global_rank, recursive_copy_to_device
from fairscale.utils.params import broadcast_object, calc_grad_norm, get_global_rank, recursive_copy_to_device
__all__ = ["OSS"]
......
......@@ -11,7 +11,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore
from fairscale.nn import FullyShardedDataParallel
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from fairscale.optim.utils import recursive_copy_to_device
from fairscale.utils.params import recursive_copy_to_device
from fairscale.utils.testing import objects_are_equal
from .test_fsdp import (
......@@ -92,7 +92,7 @@ class TestOptimizerUtils(DistributedTest):
tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
# Switching from fairscale.utils.params.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
......
......@@ -22,6 +22,7 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim
import fairscale.utils as utils
from fairscale.utils.testing import (
check_same_model_params,
check_same_models_across_ranks,
......@@ -40,7 +41,7 @@ try:
_torch_broadcast_object = True
except ImportError:
from fairscale.optim.utils import broadcast_object # noqa
from fairscale.utils.params import broadcast_object # noqa
_torch_broadcast_object = False
......@@ -56,7 +57,7 @@ def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch
dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD)
package_sync = package[0]
else:
package_sync = optim.utils.broadcast_object(
package_sync = utils.params.broadcast_object(
something_to_sync, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment