Unverified Commit 379c6bf0 authored by Joshua Meier's avatar Joshua Meier Committed by GitHub
Browse files

Support optimizer state sharding for megatron (#121)

support optimizer state sharding for megatron
parent 1c2a6f6b
......@@ -72,6 +72,8 @@ class OSS(Optimizer):
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default)
# Optional consolidated optimizer state
......@@ -88,7 +90,7 @@ class OSS(Optimizer):
# Partition helpers
def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks.
"""Partitions parameters across distributed data parallel ranks.
Returns a list of param_groups (which is a list of dict) where each
element of the list contains the param_groups for a rank. Element 0
......@@ -135,6 +137,7 @@ class OSS(Optimizer):
@property
def param_to_rank(self) -> Dict[torch.Tensor, int]:
'''param to data parallel rank'''
if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
......@@ -142,6 +145,13 @@ class OSS(Optimizer):
self._param_rank[param] = rank
return self._param_rank
def get_global_rank(self, group, rank):
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
......@@ -174,14 +184,15 @@ class OSS(Optimizer):
# Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
global_rank = self.get_global_rank(self.group, rank)
requires_grad.append((param, param.requires_grad))
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True))
requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))
for fut, req_grad in zip(requests, requires_grad):
fut.wait()
req_grad[0].requires_grad = req_grad[1]
return loss
def local_state_dict(self) -> dict:
......@@ -330,7 +341,7 @@ class OSS(Optimizer):
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = []
for rank in range(dist.get_world_size(group=self.group)):
for rank in range(self.world_size):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
......@@ -338,12 +349,13 @@ class OSS(Optimizer):
)
# Sync with other replicas
broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device)
broadcast_object(empty_buffer, src_rank=self.global_rank, group=self.group, dist_device=self._device)
else:
# Fetch the optim state from the other replicas
logging.debug("Receiving state from rank %s ", rank)
global_rank = self.get_global_rank(self.group, rank)
logging.debug("Receiving state from rank %s ", global_rank)
replica_state = broadcast_object(
empty_buffer, src_rank=rank, group=self.group, dist_device=self._device
empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device
)
all_states.append(
......@@ -358,17 +370,18 @@ class OSS(Optimizer):
"""Broadcast this rank's state shard, discard others"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
for rank in range(dist.get_world_size(group=self.group)):
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
broadcast_object(self.local_state_dict(), src_rank=rank, group=self.group, dist_device=self._device)
broadcast_object(self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing
logging.debug("Discarding broadcast from rank %s", rank)
broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device)
logging.debug("Discarding broadcast from rank %s", global_rank)
broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)
def _free_other_grads(self) -> None:
"""Free all the gradients only useful for the other ranks
......@@ -380,3 +393,4 @@ class OSS(Optimizer):
for p in partition:
for t in p["params"]:
t.grad = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment