"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "6fd04ca922e5da7ef8c52d86118fc58b798a7e4a"
Unverified Commit 6658be22 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] Add unittest for checkpoint & DDP (#126)

* Add unittest for checkpoint & DDP

- this change adds test cases to reproduce the error with checkpoint & DDP
- mandeep mentioned that there is also deadlock in this case, but this
  change doesn't cover that.
- we cover cases where weight sharing is OK
- however, same module multiple checkpoint or find_unused_parameters are
  both not OK

* added norm checks
parent 341d8b2b
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Test checkpoint and PyTorch DDP interactions.
import random
import tempfile
import numpy # type: ignore
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.checkpoint import checkpoint as torch_checkpoint # type: ignore
from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors
from fairscale.nn.pipe.microbatch import Batch
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
def set_random_seed(seed: int) -> None:
"""Set random seed for reproducability."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
def pipe_checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input, 0)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
def basic(rank, checkpoint):
# get the model, wrap with DDP and fwd, bwd.
set_random_seed(31415)
model = Sequential(Linear(2000, 2000), Linear(2000, 2000))
model.to("cuda")
model = DDP(model, device_ids=[rank])
input_tensor = torch.rand((64, 2000)).cuda()
input_tensor.requires_grad = True
output_tensor = checkpoint(model, input_tensor)
for p in model.parameters():
assert p.grad is None
output_tensor.sum().backward()
norm = 0.0
for p in model.parameters():
assert p.grad is not None
norm += p.grad.norm().item()
assert numpy.allclose(norm, 78053.52978515625), norm
def weight_sharing(rank, checkpoint):
# get the model, wrap with DDP and fwd, bwd.
set_random_seed(31415)
l1 = Linear(2000, 2000)
l2 = Linear(2000, 2000)
l1.weight = l2.weight
model = Sequential(l1, l2)
model.to("cuda")
model = DDP(model, device_ids=[rank])
input_tensor = torch.rand((64, 2000)).cuda()
input_tensor.requires_grad = True
output_tensor = checkpoint(model, input_tensor)
output_tensor.sum().backward()
norm = 0.0
for p in model.parameters():
assert p.grad is not None
norm += p.grad.norm().item()
assert numpy.allclose(norm, 57004.34228515625), norm
def checkpoint_half(rank, checkpoint):
# get the model, wrap with DDP and fwd, bwd.
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.l1 = Linear(2000, 2000)
self.l2 = Linear(2000, 2000)
def forward(self, inp):
x = self.l1(inp)
x = checkpoint(self.l2, x)
return x
set_random_seed(31415)
model = M()
model.to("cuda")
model = DDP(model, device_ids=[rank])
input_tensor = torch.rand((64, 2000)).cuda()
output_tensor = model(input_tensor)
output_tensor.sum().backward()
norm = 0.0
for p in model.parameters():
assert p.grad is not None
norm += p.grad.norm().item()
assert numpy.allclose(norm, 78053.52978515625), norm
def unused_param(rank, checkpoint):
# get the model, wrap with DDP and fwd, bwd.
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
# The size 2000 is important. Without bigger size, it doesn't trigger the RuntimeError!
self.l1 = Linear(2000, 2000)
self.l2 = Linear(2000, 2000)
def forward(self, inp):
x = self.l1(inp)
x = checkpoint(self.l2, x)
return x
model = M()
model.to("cuda")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
input_tensor = torch.rand((64, 2000)).cuda()
output_tensor = model(input_tensor)
try:
output_tensor.sum().backward()
except RuntimeError:
return
assert 0
def checkpoint_twice(rank, checkpoint):
# get the model, wrap with DDP and fwd, bwd.
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
# The size 2000 is important. Without bigger size, it doesn't trigger the RuntimeError!
self.l1 = Linear(2000, 2000)
self.l2 = Linear(2000, 2000)
def forward(self, inp):
x = self.l1(inp)
x = checkpoint(self.l2, x)
x = checkpoint(self.l2, x)
return x
model = M()
model.to("cuda")
model = DDP(model, device_ids=[rank])
input_tensor = torch.rand((64, 2000)).cuda()
output_tensor = model(input_tensor)
try:
output_tensor.sum().backward()
except RuntimeError:
return
assert 0
def run(rank, world_size, temp_file_name, checkpoint, test_func):
# setup
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=dist.Backend.NCCL, rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# actual test
test_func(rank, checkpoint)
# cleanup
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("checkpoint", [pipe_checkpoint, torch_checkpoint])
@pytest.mark.parametrize("test_func", [basic, weight_sharing, checkpoint_half, unused_param, checkpoint_twice])
def test_basic_ddp(checkpoint, test_func):
temp_file_name = tempfile.mkstemp()[1]
world_size = 2
mp.spawn(run, args=(world_size, temp_file_name, checkpoint, test_func), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment