Unverified Commit 13445c55 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature-fix-refactor][ShardedDDP] Make it possible to change trainability graph on the fly (#369)

* Better unit testing
* Make it possible to refresh the DDP assumptions when the model has changed. Make it optional so that you save some time
* Enabling accumulation tests
parent 1a636557
......@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [next rel] - TBD
### Fixed
- ShardedDDP and OSS handle model trainability changes during training ([#369](https://github.com/facebookresearch/fairscale/issues/369))
## [0.1.6] - 2021-02-10
### Added
......
This diff is collapsed.
......@@ -3,19 +3,19 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict, deque
from collections import OrderedDict
import copy
from itertools import chain
import logging
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from .utils import Workhandle, broadcast_object, recursive_copy_to_device
from .utils import broadcast_object, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -52,6 +52,14 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int):
(deprecated) used to cap the size of the broadcast buffers, not being used anymore.
.. warning: the communication patterns that OSS use depend on the "trainability" graph,
meaning that all the parameters which `require_grad` are handled differently. This is
not reevaluated at every step, please use `refresh_trainable()` if your model changed
(freeze or unfreeze for instance).
If used with :class:<fairscale.nn.ShardedDDP> then an automatic change detection is possible,
via the `auto_refresh_trainable` parameter.
"""
#: The optimizer used for a given shard
......@@ -81,27 +89,22 @@ class OSS(Optimizer):
self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[List[torch.Tensor]] = None
# Build the wrapped optimizer, responsible for a shard of the params
# Default empty values + immutables
self._optim_defaults = default
self._optim_constructor = optim
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default)
# - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for key, value in local_group.items():
if key != "params":
global_group[key] = value
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._default_device = torch.device("cpu")
# Current default device is set by the parameters allocated to this rank
self._device = list(self.per_device_params.keys())[0]
self.work_handles: Deque[Workhandle] = deque()
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self._setup_flat_buffers()
# Setup everything which is related to the parameters to be trained
# (partition and optimizer for the shard)
self.refresh_trainable()
# Partition helpers
def partition_parameters(self) -> List[List[dict]]:
......@@ -277,12 +280,12 @@ class OSS(Optimizer):
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params)
total_norm = max(p.grad.detach().abs().max().to(self._default_device) for p in local_params)
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._device) for p in local_params]), # type: ignore
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._default_device) for p in local_params]), # type: ignore
p=norm_type,
)
......@@ -412,16 +415,25 @@ class OSS(Optimizer):
OSS._sync_param_groups(state_dict["param_groups"], self.param_groups)
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
def refresh_trainable(self) -> None:
""" Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed
"""
# Create the optim which will work on the param shard
if not hasattr(self, "optim"):
self._clear_cache()
self._default_device = list(self.per_device_params.keys())[0]
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
self._setup_flat_buffers()
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom
local_cpu_state = recursive_copy_to_device(
self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")
)
# Tensor cannot be really empty, even if its size is meaningless
dummy_sync_tensor = torch.tensor([1], device=self._device)
dummy_sync_tensor = torch.tensor([1], device=self._default_device)
for rank in range(self.world_size):
if rank == self.rank:
......@@ -431,17 +443,20 @@ class OSS(Optimizer):
)
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
self.local_state_dict(),
src_rank=self.global_rank,
group=self.group,
dist_device=self._default_device,
)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device),
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
dist_device=self._default_device,
)
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
......@@ -457,19 +472,19 @@ class OSS(Optimizer):
# Sync with other replicas
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._device,
dist_device=self._default_device,
)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
dist_device=self._default_device,
)
all_states.append(
......@@ -546,23 +561,6 @@ class OSS(Optimizer):
if last_work_handle:
last_work_handle.wait()
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
while len(self.work_handles) > 0:
work_handle = self.work_handles.popleft()
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
def _try_consume_work_handle(self) -> None:
"""Try to consume the oldest future. This is non blocking, if not ready we'll pass"""
while len(self.work_handles) > 0 and self.work_handles[0].handle.is_completed():
work_handle = self.work_handles.popleft()
if work_handle.callback is not None:
work_handle.callback()
def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
......@@ -570,19 +568,35 @@ class OSS(Optimizer):
"""
for device, per_rank_params in self.per_device_params.items():
self.buckets[device] = []
# Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
if device not in self.buckets.keys():
self.buckets[device] = []
# Make parameters a view of the bucket
for dst_rank, params in enumerate(per_rank_params):
if len(params) > 0:
# Clone the non-trainable params, if in a bucket it will get destroyed
for param in filter(lambda x: not x.requires_grad, params):
param.data = param.data.detach().clone()
# Merge all the trainable params in a single bucket
trainable_params = list(filter(lambda x: x.requires_grad, params))
buffer_size = sum(map(lambda x: x.numel(), trainable_params))
self.buckets[device].append(torch.empty(buffer_size, dtype=params[0].dtype, device=device))
bucket = torch.empty(buffer_size, dtype=params[0].dtype, device=device)
offset = 0
for param in trainable_params:
# This parameter becomes a view of the bucket
offset_next = offset + param.numel()
self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten())
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)
bucket[offset:offset_next].copy_(param.data.flatten())
param.data = bucket[offset:offset_next].view_as(param.data)
offset = offset_next
# Either replace the existing bucket, or create it
if len(self.buckets[device]) == dst_rank:
self.buckets[device].append(bucket)
else:
self.buckets[device][dst_rank] = bucket
else:
self.buckets[device].append(torch.zeros(1, device=device))
......@@ -3,11 +3,11 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import collections
import io
from typing import Any, Callable, Dict, Optional
import torch
from torch._six import container_abcs
import torch.distributed as dist
......@@ -38,7 +38,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return values if isinstance(value, list) else tuple(values)
if isinstance(value, container_abcs.Mapping):
if isinstance(value, collections.abc.Mapping):
device_val: Dict[str, Any] = {}
for key, val in value.items():
device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
......@@ -89,6 +89,7 @@ class Bucket:
self.max_size = buffer.numel()
# Current status for this buffer
self.fill = 0
self.params_checked_in = 0
self.max_params_checked_in = 0 # atttribute present for convenience purposes
self.destination = -1
......
......@@ -406,3 +406,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return False
else:
return a == b
def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None:
for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
assert torch.allclose(p_a, p_b, atol=1e-3), f"Model parameters differ\n{p_a} {p_b}\n" + message
for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
......@@ -18,12 +18,13 @@ class DistributedDataParallel(Module[T_co]):
check_reduction: bool = ...
broadcast_bucket_size: float = ...
bucket_bytes_cap: float = ...
find_unused_parameters: bool = ...
# TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ...,
output_device: Optional[_device_t] = ..., dim: int = ...,
broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
check_reduction: bool = ...) -> None: ...
check_reduction: bool = ..., find_unused_parameters: bool = ...) -> None: ...
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...
......
......@@ -23,7 +23,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu
from fairscale.utils.testing import GPT2, check_same_model_params, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu
def run_one_step(rank, world_size, backend, device, temp_file_name):
......@@ -133,67 +133,66 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
NUMBER_BATCHS = 5
INPUTS = 2
BATCH_SIZE = 32
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate):
accumulate_steps = 3 if should_accumulate else 1
model.zero_grad()
def step():
if scaler is not None:
with torch.cuda.amp.autocast():
loss = model(input_tensor).abs().sum()
scaler.scale(loss).backward()
else:
loss = model(input_tensor).abs().sum()
loss.backward()
with model.no_sync() if should_accumulate else suppress():
for _ in range(accumulate_steps - 1):
step()
step()
def check_parity(amp: bool):
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
# Make sure that the model starts with non-trainable, so that we check for the buckets to be
# properly reassigned when/if this changes
next(model.parameters()).requires_grad = False
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
ddp_scaler = TorchGradScaler() if amp else None
sharded_ddp_scaler = ShardedGradScaler() if amp else None
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b, atol=1e-3
), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
check_same_model_params(sharded_ddp_model, ddp_model)
# The models should stay the same in between the ranks
for i in range(10):
input_tensor = torch.rand((64, 2)).to(device)
# Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
if ddp_scaler is not None:
with torch.cuda.amp.autocast():
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_scaler.scale(ddp_loss).backward()
else:
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
return closure(ddp_model, ddp_scaler, input_tensor, accumulate)
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
if sharded_ddp_scaler is not None:
with torch.cuda.amp.autocast():
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_ddp_scaler.scale(sharded_loss).backward()
else:
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate)
# Step/scale both
if ddp_scaler is not None:
......@@ -210,13 +209,28 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
else:
sharded_optimizer.step(closure=closure_sharded)
check_same_model_params()
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
check_parity(amp=False)
# Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next(
sharded_ddp_model.parameters()
).requires_grad
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
# Catch a version of pytorch which would not support AMP
# Test all combinations: AMP, Accumulate, Change train graph
amp_tests = [False]
if hasattr(torch.cuda.amp, "autocast"):
check_parity(amp=True)
amp_tests.append(True)
for accumulate in [False, True]:
for change_train_graph in [False, True]:
for amp in amp_tests:
print(
f"Checking configuration: accumulate {accumulate} - change train graph {change_train_graph} - amp {amp}"
)
check_parity(amp=amp, accumulate=accumulate, change_train_graph=change_train_graph)
dist.destroy_process_group()
......@@ -417,6 +431,8 @@ def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_na
model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device) # in pytorch 1.5 syncBN switches to the default device/cpu
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
......
......@@ -11,7 +11,7 @@
import copy
from math import inf
import tempfile
from typing import Any, Type, cast
from typing import Any, Dict, Type, cast
import unittest
import numpy as np
......@@ -22,7 +22,7 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim
from fairscale.utils.testing import skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
......@@ -688,10 +688,6 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
model.zero_grad()
outputs = head(model(inputs))
def check_equal_models(message: str):
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), message
# pull the current state, broadcast it to all ranks
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
......@@ -701,12 +697,16 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=1e6, momentum=0.0001)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
sharded_optimizer2.load_state_dict(state_dict2)
check_equal_models("parameters of the two identical models have diverged (before any steps)")
check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (before any steps)"
)
# now take a step and check that parameters are equal
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after stepping)")
check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after stepping)"
)
# save the state dict for one model only, then distribute to the other ranks
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
......@@ -722,7 +722,9 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# take a step
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after consolidating)")
check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after consolidating)"
)
# save again for one rank, then distribute to the others
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
......@@ -737,7 +739,9 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# take a step
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after reloading)")
check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after reloading)"
)
dist.destroy_process_group()
......@@ -768,7 +772,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
out_channels = 3
batch = 64
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer], change_train_graph: bool = False):
# Any model works. Add one different buffer per rank
trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden))
trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
......@@ -777,14 +781,14 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
head = torch.nn.Linear(hidden, out_channels).to(device)
# Define a model to be trained by OSS
oss_model = torch.nn.Sequential(trunk, head)
oss_module = torch.nn.Sequential(trunk, head)
oss_trainable_params = [
{"params": trunk.parameters(), "lr": 1e-5},
{"params": head.parameters(), "lr": 1e-4},
]
optimizer_settings = {}
if isinstance(optim, torch.optim.SGD):
optimizer_settings: Dict[Any, Any] = {}
if isinstance(optimizer, torch.optim.SGD):
optimizer_settings["momentum"] = 0.9
sharded_optimizer = optim.OSS(
......@@ -795,7 +799,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
**optimizer_settings,
)
oss_ddp_model = DDP(module=oss_model, device_ids=[rank], broadcast_buffers=True)
oss_ddp_model = DDP(module=oss_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
# Define a model to be trained by normal pytorch + DDP
ddp_trunk = copy.deepcopy(trunk)
......@@ -807,19 +811,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
{"params": ddp_head.parameters(), "lr": 1e-4},
]
ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}"
for b, ddp_b in zip(oss_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b
), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}"
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
def check_step():
input_tensor = torch.rand((batch, in_channels)).to(device)
......@@ -843,13 +835,21 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
check_same_model_params(oss_ddp_model, ddp_model)
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
check_same_model_params(oss_ddp_model, ddp_model)
# The models should stay the same in between ddp and sharded optimizer
for i in range(5):
check_step()
check_same_model_params()
# Check that altering the trainable parameters does not cause DDP and OSS to diverge
if change_train_graph:
# Flip the first parameter from trainable to non-trainable and vice-versa
next(ddp_module.parameters()).requires_grad = not next(ddp_module.parameters()).requires_grad
next(oss_module.parameters()).requires_grad = not next(oss_module.parameters()).requires_grad
# sharded_optimizer.refresh_trainable()
# Check that the checkpoints are compatible
# - get states
......@@ -864,10 +864,10 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# - run one step and check that the models are still the same
check_step()
check_same_model_params()
for opt in [torch.optim.SGD, torch.optim.Adam]:
check_optimizer_equivalence(opt)
for opt in [torch.optim.Adam, torch.optim.SGD]:
check_optimizer_equivalence(opt, change_train_graph=False)
check_optimizer_equivalence(opt, change_train_graph=True)
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment