Unverified Commit 66b2b514 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] fine tune test for checkpoint & DDP (#148)

- fixed typing
- make it run less often to reduce CI time

testing: run it in a loop make sure it is run in the right frequency.
parent a0042113
...@@ -2,5 +2,7 @@ ...@@ -2,5 +2,7 @@
from typing import Tuple from typing import Tuple
from .. import Tensor from .. import Tensor
from torch.nn.modules.module import Module
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ... def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
def checkpoint(function: Module, *args, **kwargs): ...
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
# Test checkpoint and PyTorch DDP interactions. # Test checkpoint and PyTorch DDP interactions.
import os
import random import random
import tempfile import tempfile
import numpy # type: ignore import numpy
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -17,13 +18,21 @@ import torch.multiprocessing as mp ...@@ -17,13 +18,21 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.checkpoint import checkpoint as torch_checkpoint # type: ignore from torch.utils.checkpoint import checkpoint as torch_checkpoint
from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors
from fairscale.nn.pipe.microbatch import Batch from fairscale.nn.pipe.microbatch import Batch
# This test is mainly for checking pytorch & checkpointing behavior. pipe's checkpointing
# code is tested already in another file. Therefore, we can run this test less frequently.
# We use getpid() in case random is seeded to be deterministic.
run_test = False
if os.getpid() % 100 == 42:
run_test = True
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required") skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
skip_if_not_needed = pytest.mark.skipif(not run_test, reason="Skipping due to test frequency")
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
...@@ -179,6 +188,7 @@ def run(rank, world_size, temp_file_name, checkpoint, test_func): ...@@ -179,6 +188,7 @@ def run(rank, world_size, temp_file_name, checkpoint, test_func):
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_not_needed
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("checkpoint", [pipe_checkpoint, torch_checkpoint]) @pytest.mark.parametrize("checkpoint", [pipe_checkpoint, torch_checkpoint])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment