Unverified Commit 3b7373e2 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[test][refactor][SDP] Using the nice context-based tempfiles (#640)

parent 8c8a625a
...@@ -8,7 +8,6 @@ Testing ShardedDDP ...@@ -8,7 +8,6 @@ Testing ShardedDDP
""" """
from contextlib import suppress from contextlib import suppress
import tempfile
import numpy as np import numpy as np
import pytest import pytest
...@@ -27,6 +26,7 @@ from fairscale.utils.testing import ( ...@@ -27,6 +26,7 @@ from fairscale.utils.testing import (
skip_if_less_than_four_gpu, skip_if_less_than_four_gpu,
skip_if_no_cuda, skip_if_no_cuda,
skip_if_single_gpu, skip_if_single_gpu,
temp_files_ctx,
) )
...@@ -134,10 +134,10 @@ def run_one_step( ...@@ -134,10 +134,10 @@ def run_one_step(
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type): def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type):
temp_file_name = tempfile.mkstemp()[1] with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_one_step, run_one_step,
args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size), args=(world_size, backend, device, temp_files[0], broadcast_buffers, grad_accumulation, reduce_buffer_size),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
) )
...@@ -160,15 +160,14 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, ...@@ -160,15 +160,14 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
) )
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, setup): def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, setup):
world_size = 2 world_size = 2
temp_file_name = tempfile.mkstemp()[1] with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_one_step, run_one_step,
args=( args=(
world_size, world_size,
setup[0], setup[0],
setup[1], setup[1],
temp_file_name, temp_files[0],
broadcast_buffers, broadcast_buffers,
grad_accumulation, grad_accumulation,
reduce_buffer_size, reduce_buffer_size,
...@@ -200,7 +199,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc ...@@ -200,7 +199,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
loss.backward() loss.backward()
return loss return loss
for i in range(5): for _ in range(5):
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
dist.destroy_process_group() dist.destroy_process_group()
...@@ -215,10 +214,10 @@ def test_inputs(reduce_buffer_size, backend, device): ...@@ -215,10 +214,10 @@ def test_inputs(reduce_buffer_size, backend, device):
if backend == "nccl" and device == "cpu": if backend == "nccl" and device == "cpu":
pytest.skip("Incompatible combination, or cuda not available") pytest.skip("Incompatible combination, or cuda not available")
return return
with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_test_two_inputs, run_test_two_inputs,
args=(world_size, backend, device, tempfile.mkstemp()[1], reduce_buffer_size), args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
) )
...@@ -228,7 +227,8 @@ def test_ddp_attributes(): ...@@ -228,7 +227,8 @@ def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module # - is multi_device_module
# - device_type # - device_type
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) with temp_files_ctx(num=1) as temp_files:
dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3)) model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
...@@ -240,8 +240,9 @@ def test_ddp_attributes(): ...@@ -240,8 +240,9 @@ def test_ddp_attributes():
def test_random_attributes(): def test_random_attributes():
with temp_files_ctx(num=1) as temp_files:
# Check that ShardedDDP exposes the original module's attributes # Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3)) model = Sequential(Linear(2, 3), Linear(3, 3))
model.banana = "sweet" model.banana = "sweet"
...@@ -256,8 +257,9 @@ def test_random_attributes(): ...@@ -256,8 +257,9 @@ def test_random_attributes():
def test_catch_grad_grad(): def test_catch_grad_grad():
with temp_files_ctx(num=1) as temp_files:
# Check that ShardedDDP exposes the original module's attributes # Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3)) model = Sequential(Linear(2, 3), Linear(3, 3))
model.train() model.train()
...@@ -276,8 +278,9 @@ def test_catch_grad_grad(): ...@@ -276,8 +278,9 @@ def test_catch_grad_grad():
def test_mixed_types(): def test_mixed_types():
with temp_files_ctx(num=1) as temp_files:
# Check that ShardedDDP exposes the original module's attributes # Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
model = _get_mlp(tripwire=True) model = _get_mlp(tripwire=True)
...@@ -317,9 +320,9 @@ def run_test_train_eval_change(rank, world_size, file): ...@@ -317,9 +320,9 @@ def run_test_train_eval_change(rank, world_size, file):
def test_train_eval_change(): def test_train_eval_change():
world_size = 4 world_size = 4
temp_file_name = tempfile.mkstemp()[1] with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_test_train_eval_change, args=(world_size, temp_file_name), nprocs=world_size, join=True, run_test_train_eval_change, args=(world_size, temp_files[0]), nprocs=world_size, join=True,
) )
...@@ -352,11 +355,11 @@ def test_device_change(reduce_buffer_size): ...@@ -352,11 +355,11 @@ def test_device_change(reduce_buffer_size):
# Check that ShardedDDP handles a device change properly # Check that ShardedDDP handles a device change properly
world_size = 2 world_size = 2
backend = "nccl" backend = "nccl"
temp_file_name = tempfile.mkstemp()[1] with temp_files_ctx(num=1) as temp_files:
device = "cuda" device = "cuda"
mp.spawn( mp.spawn(
run_test_device_change, run_test_device_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size), args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
) )
...@@ -389,11 +392,11 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name, ...@@ -389,11 +392,11 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name,
def test_training_change(reduce_buffer_size): def test_training_change(reduce_buffer_size):
world_size = 2 world_size = 2
backend = "nccl" backend = "nccl"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda" device = "cuda"
with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_test_training_change, run_test_training_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size), args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
) )
...@@ -421,10 +424,13 @@ def test_ddp_sync_batch_norm(): ...@@ -421,10 +424,13 @@ def test_ddp_sync_batch_norm():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size = 2 world_size = 2
backend = "gloo" backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda" device = "cuda"
with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True run_test_ddp_sync_batch_norm,
args=(world_size, backend, device, temp_files[0]),
nprocs=world_size,
join=True,
) )
...@@ -463,9 +469,11 @@ def test_two_optimizers(): ...@@ -463,9 +469,11 @@ def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2 world_size = 2
backend = "gloo" backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu" device = "cpu"
mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) with temp_files_ctx(num=1) as temp_files:
mp.spawn(
run_test_two_optimizers, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True
)
def run_test_gpt2(rank, world_size, backend, device, temp_file_name): def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
...@@ -510,9 +518,9 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name): ...@@ -510,9 +518,9 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
def test_gpt2(world_size): def test_gpt2(world_size):
# Check that having trainable unused params is fine # Check that having trainable unused params is fine
backend = "gloo" backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda" device = "cuda"
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) with temp_files_ctx(num=1) as temp_files:
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True)
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size): def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
...@@ -575,11 +583,10 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_bu ...@@ -575,11 +583,10 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_bu
@pytest.mark.parametrize("backend", ["gloo", "nccl"]) @pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend): def test_multiple_groups(reduce_buffer_size, backend):
world_size = 4 world_size = 4
temp_file_name = tempfile.mkstemp()[1] with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_test_multiple_groups, run_test_multiple_groups,
args=(world_size, temp_file_name, backend, reduce_buffer_size), args=(world_size, temp_files[0], backend, reduce_buffer_size),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
) )
...@@ -9,7 +9,6 @@ Testing ShardedDDP ...@@ -9,7 +9,6 @@ Testing ShardedDDP
from contextlib import suppress from contextlib import suppress
import copy import copy
import tempfile
import numpy as np import numpy as np
import pytest import pytest
...@@ -23,7 +22,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -23,7 +22,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, torch_version from fairscale.utils.testing import (
check_same_model_params,
skip_if_no_cuda,
skip_if_single_gpu,
temp_files_ctx,
torch_version,
)
""" """
Check that ShardedDDP gets the same results as DDP in a variety of scenarii Check that ShardedDDP gets the same results as DDP in a variety of scenarii
...@@ -250,12 +255,13 @@ def test_ddp_parity( ...@@ -250,12 +255,13 @@ def test_ddp_parity(
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL backend = dist.Backend.NCCL
with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_ddp_parity, run_ddp_parity,
args=( args=(
world_size, world_size,
backend, backend,
tempfile.mkstemp()[1], temp_files[0],
reduce_buffer_size, reduce_buffer_size,
grad_accumulation, grad_accumulation,
change_train_graph, change_train_graph,
...@@ -340,9 +346,10 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b ...@@ -340,9 +346,10 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
def test_ddp_parity_two_optim(reduce_buffer_size): def test_ddp_parity_two_optim(reduce_buffer_size):
world_size = 2 world_size = 2
backend = dist.Backend.NCCL backend = dist.Backend.NCCL
with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_ddp_parity_two_optim, run_ddp_parity_two_optim,
args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size), args=(world_size, backend, temp_files[0], reduce_buffer_size),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment