Unverified Commit e41452e8 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS/ShardedDDP] making APIs more private (#582)

* making APIs more private
* linting
parent befbc73a
......@@ -22,7 +22,7 @@ import torch.distributed as dist
from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS
from fairscale.optim.utils import Workhandle
from fairscale.optim.utils import Workhandle, get_global_rank
def _trainable(param: torch.Tensor) -> bool:
......@@ -122,11 +122,11 @@ class ShardedDataParallel(nn.Module):
self.process_group = process_group if process_group is not None else dist.group.WORLD
self.backend = dist.get_backend(self.process_group)
self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.reference_global_rank = get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group)
self.global_rank = OSS.get_global_rank(self.process_group, self.rank)
self.global_rank = get_global_rank(self.process_group, self.rank)
self._local_to_global_rank = [
OSS.get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group))
get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group))
]
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
......@@ -149,7 +149,7 @@ class ShardedDataParallel(nn.Module):
# - we build an iterator which goes through all the parameters involved globally
self._all_params = list(
chain(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers]
*[sum([sum(p, []) for p in optim._per_device_params.values()], []) for optim in self.sharded_optimizers]
)
)
self._trainable_params: List[torch.Tensor] = []
......@@ -288,10 +288,10 @@ class ShardedDataParallel(nn.Module):
# Update ShardedDDP given the new partitions
for (
device_per_rank_params
) in optim.per_device_params.values(): # all the params on this device (inc all ranks)
) in optim._per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params):
self._trainable_param_to_rank[param] = optim.param_to_rank[param]
self._trainable_param_to_rank[param] = optim._param_to_rank[param]
self._setup_bucket_strategy()
self._setup_backward_hooks()
......
......@@ -17,7 +17,7 @@ from torch.optim import SGD, Optimizer
from fairscale.nn.misc import ParamBucket
from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
from .utils import broadcast_object, calc_grad_norm, get_global_rank, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -89,11 +89,11 @@ class OSS(Optimizer):
self.in_super_constructor = False
# Partition information. lazy evaluation, computed when requested
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self.__per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self.__param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[List[torch.Tensor]] = None
self.__param_to_index: Dict[int, int] = {}
self.__local_params: Optional[List[torch.Tensor]] = None
# Default empty values + immutables
self._optim_defaults = default
......@@ -103,8 +103,8 @@ class OSS(Optimizer):
self.world_size = dist.get_world_size(self.group)
self.backend = dist.get_backend(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)]
self.global_rank = get_global_rank(self.group, self.rank)
self._local_to_global_rank = [get_global_rank(self.group, i) for i in range(self.world_size)]
self.broadcast_fp16 = broadcast_fp16
self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {}
......@@ -151,69 +151,6 @@ class OSS(Optimizer):
return self._partition_parameters
@property
def local_params(self) -> List[torch.Tensor]:
""" Iterable which goes through the parameters that this rank owns
"""
if self._local_params is None:
self._local_params = list(
chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return self._local_params
@property
def param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._param_to_index) == 0:
self._param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self._param_to_index
@property
def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if len(self._per_device_params) == 0:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for param_group in self.param_groups:
for param in param_group["params"]:
device = param.device
if self._per_device_params.get(device) is None:
self._per_device_params[device] = [[] for _ in range(self.world_size)]
self._per_device_params[device][self.param_to_rank[param]] += [param]
# Sort param_lists by size
for device in self._per_device_params.keys():
for rank_params in self._per_device_params[device]:
rank_params.sort(key=lambda x: x.numel())
return self._per_device_params
@property
def param_to_rank(self) -> Dict[torch.Tensor, int]:
"""param to data parallel rank"""
if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
logging.debug("ZeRO: Parameters dispatched to ranks %s " % list(self._param_rank.values()))
return self._param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
......@@ -281,7 +218,7 @@ class OSS(Optimizer):
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params
local_params = filter_params_fn(self._local_params) if filter_params_fn is not None else self._local_params
local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# Compute the norm on this grad set,
......@@ -301,9 +238,9 @@ class OSS(Optimizer):
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
for device, device_params in self.per_device_params.items():
for device, device_params in self._per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter
return total_norm
......@@ -426,7 +363,7 @@ class OSS(Optimizer):
for local_param_index in local_pg["params"]:
# Update the state, if any
if local_param_index in s["state"].keys():
global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
global_id = self._param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index]
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
......@@ -462,7 +399,7 @@ class OSS(Optimizer):
# Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own
if self.param_to_rank[param] != self.rank:
if self._param_to_rank[param] != self.rank:
state_dict["state"][key] = {}
else:
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
......@@ -485,7 +422,7 @@ class OSS(Optimizer):
# Create the optim which will work on the param shard
if not hasattr(self, "optim"):
self._clear_cache()
self._default_device = list(self.per_device_params.keys())[0]
self._default_device = list(self._per_device_params.keys())[0]
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
......@@ -517,20 +454,73 @@ class OSS(Optimizer):
# Update the bucketing strategy accordingly
self._setup_flat_buffers()
@property
def _local_params(self) -> List[torch.Tensor]:
""" Iterable which goes through the parameters that this rank owns """
if self.__local_params is None:
self.__local_params = list(
chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self._per_device_params.values()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return self.__local_params
@property
def _param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params """
if len(self.__param_to_index) == 0:
self.__param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self.__param_to_index
@property
def _per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if len(self.__per_device_params) == 0:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for param_group in self.param_groups:
for param in param_group["params"]:
device = param.device
if self.__per_device_params.get(device) is None:
self.__per_device_params[device] = [[] for _ in range(self.world_size)]
self.__per_device_params[device][self._param_to_rank[param]] += [param]
# Sort param_lists by size
for device in self.__per_device_params.keys():
for rank_params in self.__per_device_params[device]:
rank_params.sort(key=lambda x: x.numel())
return self.__per_device_params
@property
def _param_to_rank(self) -> Dict[torch.Tensor, int]:
"""Map the params to the rank which owns them"""
if len(self.__param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self.__param_rank[param] = rank
logging.debug("FairScale OSS: Parameters dispatched to ranks %s " % list(self.__param_rank.values()))
return self.__param_rank
def _clear_cache(self) -> None:
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
self._param_to_index.clear()
self._local_params = None
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
self.__per_device_params.clear()
self.__param_rank.clear()
self.__param_to_index.clear()
self.__local_params = None
@staticmethod
def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None:
......@@ -548,7 +538,7 @@ class OSS(Optimizer):
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
if torch.device("cuda").type == self._default_device.type:
for device in self.per_device_params.keys():
for device in self._per_device_params.keys():
torch.cuda.synchronize(device=device)
work_handles = [] # Work handles are consumed within this scope, no callback
......@@ -585,7 +575,7 @@ class OSS(Optimizer):
`refresh_trainability` is called.
"""
for device, per_rank_params in self.per_device_params.items():
for device, per_rank_params in self._per_device_params.items():
# Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
if device not in self.buckets.keys():
......@@ -610,7 +600,7 @@ class OSS(Optimizer):
self.buckets[device][dst_rank] = bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use = list(self.per_device_params.keys())
devices_in_use = list(self._per_device_params.keys())
devices_to_pop = list(filter(lambda x: x not in devices_in_use, self.buckets.keys()))
for d in devices_to_pop:
self.buckets.pop(d)
......@@ -18,6 +18,13 @@ class Workhandle:
self.callback = callback
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
return dist.distributed_c10d._get_global_rank(group, rank)
# Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
"""
......
......@@ -681,7 +681,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
assert torch.allclose(oss_total_norm, total_norm), "torch and fairscale should return the same grad norm"
# Check that the params have indeed been clipped
for params in sharded_optimizer.per_device_params.values():
for params in sharded_optimizer._per_device_params.values():
for param in filter(lambda x: x.grad is not None, params[rank]):
assert torch.norm(param.grad, p=norm) < CLIP_NORM, f"param grad norm above clip : {param.grad}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment