Unverified Commit 2350968e authored by Crutcher Dunnavant's avatar Crutcher Dunnavant Committed by GitHub
Browse files

Move f/utils => f/internal; move testing libs to fair_dev/testing (#1004)

parent 3b727945
...@@ -14,8 +14,8 @@ import torch.multiprocessing as mp ...@@ -14,8 +14,8 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam from torch.optim import Adam
from fair_dev.testing.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx
from fairscale.nn import FullyShardedDataParallel from fairscale.nn import FullyShardedDataParallel
from fairscale.utils.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx
from tests.nn.data_parallel.test_fsdp import DistributedTest, MixtureOfExperts, rename_test, spawn_and_init from tests.nn.data_parallel.test_fsdp import DistributedTest, MixtureOfExperts, rename_test, spawn_and_init
USE_TEMPFILE = True # False for debugging USE_TEMPFILE = True # False for debugging
......
...@@ -17,10 +17,10 @@ import torch.multiprocessing as mp ...@@ -17,10 +17,10 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Module from torch.nn import Linear, Module
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
......
...@@ -20,12 +20,12 @@ import torch.nn as nn ...@@ -20,12 +20,12 @@ import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim import torch.optim as optim
from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.internal import torch_version
from fairscale.nn import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap from fairscale.nn.wrap import enable_wrap, wrap
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
class Model(nn.Module): class Model(nn.Module):
......
...@@ -17,10 +17,10 @@ import torch.multiprocessing as mp ...@@ -17,10 +17,10 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Module, Sequential from torch.nn import Linear, Module, Sequential
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, skip_if_no_cuda, teardown
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
......
...@@ -23,9 +23,9 @@ except ImportError as ie: ...@@ -23,9 +23,9 @@ except ImportError as ie:
pytestmark = pytest.mark.skipif(True, reason=ie.msg) pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass pass
from fair_dev.testing.testing import dist_init, spawn_for_all_world_sizes
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod # All helper functions called by spawn must be either @classmethod, @staticmethod
......
...@@ -11,10 +11,10 @@ import torch ...@@ -11,10 +11,10 @@ import torch
from torch import nn from torch import nn
from torch.optim import SGD, Adadelta, Adam # type: ignore from torch.optim import SGD, Adadelta, Adam # type: ignore
from fair_dev.testing.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes
from fairscale.internal.params import recursive_copy_to_device
from fairscale.nn import FullyShardedDataParallel from fairscale.nn import FullyShardedDataParallel
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from fairscale.utils.params import recursive_copy_to_device
from fairscale.utils.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes
from .test_fsdp import ( from .test_fsdp import (
DistributedTest, DistributedTest,
......
...@@ -19,10 +19,10 @@ from torch.cuda import Event ...@@ -19,10 +19,10 @@ from torch.cuda import Event
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from fair_dev.testing.testing import dist_init, get_cycles_per_ms, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.internal import torch_version
from fairscale.nn import enable_wrap, wrap from fairscale.nn import enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, get_cycles_per_ms, skip_if_single_gpu, teardown, temp_files_ctx
class Layer(nn.Module): class Layer(nn.Module):
......
...@@ -13,8 +13,8 @@ import pytest ...@@ -13,8 +13,8 @@ import pytest
import torch import torch
from torch.nn import Linear, Module from torch.nn import Linear, Module
from fair_dev.testing.testing import dist_init, skip_if_no_cuda, teardown, temp_files_ctx
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown, temp_files_ctx
# A fixture to get tempfiles and ensure they are cleaned up. # A fixture to get tempfiles and ensure they are cleaned up.
......
...@@ -33,10 +33,7 @@ from torch.nn import ( ...@@ -33,10 +33,7 @@ from torch.nn import (
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fair_dev.testing.testing import (
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.utils import torch_version
from fairscale.utils.testing import (
dist_init, dist_init,
objects_are_equal, objects_are_equal,
rmf, rmf,
...@@ -45,6 +42,9 @@ from fairscale.utils.testing import ( ...@@ -45,6 +42,9 @@ from fairscale.utils.testing import (
teardown, teardown,
torch_cuda_version, torch_cuda_version,
) )
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
if torch_version() >= (1, 8, 0): if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
......
...@@ -17,8 +17,8 @@ import torch.multiprocessing as mp ...@@ -17,8 +17,8 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Module from torch.nn import Linear, Module
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, objects_are_equal, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_single_gpu, teardown, temp_files_ctx
class Model(Module): class Model(Module):
......
...@@ -17,9 +17,7 @@ from torch import nn ...@@ -17,9 +17,7 @@ from torch import nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.optim import SGD from torch.optim import SGD
from fairscale.experimental.nn import MEVO from fair_dev.testing.testing import (
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import (
dist_init, dist_init,
in_circle_ci, in_circle_ci,
objects_are_equal, objects_are_equal,
...@@ -27,6 +25,8 @@ from fairscale.utils.testing import ( ...@@ -27,6 +25,8 @@ from fairscale.utils.testing import (
teardown, teardown,
temp_files_ctx, temp_files_ctx,
) )
from fairscale.experimental.nn import MEVO
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
VOCAB = 4 VOCAB = 4
D_MODEL = 2 D_MODEL = 2
......
...@@ -11,9 +11,9 @@ import pytest ...@@ -11,9 +11,9 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fair_dev.testing.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx
from .test_fsdp import ( from .test_fsdp import (
CONFIG_OPTIONS, CONFIG_OPTIONS,
......
...@@ -11,7 +11,7 @@ from parameterized import parameterized ...@@ -11,7 +11,7 @@ from parameterized import parameterized
import pytest import pytest
import torch import torch
from fairscale.utils.version import torch_version from fairscale.internal.version import torch_version
from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init
......
...@@ -18,10 +18,10 @@ import torch.multiprocessing as mp ...@@ -18,10 +18,10 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case):
......
...@@ -13,9 +13,9 @@ from torch import nn ...@@ -13,9 +13,9 @@ from torch import nn
import torch.distributed import torch.distributed
import torch.multiprocessing as mp import torch.multiprocessing as mp
from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
@skip_if_single_gpu @skip_if_single_gpu
......
...@@ -16,9 +16,7 @@ import torch.distributed as dist ...@@ -16,9 +16,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from fairscale.nn.data_parallel import ShardedDataParallel from fair_dev.testing.testing import (
from fairscale.optim import OSS
from fairscale.utils.testing import (
GPT2, GPT2,
SGDWithPausingCompute, SGDWithPausingCompute,
available_devices, available_devices,
...@@ -28,6 +26,8 @@ from fairscale.utils.testing import ( ...@@ -28,6 +26,8 @@ from fairscale.utils.testing import (
skip_if_single_gpu, skip_if_single_gpu,
temp_files_ctx, temp_files_ctx,
) )
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
def _get_mlp(tripwire: bool = False): def _get_mlp(tripwire: bool = False):
......
...@@ -19,10 +19,10 @@ import torch.multiprocessing as mp ...@@ -19,10 +19,10 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from fair_dev.testing.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.utils import torch_version
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
if torch_version() >= (1, 8, 0): if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
......
...@@ -10,8 +10,8 @@ import unittest ...@@ -10,8 +10,8 @@ import unittest
import torch import torch
from fair_dev.testing.testing import objects_are_equal
from fairscale.nn import FlattenParamsWrapper from fairscale.nn import FlattenParamsWrapper
from fairscale.utils.testing import objects_are_equal
class TestFlattenParams(unittest.TestCase): class TestFlattenParams(unittest.TestCase):
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fair_dev.testing.testing import IdentityLayer, dist_init, set_random_seed, spawn_for_all_world_sizes
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel.cross_entropy import vocab_parallel_cross_entropy from fairscale.nn.model_parallel.cross_entropy import vocab_parallel_cross_entropy
from fairscale.nn.model_parallel.mappings import scatter_to_model_parallel_region from fairscale.nn.model_parallel.mappings import scatter_to_model_parallel_region
from fairscale.utils.testing import IdentityLayer, dist_init, set_random_seed, spawn_for_all_world_sizes
def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
import torch import torch
from fair_dev.testing.testing import dist_init, spawn_for_all_world_sizes
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_initialize_model_parallel(rank, model_parallel_size, filename, filename_rpc): def run_test_initialize_model_parallel(rank, model_parallel_size, filename, filename_rpc):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment