Unverified Commit 66b2b514 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] fine tune test for checkpoint & DDP (#148)

- fixed typing
- make it run less often to reduce CI time

testing: run it in a loop make sure it is run in the right frequency.
parent a0042113
......@@ -2,5 +2,7 @@
from typing import Tuple
from .. import Tensor
from torch.nn.modules.module import Module
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
def checkpoint(function: Module, *args, **kwargs): ...
......@@ -6,10 +6,11 @@
# Test checkpoint and PyTorch DDP interactions.
import os
import random
import tempfile
import numpy # type: ignore
import numpy
import pytest
import torch
import torch.distributed as dist
......@@ -17,13 +18,21 @@ import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.checkpoint import checkpoint as torch_checkpoint # type: ignore
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors
from fairscale.nn.pipe.microbatch import Batch
# This test is mainly for checking pytorch & checkpointing behavior. pipe's checkpointing
# code is tested already in another file. Therefore, we can run this test less frequently.
# We use getpid() in case random is seeded to be deterministic.
run_test = False
if os.getpid() % 100 == 42:
run_test = True
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
skip_if_not_needed = pytest.mark.skipif(not run_test, reason="Skipping due to test frequency")
def set_random_seed(seed: int) -> None:
......@@ -179,6 +188,7 @@ def run(rank, world_size, temp_file_name, checkpoint, test_func):
dist.destroy_process_group()
@skip_if_not_needed
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("checkpoint", [pipe_checkpoint, torch_checkpoint])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment