Unverified Commit 2350968e authored by Crutcher Dunnavant's avatar Crutcher Dunnavant Committed by GitHub
Browse files

Move f/utils => f/internal; move testing libs to fair_dev/testing (#1004)

parent 3b727945
...@@ -11,8 +11,8 @@ from typing import Dict, List, Optional ...@@ -11,8 +11,8 @@ from typing import Dict, List, Optional
import torch import torch
from fairscale.internal.object import pyobject_to_tensor, tensor_to_pyobject
from fairscale.nn.model_parallel import get_pipeline_parallel_group from fairscale.nn.model_parallel import get_pipeline_parallel_group
from fairscale.utils.object import pyobject_to_tensor, tensor_to_pyobject
from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors
......
...@@ -27,7 +27,7 @@ from torch import Tensor, nn ...@@ -27,7 +27,7 @@ from torch import Tensor, nn
import torch.autograd import torch.autograd
import torch.cuda import torch.cuda
from fairscale.utils import torch_version from fairscale.internal import torch_version
from . import microbatch from . import microbatch
from .batchnorm import DeferredBatchNorm from .batchnorm import DeferredBatchNorm
......
...@@ -18,7 +18,7 @@ import torch.distributed as dist ...@@ -18,7 +18,7 @@ import torch.distributed as dist
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.sgd import SGD from torch.optim.sgd import SGD
from fairscale.utils import torch_version from fairscale.internal import torch_version
class _GeneralMultiDeviceReplicator(object): class _GeneralMultiDeviceReplicator(object):
......
...@@ -17,8 +17,8 @@ import torch.distributed as dist ...@@ -17,8 +17,8 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from fairscale.internal.params import calc_grad_norm, get_global_rank, recursive_copy_to_device
from fairscale.nn.misc import ParamBucket from fairscale.nn.misc import ParamBucket
from fairscale.utils.params import calc_grad_norm, get_global_rank, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
......
...@@ -22,8 +22,8 @@ from torch import nn ...@@ -22,8 +22,8 @@ from torch import nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from fair_dev.testing.testing import get_worker_map, torch_spawn
from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
class MySGD(Optimizer): class MySGD(Optimizer):
......
...@@ -15,8 +15,8 @@ from torch import nn ...@@ -15,8 +15,8 @@ from torch import nn
import torch.distributed import torch.distributed
import torch.nn.functional as F import torch.nn.functional as F
from fair_dev.testing.testing import skip_if_single_gpu, spawn_for_all_world_sizes
import fairscale.experimental.nn.data_parallel.gossip as gossip import fairscale.experimental.nn.data_parallel.gossip as gossip
from fairscale.utils.testing import skip_if_single_gpu, spawn_for_all_world_sizes
# Enfore CUBLAS reproducibility, see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility # Enfore CUBLAS reproducibility, see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
import torch.nn import torch.nn
import torch.nn as nn import torch.nn as nn
from fairscale.utils import torch_version from fairscale.internal import torch_version
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
......
...@@ -12,9 +12,9 @@ import os ...@@ -12,9 +12,9 @@ import os
import pytest import pytest
import torch import torch
from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.experimental.nn import MEVO from fairscale.experimental.nn import MEVO
from fairscale.experimental.nn.mevo import BaselineSoftmaxNllLoss, get_data from fairscale.experimental.nn.mevo import BaselineSoftmaxNllLoss, get_data
from fairscale.utils.testing import skip_if_no_cuda
@pytest.fixture(scope="session", params=[torch.float16, torch.float32]) @pytest.fixture(scope="session", params=[torch.float16, torch.float32])
......
...@@ -20,9 +20,9 @@ import torch.distributed.rpc as rpc ...@@ -20,9 +20,9 @@ import torch.distributed.rpc as rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from fair_dev.testing.testing import skip_if_single_gpu
from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.utils import torch_version from fairscale.internal import torch_version
from fairscale.utils.testing import skip_if_single_gpu
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or torch_version() < (1, 9, 0), not torch.cuda.is_available() or torch_version() < (1, 9, 0),
......
...@@ -14,9 +14,9 @@ import numpy as np ...@@ -14,9 +14,9 @@ import numpy as np
import pytest import pytest
import torch import torch
from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.experimental.nn.offload import OffloadModel from fairscale.experimental.nn.offload import OffloadModel
from fairscale.utils import torch_version from fairscale.internal import torch_version
from fairscale.utils.testing import skip_if_no_cuda
if torch_version() >= (1, 8, 0): if torch_version() >= (1, 8, 0):
from fairscale.experimental.nn.auto_shard import shard_model from fairscale.experimental.nn.auto_shard import shard_model
......
...@@ -10,13 +10,13 @@ import torch.multiprocessing as mp ...@@ -10,13 +10,13 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from fair_dev.testing.testing import GPT2, dist_init, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
from fairscale.experimental.tooling.layer_memory_tracker import ( from fairscale.experimental.tooling.layer_memory_tracker import (
LayerwiseMemoryTracker, LayerwiseMemoryTracker,
ProcessGroupTracker, ProcessGroupTracker,
find_best_reset_points, find_best_reset_points,
) )
from fairscale.nn import FullyShardedDataParallel from fairscale.nn import FullyShardedDataParallel
from fairscale.utils.testing import GPT2, dist_init, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
@skip_if_no_cuda() @skip_if_no_cuda()
......
...@@ -10,11 +10,11 @@ import torch ...@@ -10,11 +10,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.internal import torch_version
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper, disable_checkpointing from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper, disable_checkpointing
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda
def get_cuda_mem_allocated(): def get_cuda_mem_allocated():
......
...@@ -14,9 +14,9 @@ import torch ...@@ -14,9 +14,9 @@ import torch
from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import objects_are_equal
from fairscale.internal import torch_version
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.utils import torch_version
from fairscale.utils.testing import objects_are_equal
NORM_TYPES = [LayerNorm, BatchNorm2d] NORM_TYPES = [LayerNorm, BatchNorm2d]
MP_TYPES = ["fp32", "fp16", "call_half"] MP_TYPES = ["fp32", "fp16", "call_half"]
......
...@@ -18,10 +18,7 @@ import torch ...@@ -18,10 +18,7 @@ import torch
from torch import nn from torch import nn
import torch.distributed import torch.distributed
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fair_dev.testing.testing import (
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import (
DeviceAndTypeCheckModule, DeviceAndTypeCheckModule,
DummyProcessGroup, DummyProcessGroup,
dist_init, dist_init,
...@@ -30,6 +27,9 @@ from fairscale.utils.testing import ( ...@@ -30,6 +27,9 @@ from fairscale.utils.testing import (
skip_a_test_if_in_CI, skip_a_test_if_in_CI,
spawn_for_all_world_sizes, spawn_for_all_world_sizes,
) )
from fairscale.internal import torch_version
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
if torch_version() >= (1, 8, 0): if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
......
...@@ -10,7 +10,7 @@ from parameterized import parameterized ...@@ -10,7 +10,7 @@ from parameterized import parameterized
import pytest import pytest
import torch.nn as nn import torch.nn as nn
from fairscale.utils import torch_version from fairscale.internal import torch_version
from .test_fsdp import ( from .test_fsdp import (
CONFIG_OPTIONS, CONFIG_OPTIONS,
......
...@@ -21,8 +21,8 @@ import torch.nn as nn ...@@ -21,8 +21,8 @@ import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim import torch.optim as optim
from fair_dev.testing.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown
class FreezeModel(nn.Module): class FreezeModel(nn.Module):
......
...@@ -12,8 +12,8 @@ from unittest.mock import patch ...@@ -12,8 +12,8 @@ from unittest.mock import patch
from parameterized import parameterized from parameterized import parameterized
import torch import torch
from fair_dev.testing.testing import DummyProcessGroup, make_cudnn_deterministic, objects_are_equal
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, make_cudnn_deterministic, objects_are_equal
from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init
......
...@@ -6,9 +6,9 @@ import unittest ...@@ -6,9 +6,9 @@ import unittest
import torch import torch
from torch import nn from torch import nn
from fair_dev.testing.testing import dist_init
from fairscale.nn import FullyShardedDataParallel as FSDP from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, enable_wrap from fairscale.nn import auto_wrap, enable_wrap
from fairscale.utils.testing import dist_init
def wrap_transformer_only(module, recurse, **kwargs): def wrap_transformer_only(module, recurse, **kwargs):
......
...@@ -16,10 +16,10 @@ import torch ...@@ -16,10 +16,10 @@ import torch
from torch.nn import Linear, Module from torch.nn import Linear, Module
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, rmf, skip_if_no_cuda, teardown
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, rmf, skip_if_no_cuda, teardown
# A fixture to get tempfiles and ensure they are cleaned up. # A fixture to get tempfiles and ensure they are cleaned up.
......
...@@ -18,12 +18,12 @@ import torch.nn as nn ...@@ -18,12 +18,12 @@ import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim import torch.optim as optim
from fair_dev.testing.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.internal import torch_version
from fairscale.internal.parallel import get_process_group_cached
from fairscale.nn import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.utils import torch_version
from fairscale.utils.parallel import get_process_group_cached
from fairscale.utils.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx
def to_fsdp(module, fsdp_config): def to_fsdp(module, fsdp_config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment