Unverified Commit 3b7373e2 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[test][refactor][SDP] Using the nice context-based tempfiles (#640)

parent 8c8a625a
......@@ -8,7 +8,6 @@ Testing ShardedDDP
"""
from contextlib import suppress
import tempfile
import numpy as np
import pytest
......@@ -27,6 +26,7 @@ from fairscale.utils.testing import (
skip_if_less_than_four_gpu,
skip_if_no_cuda,
skip_if_single_gpu,
temp_files_ctx,
)
......@@ -134,13 +134,13 @@ def run_one_step(
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type):
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_one_step,
args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size),
nprocs=world_size,
join=True,
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_one_step,
args=(world_size, backend, device, temp_files[0], broadcast_buffers, grad_accumulation, reduce_buffer_size),
nprocs=world_size,
join=True,
)
@skip_if_no_cuda
......@@ -160,24 +160,23 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
)
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, setup):
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_one_step,
args=(
world_size,
setup[0],
setup[1],
temp_file_name,
broadcast_buffers,
grad_accumulation,
reduce_buffer_size,
optimizer_type,
reduce_fp16,
),
nprocs=world_size,
join=True,
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_one_step,
args=(
world_size,
setup[0],
setup[1],
temp_files[0],
broadcast_buffers,
grad_accumulation,
reduce_buffer_size,
optimizer_type,
reduce_fp16,
),
nprocs=world_size,
join=True,
)
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
......@@ -200,7 +199,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
loss.backward()
return loss
for i in range(5):
for _ in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
......@@ -215,78 +214,82 @@ def test_inputs(reduce_buffer_size, backend, device):
if backend == "nccl" and device == "cpu":
pytest.skip("Incompatible combination, or cuda not available")
return
mp.spawn(
run_test_two_inputs,
args=(world_size, backend, device, tempfile.mkstemp()[1], reduce_buffer_size),
nprocs=world_size,
join=True,
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_test_two_inputs,
args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
nprocs=world_size,
join=True,
)
def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
with temp_files_ctx(num=1) as temp_files:
dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "is_multi_device_module")
assert hasattr(ddp_model, "device_type")
dist.destroy_process_group()
assert hasattr(ddp_model, "is_multi_device_module")
assert hasattr(ddp_model, "device_type")
dist.destroy_process_group()
def test_random_attributes():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
with temp_files_ctx(num=1) as temp_files:
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
model.banana = "sweet"
model = Sequential(Linear(2, 3), Linear(3, 3))
model.banana = "sweet"
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "banana")
assert not hasattr(ddp_model, "orange")
assert hasattr(ddp_model, "banana")
assert not hasattr(ddp_model, "orange")
dist.destroy_process_group()
dist.destroy_process_group()
def test_catch_grad_grad():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
with temp_files_ctx(num=1) as temp_files:
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
model.train()
chained_grad = torch.zeros_like(next(model.parameters()))
chained_grad.requires_grad = True
next(model.parameters()).grad = chained_grad
model = Sequential(Linear(2, 3), Linear(3, 3))
model.train()
chained_grad = torch.zeros_like(next(model.parameters()))
chained_grad.requires_grad = True
next(model.parameters()).grad = chained_grad
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
inputs = torch.rand(100, 2)
with pytest.raises(RuntimeError):
_ = ddp_model(inputs)
inputs = torch.rand(100, 2)
with pytest.raises(RuntimeError):
_ = ddp_model(inputs)
dist.destroy_process_group()
dist.destroy_process_group()
def test_mixed_types():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
with temp_files_ctx(num=1) as temp_files:
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
model = _get_mlp(tripwire=True)
model = _get_mlp(tripwire=True)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(model, optimizer)
input_tensor = torch.rand((2, 2))
_ = model(input_tensor)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(model, optimizer)
input_tensor = torch.rand((2, 2))
_ = model(input_tensor)
dist.destroy_process_group()
dist.destroy_process_group()
def run_test_train_eval_change(rank, world_size, file):
......@@ -317,10 +320,10 @@ def run_test_train_eval_change(rank, world_size, file):
def test_train_eval_change():
world_size = 4
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_train_eval_change, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_test_train_eval_change, args=(world_size, temp_files[0]), nprocs=world_size, join=True,
)
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
......@@ -352,14 +355,14 @@ def test_device_change(reduce_buffer_size):
# Check that ShardedDDP handles a device change properly
world_size = 2
backend = "nccl"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(
run_test_device_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
nprocs=world_size,
join=True,
)
with temp_files_ctx(num=1) as temp_files:
device = "cuda"
mp.spawn(
run_test_device_change,
args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
......@@ -389,14 +392,14 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name,
def test_training_change(reduce_buffer_size):
world_size = 2
backend = "nccl"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(
run_test_training_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
nprocs=world_size,
join=True,
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_test_training_change,
args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
......@@ -421,11 +424,14 @@ def test_ddp_sync_batch_norm():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(
run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_test_ddp_sync_batch_norm,
args=(world_size, backend, device, temp_files[0]),
nprocs=world_size,
join=True,
)
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
......@@ -463,9 +469,11 @@ def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_test_two_optimizers, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True
)
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
......@@ -510,9 +518,9 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
def test_gpt2(world_size):
# Check that having trainable unused params is fine
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True)
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
......@@ -575,11 +583,10 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_bu
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend):
world_size = 4
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_multiple_groups,
args=(world_size, temp_file_name, backend, reduce_buffer_size),
nprocs=world_size,
join=True,
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_test_multiple_groups,
args=(world_size, temp_files[0], backend, reduce_buffer_size),
nprocs=world_size,
join=True,
)
......@@ -9,7 +9,6 @@ Testing ShardedDDP
from contextlib import suppress
import copy
import tempfile
import numpy as np
import pytest
......@@ -23,7 +22,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, torch_version
from fairscale.utils.testing import (
check_same_model_params,
skip_if_no_cuda,
skip_if_single_gpu,
temp_files_ctx,
torch_version,
)
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
......@@ -250,24 +255,25 @@ def test_ddp_parity(
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(
run_ddp_parity,
args=(
world_size,
backend,
tempfile.mkstemp()[1],
reduce_buffer_size,
grad_accumulation,
change_train_graph,
fp16_reduction,
clip_grad_norm,
amp,
manual_reduction,
multiple_fw,
),
nprocs=world_size,
join=True,
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_ddp_parity,
args=(
world_size,
backend,
temp_files[0],
reduce_buffer_size,
grad_accumulation,
change_train_graph,
fp16_reduction,
clip_grad_norm,
amp,
manual_reduction,
multiple_fw,
),
nprocs=world_size,
join=True,
)
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_buffer_size):
......@@ -340,9 +346,10 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
def test_ddp_parity_two_optim(reduce_buffer_size):
world_size = 2
backend = dist.Backend.NCCL
mp.spawn(
run_ddp_parity_two_optim,
args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size),
nprocs=world_size,
join=True,
)
with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_ddp_parity_two_optim,
args=(world_size, backend, temp_files[0], reduce_buffer_size),
nprocs=world_size,
join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment