Unverified Commit 2350968e authored by Crutcher Dunnavant's avatar Crutcher Dunnavant Committed by GitHub
Browse files

Move f/utils => f/internal; move testing libs to fair_dev/testing (#1004)

parent 3b727945
......@@ -21,12 +21,12 @@ from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fair_dev.testing.testing import dist_init, get_worker_map
from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map
try:
from fairscale.optim import Adam # type: ignore
......
......@@ -16,9 +16,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import utils
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
from fair_dev.testing.testing import dist_init
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.utils.testing import dist_init
MPI_PORT = 29500
RPC_PORT = 29501
......
......@@ -49,9 +49,9 @@ from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from fairscale.internal import torch_version
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
from fairscale.utils import torch_version
if TYPE_CHECKING:
Base = nn.Module[Tensor]
......
......@@ -10,8 +10,8 @@ import torch
from torch import Tensor, nn
from torch.distributed import rpc
from fairscale.internal import torch_version
from fairscale.nn.pipe import microbatch
from fairscale.utils import torch_version
from .data import DataConsumer
from .graph import Node, PipelineModulesGraph
......
......@@ -17,7 +17,7 @@ import numpy as np
import torch
from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL
from fairscale.utils import torch_version
from fairscale.internal import torch_version
try:
from torch.utils._pytree import tree_map
......
......@@ -10,8 +10,8 @@ from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup
from fairscale.internal import torch_version
from fairscale.nn.checkpoint import is_checkpointing, is_recomputing
from fairscale.utils import torch_version
def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
......
......@@ -14,7 +14,7 @@ from torch import Tensor
import torch.nn as nn
import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from fairscale.internal.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from .checkpoint_utils import patch_batchnorm
......
......@@ -40,19 +40,19 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
from fairscale.internal.containers import apply_to_tensors
from fairscale.internal.parallel import (
ProcessGroupName,
chunk_and_pad,
enable_pytorch_sync_bn,
get_process_group_cached,
validate_process_group,
)
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
from fairscale.internal.params import calc_grad_norm, recursive_copy_to_device
from fairscale.internal.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.internal.state_dict import replace_by_prefix_
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from . import fsdp_optim_utils as ou
......
......@@ -21,9 +21,9 @@ from torch.autograd import Variable
import torch.autograd.profiler as profiler
import torch.distributed as dist
from fairscale.internal.params import Workhandle, get_global_rank
from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS
from fairscale.utils.params import Workhandle, get_global_rank
def _trainable(param: torch.Tensor) -> bool:
......
......@@ -44,7 +44,7 @@ except ImportError:
import_ssd_offload = False
pass
from fairscale.utils.state_dict import replace_by_prefix_
from fairscale.internal.state_dict import replace_by_prefix_
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment