Unverified Commit 2350968e authored by Crutcher Dunnavant's avatar Crutcher Dunnavant Committed by GitHub
Browse files

Move f/utils => f/internal; move testing libs to fair_dev/testing (#1004)

parent 3b727945
...@@ -21,12 +21,12 @@ from torch.utils.data import DataLoader ...@@ -21,12 +21,12 @@ from torch.utils.data import DataLoader
import torchtext import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from fair_dev.testing.testing import dist_init, get_worker_map
from fairscale.experimental.nn.ampnet_pipe import pipe from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map
try: try:
from fairscale.optim import Adam # type: ignore from fairscale.optim import Adam # type: ignore
......
...@@ -16,9 +16,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -16,9 +16,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import utils import utils
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2 from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
from fair_dev.testing.testing import dist_init
from fairscale.nn import Pipe from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.utils.testing import dist_init
MPI_PORT = 29500 MPI_PORT = 29500
RPC_PORT = 29501 RPC_PORT = 29501
......
...@@ -49,9 +49,9 @@ from torch.distributed import rpc ...@@ -49,9 +49,9 @@ from torch.distributed import rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from fairscale.internal import torch_version
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
from fairscale.utils import torch_version
if TYPE_CHECKING: if TYPE_CHECKING:
Base = nn.Module[Tensor] Base = nn.Module[Tensor]
......
...@@ -10,8 +10,8 @@ import torch ...@@ -10,8 +10,8 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.distributed import rpc from torch.distributed import rpc
from fairscale.internal import torch_version
from fairscale.nn.pipe import microbatch from fairscale.nn.pipe import microbatch
from fairscale.utils import torch_version
from .data import DataConsumer from .data import DataConsumer
from .graph import Node, PipelineModulesGraph from .graph import Node, PipelineModulesGraph
......
...@@ -17,7 +17,7 @@ import numpy as np ...@@ -17,7 +17,7 @@ import numpy as np
import torch import torch
from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL
from fairscale.utils import torch_version from fairscale.internal import torch_version
try: try:
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
......
...@@ -10,8 +10,8 @@ from torch import Tensor ...@@ -10,8 +10,8 @@ from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from fairscale.internal import torch_version
from fairscale.nn.checkpoint import is_checkpointing, is_recomputing from fairscale.nn.checkpoint import is_checkpointing, is_recomputing
from fairscale.utils import torch_version
def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor: def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
......
...@@ -14,7 +14,7 @@ from torch import Tensor ...@@ -14,7 +14,7 @@ from torch import Tensor
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as torch_checkpoint import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors from fairscale.internal.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from .checkpoint_utils import patch_batchnorm from .checkpoint_utils import patch_batchnorm
......
...@@ -40,19 +40,19 @@ import torch.nn as nn ...@@ -40,19 +40,19 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.internal.containers import apply_to_tensors
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap from fairscale.internal.parallel import (
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
ProcessGroupName, ProcessGroupName,
chunk_and_pad, chunk_and_pad,
enable_pytorch_sync_bn, enable_pytorch_sync_bn,
get_process_group_cached, get_process_group_cached,
validate_process_group, validate_process_group,
) )
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device from fairscale.internal.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.internal.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.internal.state_dict import replace_by_prefix_
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from . import fsdp_optim_utils as ou from . import fsdp_optim_utils as ou
......
...@@ -21,9 +21,9 @@ from torch.autograd import Variable ...@@ -21,9 +21,9 @@ from torch.autograd import Variable
import torch.autograd.profiler as profiler import torch.autograd.profiler as profiler
import torch.distributed as dist import torch.distributed as dist
from fairscale.internal.params import Workhandle, get_global_rank
from fairscale.nn.misc import GradBucket from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.utils.params import Workhandle, get_global_rank
def _trainable(param: torch.Tensor) -> bool: def _trainable(param: torch.Tensor) -> bool:
......
...@@ -44,7 +44,7 @@ except ImportError: ...@@ -44,7 +44,7 @@ except ImportError:
import_ssd_offload = False import_ssd_offload = False
pass pass
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.internal.state_dict import replace_by_prefix_
if TYPE_CHECKING: if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401 from collections import OrderedDict # noqa: F401
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment