Unverified Commit e41452e8 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS/ShardedDDP] making APIs more private (#582)

* making APIs more private
* linting
parent befbc73a
...@@ -22,7 +22,7 @@ import torch.distributed as dist ...@@ -22,7 +22,7 @@ import torch.distributed as dist
from fairscale.nn.misc import GradBucket from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.utils import Workhandle from fairscale.optim.utils import Workhandle, get_global_rank
def _trainable(param: torch.Tensor) -> bool: def _trainable(param: torch.Tensor) -> bool:
...@@ -122,11 +122,11 @@ class ShardedDataParallel(nn.Module): ...@@ -122,11 +122,11 @@ class ShardedDataParallel(nn.Module):
self.process_group = process_group if process_group is not None else dist.group.WORLD self.process_group = process_group if process_group is not None else dist.group.WORLD
self.backend = dist.get_backend(self.process_group) self.backend = dist.get_backend(self.process_group)
self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0 self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference self.reference_global_rank = get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
self.global_rank = OSS.get_global_rank(self.process_group, self.rank) self.global_rank = get_global_rank(self.process_group, self.rank)
self._local_to_global_rank = [ self._local_to_global_rank = [
OSS.get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group)) get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group))
] ]
# Expose some of the PytorchDDP attributes, some frameworks rely on them. # Expose some of the PytorchDDP attributes, some frameworks rely on them.
...@@ -149,7 +149,7 @@ class ShardedDataParallel(nn.Module): ...@@ -149,7 +149,7 @@ class ShardedDataParallel(nn.Module):
# - we build an iterator which goes through all the parameters involved globally # - we build an iterator which goes through all the parameters involved globally
self._all_params = list( self._all_params = list(
chain( chain(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers] *[sum([sum(p, []) for p in optim._per_device_params.values()], []) for optim in self.sharded_optimizers]
) )
) )
self._trainable_params: List[torch.Tensor] = [] self._trainable_params: List[torch.Tensor] = []
...@@ -288,10 +288,10 @@ class ShardedDataParallel(nn.Module): ...@@ -288,10 +288,10 @@ class ShardedDataParallel(nn.Module):
# Update ShardedDDP given the new partitions # Update ShardedDDP given the new partitions
for ( for (
device_per_rank_params device_per_rank_params
) in optim.per_device_params.values(): # all the params on this device (inc all ranks) ) in optim._per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params: for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params): for param in filter(lambda x: x.requires_grad, device_params):
self._trainable_param_to_rank[param] = optim.param_to_rank[param] self._trainable_param_to_rank[param] = optim._param_to_rank[param]
self._setup_bucket_strategy() self._setup_bucket_strategy()
self._setup_backward_hooks() self._setup_backward_hooks()
......
...@@ -17,7 +17,7 @@ from torch.optim import SGD, Optimizer ...@@ -17,7 +17,7 @@ from torch.optim import SGD, Optimizer
from fairscale.nn.misc import ParamBucket from fairscale.nn.misc import ParamBucket
from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device from .utils import broadcast_object, calc_grad_norm, get_global_rank, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -89,11 +89,11 @@ class OSS(Optimizer): ...@@ -89,11 +89,11 @@ class OSS(Optimizer):
self.in_super_constructor = False self.in_super_constructor = False
# Partition information. lazy evaluation, computed when requested # Partition information. lazy evaluation, computed when requested
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params self.__per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {} self.__param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
self._param_to_index: Dict[int, int] = {} self.__param_to_index: Dict[int, int] = {}
self._local_params: Optional[List[torch.Tensor]] = None self.__local_params: Optional[List[torch.Tensor]] = None
# Default empty values + immutables # Default empty values + immutables
self._optim_defaults = default self._optim_defaults = default
...@@ -103,8 +103,8 @@ class OSS(Optimizer): ...@@ -103,8 +103,8 @@ class OSS(Optimizer):
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.backend = dist.get_backend(self.group) self.backend = dist.get_backend(self.group)
self.rank = dist.get_rank(self.group) self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank) self.global_rank = get_global_rank(self.group, self.rank)
self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)] self._local_to_global_rank = [get_global_rank(self.group, i) for i in range(self.world_size)]
self.broadcast_fp16 = broadcast_fp16 self.broadcast_fp16 = broadcast_fp16
self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {} self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {}
...@@ -151,69 +151,6 @@ class OSS(Optimizer): ...@@ -151,69 +151,6 @@ class OSS(Optimizer):
return self._partition_parameters return self._partition_parameters
@property
def local_params(self) -> List[torch.Tensor]:
""" Iterable which goes through the parameters that this rank owns
"""
if self._local_params is None:
self._local_params = list(
chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return self._local_params
@property
def param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._param_to_index) == 0:
self._param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self._param_to_index
@property
def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if len(self._per_device_params) == 0:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for param_group in self.param_groups:
for param in param_group["params"]:
device = param.device
if self._per_device_params.get(device) is None:
self._per_device_params[device] = [[] for _ in range(self.world_size)]
self._per_device_params[device][self.param_to_rank[param]] += [param]
# Sort param_lists by size
for device in self._per_device_params.keys():
for rank_params in self._per_device_params[device]:
rank_params.sort(key=lambda x: x.numel())
return self._per_device_params
@property
def param_to_rank(self) -> Dict[torch.Tensor, int]:
"""param to data parallel rank"""
if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
logging.debug("ZeRO: Parameters dispatched to ranks %s " % list(self._param_rank.values()))
return self._param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs. # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
...@@ -281,7 +218,7 @@ class OSS(Optimizer): ...@@ -281,7 +218,7 @@ class OSS(Optimizer):
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel' # To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM: # 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54 # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params local_params = filter_params_fn(self._local_params) if filter_params_fn is not None else self._local_params
local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device) local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# Compute the norm on this grad set, # Compute the norm on this grad set,
...@@ -301,9 +238,9 @@ class OSS(Optimizer): ...@@ -301,9 +238,9 @@ class OSS(Optimizer):
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1: if clip_coef < 1:
for device, device_params in self.per_device_params.items(): for device, device_params in self._per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]): for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter
return total_norm return total_norm
...@@ -426,7 +363,7 @@ class OSS(Optimizer): ...@@ -426,7 +363,7 @@ class OSS(Optimizer):
for local_param_index in local_pg["params"]: for local_param_index in local_pg["params"]:
# Update the state, if any # Update the state, if any
if local_param_index in s["state"].keys(): if local_param_index in s["state"].keys():
global_id = self.param_to_index[local_index_to_param_id[local_param_index]] global_id = self._param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index] state_dict["state"][global_id] = s["state"][local_param_index]
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict # Make sure that the parameters are sorted in the state, as expected for a pytorch dict
...@@ -462,7 +399,7 @@ class OSS(Optimizer): ...@@ -462,7 +399,7 @@ class OSS(Optimizer):
# Populate the sharded optimizer state on the fly, # Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own # remove the params that this rank does not own
if self.param_to_rank[param] != self.rank: if self._param_to_rank[param] != self.rank:
state_dict["state"][key] = {} state_dict["state"][key] = {}
else: else:
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device) self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
...@@ -485,7 +422,7 @@ class OSS(Optimizer): ...@@ -485,7 +422,7 @@ class OSS(Optimizer):
# Create the optim which will work on the param shard # Create the optim which will work on the param shard
if not hasattr(self, "optim"): if not hasattr(self, "optim"):
self._clear_cache() self._clear_cache()
self._default_device = list(self.per_device_params.keys())[0] self._default_device = list(self._per_device_params.keys())[0]
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults) self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups) OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
...@@ -517,20 +454,73 @@ class OSS(Optimizer): ...@@ -517,20 +454,73 @@ class OSS(Optimizer):
# Update the bucketing strategy accordingly # Update the bucketing strategy accordingly
self._setup_flat_buffers() self._setup_flat_buffers()
@property
def _local_params(self) -> List[torch.Tensor]:
""" Iterable which goes through the parameters that this rank owns """
if self.__local_params is None:
self.__local_params = list(
chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self._per_device_params.values()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return self.__local_params
@property
def _param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params """
if len(self.__param_to_index) == 0:
self.__param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self.__param_to_index
@property
def _per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if len(self.__per_device_params) == 0:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for param_group in self.param_groups:
for param in param_group["params"]:
device = param.device
if self.__per_device_params.get(device) is None:
self.__per_device_params[device] = [[] for _ in range(self.world_size)]
self.__per_device_params[device][self._param_to_rank[param]] += [param]
# Sort param_lists by size
for device in self.__per_device_params.keys():
for rank_params in self.__per_device_params[device]:
rank_params.sort(key=lambda x: x.numel())
return self.__per_device_params
@property
def _param_to_rank(self) -> Dict[torch.Tensor, int]:
"""Map the params to the rank which owns them"""
if len(self.__param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self.__param_rank[param] = rank
logging.debug("FairScale OSS: Parameters dispatched to ranks %s " % list(self.__param_rank.values()))
return self.__param_rank
def _clear_cache(self) -> None: def _clear_cache(self) -> None:
self._partition_parameters.clear() self._partition_parameters.clear()
self._per_device_params.clear() self.__per_device_params.clear()
self._param_rank.clear() self.__param_rank.clear()
self._param_to_index.clear() self.__param_to_index.clear()
self._local_params = None self.__local_params = None
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
@staticmethod @staticmethod
def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None: def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None:
...@@ -548,7 +538,7 @@ class OSS(Optimizer): ...@@ -548,7 +538,7 @@ class OSS(Optimizer):
# if NCCL broadcasts will be done in an independent stream # if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete # make sure that prior compute work is complete
if torch.device("cuda").type == self._default_device.type: if torch.device("cuda").type == self._default_device.type:
for device in self.per_device_params.keys(): for device in self._per_device_params.keys():
torch.cuda.synchronize(device=device) torch.cuda.synchronize(device=device)
work_handles = [] # Work handles are consumed within this scope, no callback work_handles = [] # Work handles are consumed within this scope, no callback
...@@ -585,7 +575,7 @@ class OSS(Optimizer): ...@@ -585,7 +575,7 @@ class OSS(Optimizer):
`refresh_trainability` is called. `refresh_trainability` is called.
""" """
for device, per_rank_params in self.per_device_params.items(): for device, per_rank_params in self._per_device_params.items():
# Only wipe the existing buckets if there are none # Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes) # (could be that this is called twice, when trainability changes)
if device not in self.buckets.keys(): if device not in self.buckets.keys():
...@@ -610,7 +600,7 @@ class OSS(Optimizer): ...@@ -610,7 +600,7 @@ class OSS(Optimizer):
self.buckets[device][dst_rank] = bucket self.buckets[device][dst_rank] = bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed) # Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use = list(self.per_device_params.keys()) devices_in_use = list(self._per_device_params.keys())
devices_to_pop = list(filter(lambda x: x not in devices_in_use, self.buckets.keys())) devices_to_pop = list(filter(lambda x: x not in devices_in_use, self.buckets.keys()))
for d in devices_to_pop: for d in devices_to_pop:
self.buckets.pop(d) self.buckets.pop(d)
...@@ -18,6 +18,13 @@ class Workhandle: ...@@ -18,6 +18,13 @@ class Workhandle:
self.callback = callback self.callback = callback
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
return dist.distributed_c10d._get_global_rank(group, rank)
# Credits: classy_vision/generic/distributed_util.py # Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any: def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
""" """
......
...@@ -681,7 +681,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name): ...@@ -681,7 +681,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
assert torch.allclose(oss_total_norm, total_norm), "torch and fairscale should return the same grad norm" assert torch.allclose(oss_total_norm, total_norm), "torch and fairscale should return the same grad norm"
# Check that the params have indeed been clipped # Check that the params have indeed been clipped
for params in sharded_optimizer.per_device_params.values(): for params in sharded_optimizer._per_device_params.values():
for param in filter(lambda x: x.grad is not None, params[rank]): for param in filter(lambda x: x.grad is not None, params[rank]):
assert torch.norm(param.grad, p=norm) < CLIP_NORM, f"param grad norm above clip : {param.grad}" assert torch.norm(param.grad, p=norm) < CLIP_NORM, f"param grad norm above clip : {param.grad}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment