Unverified Commit 8f8f8ef9 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[chore] move fair_dev into fairscale (#1078)


Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent bfd57ff3
...@@ -21,7 +21,7 @@ import torch.nn as nn ...@@ -21,7 +21,7 @@ import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim import torch.optim as optim
from fair_dev.testing.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown from fairscale.fair_dev.testing.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
......
...@@ -12,7 +12,7 @@ from unittest.mock import patch ...@@ -12,7 +12,7 @@ from unittest.mock import patch
from parameterized import parameterized from parameterized import parameterized
import torch import torch
from fair_dev.testing.testing import DummyProcessGroup, make_cudnn_deterministic, objects_are_equal from fairscale.fair_dev.testing.testing import DummyProcessGroup, make_cudnn_deterministic, objects_are_equal
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel
from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init
......
...@@ -6,7 +6,7 @@ import unittest ...@@ -6,7 +6,7 @@ import unittest
import torch import torch
from torch import nn from torch import nn
from fair_dev.testing.testing import dist_init from fairscale.fair_dev.testing.testing import dist_init
from fairscale.nn import FullyShardedDataParallel as FSDP from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, enable_wrap from fairscale.nn import auto_wrap, enable_wrap
......
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
from torch.nn import Linear, Module from torch.nn import Linear, Module
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, rmf, skip_if_no_cuda, teardown from fairscale.fair_dev.testing.testing import dist_init, rmf, skip_if_no_cuda, teardown
from fairscale.internal import torch_version from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
......
...@@ -18,7 +18,7 @@ import torch.nn as nn ...@@ -18,7 +18,7 @@ import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim import torch.optim as optim
from fair_dev.testing.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx from fairscale.fair_dev.testing.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.internal import torch_version from fairscale.internal import torch_version
from fairscale.internal.parallel import get_process_group_cached from fairscale.internal.parallel import get_process_group_cached
from fairscale.nn import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper
......
...@@ -14,7 +14,7 @@ import torch.multiprocessing as mp ...@@ -14,7 +14,7 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam from torch.optim import Adam
from fair_dev.testing.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx from fairscale.fair_dev.testing.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx
from fairscale.nn import FullyShardedDataParallel from fairscale.nn import FullyShardedDataParallel
from tests.nn.data_parallel.test_fsdp import DistributedTest, MixtureOfExperts, rename_test, spawn_and_init from tests.nn.data_parallel.test_fsdp import DistributedTest, MixtureOfExperts, rename_test, spawn_and_init
......
...@@ -17,7 +17,7 @@ import torch.multiprocessing as mp ...@@ -17,7 +17,7 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Module from torch.nn import Linear, Module
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown from fairscale.fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown
from fairscale.internal import torch_version from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
......
...@@ -20,7 +20,7 @@ import torch.nn as nn ...@@ -20,7 +20,7 @@ import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim import torch.optim as optim
from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx from fairscale.fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.internal import torch_version from fairscale.internal import torch_version
from fairscale.nn import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
......
...@@ -17,7 +17,7 @@ import torch.multiprocessing as mp ...@@ -17,7 +17,7 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Module, Sequential from torch.nn import Linear, Module, Sequential
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, skip_if_no_cuda, teardown from fairscale.fair_dev.testing.testing import dist_init, skip_if_no_cuda, teardown
from fairscale.internal import torch_version from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
......
...@@ -25,7 +25,7 @@ except ImportError as ie: ...@@ -25,7 +25,7 @@ except ImportError as ie:
pytestmark = pytest.mark.skipif(True, reason=ie.msg) pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass pass
from fair_dev.testing.testing import dist_init, spawn_for_all_world_sizes from fairscale.fair_dev.testing.testing import dist_init, spawn_for_all_world_sizes
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from torch import nn from torch import nn
from torch.optim import SGD, Adadelta, Adam # type: ignore from torch.optim import SGD, Adadelta, Adam # type: ignore
from fair_dev.testing.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes from fairscale.fair_dev.testing.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes
from fairscale.internal.params import recursive_copy_to_device from fairscale.internal.params import recursive_copy_to_device
from fairscale.nn.data_parallel import FullyShardedDataParallel, get_fsdp_instances from fairscale.nn.data_parallel import FullyShardedDataParallel, get_fsdp_instances
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
......
...@@ -19,7 +19,13 @@ from torch.cuda import Event ...@@ -19,7 +19,13 @@ from torch.cuda import Event
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from fair_dev.testing.testing import dist_init, get_cycles_per_ms, skip_if_single_gpu, teardown, temp_files_ctx from fairscale.fair_dev.testing.testing import (
dist_init,
get_cycles_per_ms,
skip_if_single_gpu,
teardown,
temp_files_ctx,
)
from fairscale.internal import torch_version from fairscale.internal import torch_version
from fairscale.nn import enable_wrap, wrap from fairscale.nn import enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
......
...@@ -13,7 +13,7 @@ import pytest ...@@ -13,7 +13,7 @@ import pytest
import torch import torch
from torch.nn import Linear, Module from torch.nn import Linear, Module
from fair_dev.testing.testing import dist_init, skip_if_no_cuda, teardown, temp_files_ctx from fairscale.fair_dev.testing.testing import dist_init, skip_if_no_cuda, teardown, temp_files_ctx
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
......
...@@ -33,7 +33,7 @@ from torch.nn import ( ...@@ -33,7 +33,7 @@ from torch.nn import (
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import ( from fairscale.fair_dev.testing.testing import (
dist_init, dist_init,
objects_are_equal, objects_are_equal,
rmf, rmf,
......
...@@ -17,7 +17,13 @@ import torch.multiprocessing as mp ...@@ -17,7 +17,13 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Module from torch.nn import Linear, Module
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, objects_are_equal, skip_if_single_gpu, teardown, temp_files_ctx from fairscale.fair_dev.testing.testing import (
dist_init,
objects_are_equal,
skip_if_single_gpu,
teardown,
temp_files_ctx,
)
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
......
...@@ -17,7 +17,8 @@ from torch import nn ...@@ -17,7 +17,8 @@ from torch import nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import ( from fairscale.experimental.nn import MEVO
from fairscale.fair_dev.testing.testing import (
dist_init, dist_init,
in_circle_ci, in_circle_ci,
objects_are_equal, objects_are_equal,
...@@ -25,7 +26,6 @@ from fair_dev.testing.testing import ( ...@@ -25,7 +26,6 @@ from fair_dev.testing.testing import (
teardown, teardown,
temp_files_ctx, temp_files_ctx,
) )
from fairscale.experimental.nn import MEVO
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
VOCAB = 4 VOCAB = 4
......
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fair_dev.testing.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx from fairscale.fair_dev.testing.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx
from fairscale.internal import torch_version from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
......
...@@ -18,7 +18,7 @@ import torch.multiprocessing as mp ...@@ -18,7 +18,7 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from torch.optim import SGD from torch.optim import SGD
from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown from fairscale.fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown
from fairscale.internal import torch_version from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState
......
...@@ -13,7 +13,7 @@ from torch import nn ...@@ -13,7 +13,7 @@ from torch import nn
import torch.distributed import torch.distributed
import torch.multiprocessing as mp import torch.multiprocessing as mp
from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx from fairscale.fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
......
...@@ -16,7 +16,7 @@ import torch.distributed as dist ...@@ -16,7 +16,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from fair_dev.testing.testing import ( from fairscale.fair_dev.testing.testing import (
GPT2, GPT2,
SGDWithPausingCompute, SGDWithPausingCompute,
available_devices, available_devices,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment