Unverified Commit 8f8f8ef9 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[chore] move fair_dev into fairscale (#1078)


Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent bfd57ff3
......@@ -21,8 +21,8 @@ from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fair_dev.testing.testing import dist_init, get_worker_map
from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.fair_dev.testing.testing import dist_init, get_worker_map
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule
......
......@@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import utils
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
from fair_dev.testing.testing import dist_init
from fairscale.fair_dev.testing.testing import dist_init
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
......
NOTE:
The experimental and fair_dev submodules are not part of the fairscale public
API. There can be breaking changes in them at anytime.
......@@ -5,6 +5,10 @@
################################################################################
# Import most common subpackages
#
# NOTE: we don't maintain any public APIs in both experimental and fair_dev
# sub-modules. Code in them are experimental or for developer only. They
# can be changed, removed, anytime.
################################################################################
from typing import List
......
......@@ -22,8 +22,8 @@ from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from fair_dev.testing.testing import get_worker_map, torch_spawn
from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.fair_dev.testing.testing import get_worker_map, torch_spawn
class MySGD(Optimizer):
......
......@@ -15,8 +15,8 @@ from torch import nn
import torch.distributed
import torch.nn.functional as F
from fair_dev.testing.testing import skip_if_single_gpu, spawn_for_all_world_sizes
import fairscale.experimental.nn.data_parallel.gossip as gossip
from fairscale.fair_dev.testing.testing import skip_if_single_gpu, spawn_for_all_world_sizes
# Enfore CUBLAS reproducibility, see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
......
......@@ -12,9 +12,9 @@ import os
import pytest
import torch
from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.experimental.nn import MEVO
from fairscale.experimental.nn.mevo import BaselineSoftmaxNllLoss, get_data
from fairscale.fair_dev.testing.testing import skip_if_no_cuda
@pytest.fixture(scope="session", params=[torch.float16, torch.float32])
......
......@@ -20,8 +20,8 @@ import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
from fair_dev.testing.testing import skip_due_to_flakyness, skip_if_single_gpu
from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.fair_dev.testing.testing import skip_due_to_flakyness, skip_if_single_gpu
from fairscale.internal import torch_version
pytestmark = pytest.mark.skipif(
......
......@@ -14,8 +14,8 @@ import numpy as np
import pytest
import torch
from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.experimental.nn.offload import OffloadModel
from fairscale.fair_dev.testing.testing import skip_if_no_cuda
from fairscale.internal import torch_version
if torch_version() >= (1, 8, 0):
......
......@@ -10,12 +10,12 @@ import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from fair_dev.testing.testing import GPT2, dist_init, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
from fairscale.experimental.tooling.layer_memory_tracker import (
LayerwiseMemoryTracker,
ProcessGroupTracker,
find_best_reset_points,
)
from fairscale.fair_dev.testing.testing import GPT2, dist_init, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
from fairscale.nn import FullyShardedDataParallel
......
......@@ -11,8 +11,8 @@ import pytest
import torch
from torch import nn
from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.sha1_store import SHA1_Store
from fairscale.fair_dev.testing.testing import objects_are_equal
# Get the absolute path of the parent at the beginning before any os.chdir(),
# so that we can proper clean it up at any CWD.
......
......@@ -6,8 +6,8 @@
import pytest
import torch
from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.signal_sparsity import SignalSparsity, random_sparse_mask
from fairscale.fair_dev.testing.testing import objects_are_equal
# Our own tolerance
ATOL = 1e-6
......
......@@ -8,8 +8,8 @@ import time
import pytest
import torch
from fair_dev.testing.testing import objects_are_equal, skip_if_no_cuda
from fairscale.experimental.wgit.signal_sparsity_profiling import EnergyConcentrationProfile as ECP
from fairscale.fair_dev.testing.testing import objects_are_equal, skip_if_no_cuda
# Our own tolerance
ATOL = 1e-6
......
......@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.fair_dev.testing.testing import skip_if_no_cuda
from fairscale.internal import torch_version
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper, disable_checkpointing
from fairscale.nn.misc import FlattenParamsWrapper
......
......@@ -14,7 +14,7 @@ import torch
from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from torch.optim import SGD
from fair_dev.testing.testing import objects_are_equal
from fairscale.fair_dev.testing.testing import objects_are_equal
from fairscale.internal import torch_version
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
......
......@@ -18,7 +18,7 @@ import torch
from torch import nn
import torch.distributed
from fair_dev.testing.testing import (
from fairscale.fair_dev.testing.testing import (
DeviceAndTypeCheckModule,
DummyProcessGroup,
dist_init,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment