Unverified Commit 379c6bf0 authored by Joshua Meier's avatar Joshua Meier Committed by GitHub
Browse files

Support optimizer state sharding for megatron (#121)

support optimizer state sharding for megatron
parent 1c2a6f6b
...@@ -72,6 +72,8 @@ class OSS(Optimizer): ...@@ -72,6 +72,8 @@ class OSS(Optimizer):
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group) self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default) self.optim = optim(self.partition_parameters()[self.rank], **default)
# Optional consolidated optimizer state # Optional consolidated optimizer state
...@@ -88,7 +90,7 @@ class OSS(Optimizer): ...@@ -88,7 +90,7 @@ class OSS(Optimizer):
# Partition helpers # Partition helpers
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks. """Partitions parameters across distributed data parallel ranks.
Returns a list of param_groups (which is a list of dict) where each Returns a list of param_groups (which is a list of dict) where each
element of the list contains the param_groups for a rank. Element 0 element of the list contains the param_groups for a rank. Element 0
...@@ -135,6 +137,7 @@ class OSS(Optimizer): ...@@ -135,6 +137,7 @@ class OSS(Optimizer):
@property @property
def param_to_rank(self) -> Dict[torch.Tensor, int]: def param_to_rank(self) -> Dict[torch.Tensor, int]:
'''param to data parallel rank'''
if len(self._param_rank) == 0: if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()): for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups: for param_group in param_groups:
...@@ -142,6 +145,13 @@ class OSS(Optimizer): ...@@ -142,6 +145,13 @@ class OSS(Optimizer):
self._param_rank[param] = rank self._param_rank[param] = rank
return self._param_rank return self._param_rank
def get_global_rank(self, group, rank):
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs. # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
...@@ -174,14 +184,15 @@ class OSS(Optimizer): ...@@ -174,14 +184,15 @@ class OSS(Optimizer):
# Gloo will rightly assert on this operation for any tensor that requires grad. # Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case # We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank. # the grad is only useful on the source rank.
global_rank = self.get_global_rank(self.group, rank)
requires_grad.append((param, param.requires_grad)) requires_grad.append((param, param.requires_grad))
param.requires_grad = False param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True)) requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))
for fut, req_grad in zip(requests, requires_grad): for fut, req_grad in zip(requests, requires_grad):
fut.wait() fut.wait()
req_grad[0].requires_grad = req_grad[1] req_grad[0].requires_grad = req_grad[1]
return loss return loss
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
...@@ -330,7 +341,7 @@ class OSS(Optimizer): ...@@ -330,7 +341,7 @@ class OSS(Optimizer):
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device) empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = [] all_states: List[Dict[str, Any]] = []
for rank in range(dist.get_world_size(group=self.group)): for rank in range(self.world_size):
if rank == self.rank: if rank == self.rank:
logging.debug("Saving self state") logging.debug("Saving self state")
all_states.append( all_states.append(
...@@ -338,12 +349,13 @@ class OSS(Optimizer): ...@@ -338,12 +349,13 @@ class OSS(Optimizer):
) )
# Sync with other replicas # Sync with other replicas
broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device) broadcast_object(empty_buffer, src_rank=self.global_rank, group=self.group, dist_device=self._device)
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
logging.debug("Receiving state from rank %s ", rank) global_rank = self.get_global_rank(self.group, rank)
logging.debug("Receiving state from rank %s ", global_rank)
replica_state = broadcast_object( replica_state = broadcast_object(
empty_buffer, src_rank=rank, group=self.group, dist_device=self._device empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device
) )
all_states.append( all_states.append(
...@@ -358,17 +370,18 @@ class OSS(Optimizer): ...@@ -358,17 +370,18 @@ class OSS(Optimizer):
"""Broadcast this rank's state shard, discard others""" """Broadcast this rank's state shard, discard others"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device) empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
for rank in range(dist.get_world_size(group=self.group)): for rank in range(self.world_size):
if rank == self.rank: if rank == self.rank:
# Send the state to the reference replica # Send the state to the reference replica
logging.debug( logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank, "Sending the sharded optimizer state to the reference replica from rank %s", rank,
) )
broadcast_object(self.local_state_dict(), src_rank=rank, group=self.group, dist_device=self._device) broadcast_object(self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device)
else: else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing # Discard this tensor/rank, broadcast necessary for syncing
logging.debug("Discarding broadcast from rank %s", rank) logging.debug("Discarding broadcast from rank %s", global_rank)
broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device) broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)
def _free_other_grads(self) -> None: def _free_other_grads(self) -> None:
"""Free all the gradients only useful for the other ranks """Free all the gradients only useful for the other ranks
...@@ -380,3 +393,4 @@ class OSS(Optimizer): ...@@ -380,3 +393,4 @@ class OSS(Optimizer):
for p in partition: for p in partition:
for t in p["params"]: for t in p["params"]:
t.grad = None t.grad = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment