Unverified Commit d52d2186 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf][ShardedDDP] fp16 gradient reduce (#411)

* POC, testing against the DDP comm hook when available
* docs, adding a reference to DDP's compress hook
* updating changelog, prep for v0.1.8 release
parent d10c34e7
......@@ -7,11 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Added
- FullyShardedDataParallel (FSDP) ([#413](https://github.com/facebookresearch/fairscale/issues/413))
- ShardedDDP fp16 grad reduction option ([#402](https://github.com/facebookresearch/fairscale/issues/402))
### Fixed
- Catch corner case when the model is too small with respect to the world size, and shards are empty ([#406](https://github.com/facebookresearch/fairscale/pull/406))
- Memory leak in checkpoint_wrapper ([#412](https://github.com/facebookresearch/fairscale/pull/412))
......
......@@ -58,7 +58,11 @@ class ShardedDataParallel(nn.Module):
(default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
a parameter is frozen or unfrozen.
reduce_fp16 (bool):
cast the grads to fp16 before reducing. Not needed if the model is already fp16, but will probably improve performance
for multi node jobs using PyTorch AMP. The effect is similar to DDP's fp16_compress_hook_ and will also save some memory.
.. _fp16_compress_hook: https://pytorch.org/docs/1.8.0/ddp_comm_hooks.html?highlight=fp16#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
.. warning:
ShardedDDP implements gradient sharding, meaning that each rank only owns a unique shard of the model gradients
......@@ -100,6 +104,7 @@ class ShardedDataParallel(nn.Module):
sync_models_at_startup: bool = True,
reduce_buffer_size: int = 0,
auto_refresh_trainable: bool = True,
reduce_fp16: bool = True,
):
super().__init__()
......@@ -107,6 +112,12 @@ class ShardedDataParallel(nn.Module):
self.sharded_optimizers = [sharded_optimizer] if isinstance(sharded_optimizer, OSS) else sharded_optimizer
self.enable_broadcast_buffers = broadcast_buffers
self.auto_refresh_trainable = auto_refresh_trainable
self.reduce_fp16 = reduce_fp16
if reduce_buffer_size > 0:
self.reduce_fp16 = False
logging.warning(
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
)
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
......@@ -401,10 +412,16 @@ class ShardedDataParallel(nn.Module):
self._grad_to_be_reduced[index] = False
param.grad.mul_(self.world_size_scaling)
if self.reduce_fp16:
param.grad.data = param.grad.data.half()
# Future work includes clearing up the buffer if possible
def cleanup() -> None:
if dst_rank != self.global_rank:
param.grad = None
else:
assert param.grad is not None
param.grad.data = param.grad.data.to(dtype=param.dtype)
# Async reduce for this buffer, log the future
dst_global_rank = OSS.get_global_rank(self.process_group, dst_rank)
......
......@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "helpers", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Testing ShardedDDP
"""
from contextlib import suppress
import tempfile
import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.utils.testing import (
GPT2,
available_devices,
check_same_models_across_ranks,
skip_if_less_than_four_gpu,
skip_if_no_cuda,
skip_if_py38,
skip_if_single_gpu,
)
def _get_mlp():
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = _get_mlp()
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
def run_one_step(
rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
# Any model works. Add one different buffer per rank
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
next(model.parameters()).requires_grad = False # Test non-trainable parameters
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(
model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
)
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers
)
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
)
dist.destroy_process_group()
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size):
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_one_step,
args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size),
nprocs=world_size,
join=True,
)
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_gpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.NCCL, torch.device("cuda"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
@skip_if_py38
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_cpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.GLOO, torch.device("cpu"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == "cuda":
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
model = _DoubleInput().to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
@pytest.mark.parametrize("device", available_devices)
def test_inputs(reduce_buffer_size, backend, device):
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
if backend == "nccl" and device == "cpu":
pytest.skip("Incompatible combination, or cuda not available")
return
mp.spawn(
run_test_two_inputs,
args=(world_size, backend, device, tempfile.mkstemp()[1], reduce_buffer_size),
nprocs=world_size,
join=True,
)
def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "is_multi_device_module")
assert hasattr(ddp_model, "device_type")
dist.destroy_process_group()
def test_random_attributes():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
model.banana = "sweet"
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "banana")
assert not hasattr(ddp_model, "orange")
dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
model = Sequential(Linear(2, 3), Linear(3, 3)).cpu() # not device on purpose, test changing it after the fact
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(
model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size
)
try:
ddp_model.to(device)
assert False, "Changing devices should be caught and not supported"
except AssertionError:
pass
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_device_change(reduce_buffer_size):
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size = 2
backend = "nccl"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(
run_test_device_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
group = dist.init_process_group(
init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size
)
torch.cuda.set_device(rank)
model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size)
inputs = torch.rand((10, 2), device=device)
outputs = ddp_model(inputs) # assert if the module has not been changed properly
_ = outputs.norm().backward()
ddp_model.eval()
ddp_model(inputs) # This will assert if eval() is not properly taken into account
ddp_model(inputs)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_training_change(reduce_buffer_size):
world_size = 2
backend = "nccl"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(
run_test_training_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device) # in pytorch 1.5 syncBN switches to the default device/cpu
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert isinstance(model[1], torch.nn.SyncBatchNorm)
# Ensures sync batch norm handles have been added
ddp_model(torch.randn(2, 2).to(device))
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_sync_batch_norm():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(
run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
)
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
model = _DoubleInput().to(device)
parameters = list(model.parameters())
optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])
# Optim loop
def closure():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
for i in range(5):
optimizer_1.zero_grad()
optimizer_2.zero_grad()
_ = optimizer_1.step(closure=closure)
_ = optimizer_2.step(closure=closure)
dist.destroy_process_group()
def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
INPUT_DIM = 16
BACH_SIZE = 10
STEPS = 10
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
model = GPT2(
embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=0)
# Optim loop
def closure():
optimizer.zero_grad()
# Force int inputs to prevent the first grad from firing
input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# Check for bucketing overflows
for i in range(STEPS):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_gpt2():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
# Only work with the even ranks, to check that the global_rank indexing is properly used
dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)
sub_group_ranks = [0, 2]
process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)
# Make sure that all the ranks get different training data
# So that the sync check in between their models is meaningful
torch.manual_seed(rank)
np.random.seed(rank)
# Standard deep learning setup
device = "cuda"
torch.cuda.set_device(rank)
epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
loss_fn = torch.nn.L1Loss().to(device)
def check(optimizer, model):
# Just run a couple of epochs, check that the model is properly updated
for _ in range(epochs):
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
def closure():
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, target)
loss.backward()
return loss
_ = optimizer.step(closure=closure)
# Check that all the params are the same on all ranks
check_same_models_across_ranks(
model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
)
if rank in sub_group_ranks:
# Model not-fitting in the broadcast bucket
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(
model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
)
check(optimizer, model)
dist.destroy_process_group(process_group)
@skip_if_less_than_four_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend):
world_size = 4
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_multiple_groups,
args=(world_size, temp_file_name, backend, reduce_buffer_size),
nprocs=world_size,
join=True,
)
......@@ -30,10 +30,18 @@ from fairscale.utils.testing import (
check_same_models_across_ranks,
skip_if_less_than_four_gpu,
skip_if_no_cuda,
skip_if_py38,
skip_if_single_gpu,
)
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
"""
_test_fp16_reduction = [False]
if hasattr(dist, "algorithms.ddp_com_hooks.default_hooks"):
_test_fp16_reduction.append(True)
def _get_mlp():
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
......@@ -50,90 +58,8 @@ class _DoubleInput(torch.nn.Module):
return torch.cat((x1, x2), dim=1)
def run_one_step(
rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
# Any model works. Add one different buffer per rank
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
next(model.parameters()).requires_grad = False # Test non-trainable parameters
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(
model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
)
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers
)
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
)
dist.destroy_process_group()
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size):
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_one_step,
args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size),
nprocs=world_size,
join=True,
)
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_gpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.NCCL, torch.device("cuda"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
@skip_if_py38
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_cpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.GLOO, torch.device("cpu"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
def run_ddp_parity(
rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph
rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
......@@ -188,12 +114,18 @@ def run_ddp_parity(
sharded_optimizer=sharded_optimizer,
broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
reduce_fp16=fp16_reduction,
)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
if fp16_reduction:
from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook
ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore
ddp_scaler = TorchGradScaler() if amp else None
sharded_ddp_scaler = ShardedGradScaler() if amp else None
......@@ -269,12 +201,21 @@ def run_ddp_parity(
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("change_train_graph", [True, False])
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph):
@pytest.mark.parametrize("fp16_reduction", _test_fp16_reduction)
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction):
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(
run_ddp_parity,
args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size, grad_accumulation, change_train_graph),
args=(
world_size,
backend,
tempfile.mkstemp()[1],
reduce_buffer_size,
grad_accumulation,
change_train_graph,
fp16_reduction,
),
nprocs=world_size,
join=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment