Unverified Commit 1d1d15ea authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS pytorch-compliant state dict (#61)

* Aligning the optimizer state dict with what PyTorch expects

* Adding a check on the dict keys, ensure that `state` and `param_groups` are there

* after installing the specific isort, black and all, one liner to please the linter..
parent 4488e17c
...@@ -50,7 +50,7 @@ def train( ...@@ -50,7 +50,7 @@ def train(
"label": torch.stack([i[1] for i in inputs]).to(rank), "label": torch.stack([i[1] for i in inputs]).to(rank),
} }
def _print(msg): def print_(msg):
if dist.get_rank() == 0: if dist.get_rank() == 0:
print(msg) print(msg)
...@@ -91,7 +91,7 @@ def train( ...@@ -91,7 +91,7 @@ def train(
epoch_end = time.monotonic() epoch_end = time.monotonic()
measurements.append(data_size / (epoch_end - epoch_start)) measurements.append(data_size / (epoch_end - epoch_start))
_print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec") print_(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
training_stop = time.monotonic() training_stop = time.monotonic()
......
...@@ -58,14 +58,14 @@ class OSS(Optimizer): ...@@ -58,14 +58,14 @@ class OSS(Optimizer):
# Build the wrapped optimizer, responsible for a shard of the params # Build the wrapped optimizer, responsible for a shard of the params
self.group = group self.group = group
self.rank = dist.get_rank(group) self.rank = dist.get_rank(group)
param_groups = self.partition_parameters() split_param_groups = self.partition_parameters()
self.optim = optim(param_groups[self.rank], **defaults) self.optim = optim(split_param_groups[self.rank], **defaults)
# Optional consolidated optimizer state # Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = [] self._all_states: List[Dict[str, Any]] = []
# Current device is set by the parameters allocated to this rank # Current device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device self._device = split_param_groups[self.rank][0]["params"][0].device
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks. """Partitions parameters across distributed ranks.
...@@ -122,7 +122,7 @@ class OSS(Optimizer): ...@@ -122,7 +122,7 @@ class OSS(Optimizer):
if self.rank == recipient_rank: if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas # Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank # Store all the states in order, rank by rank
logging.debug("Pulling the sharded SGD state from all replicas") logging.debug("Pulling the sharded optimizer state from all replicas")
self._all_states = self._collect_sharded_states() self._all_states = self._collect_sharded_states()
else: else:
# Acknowledge broadcasts, and send this rank's shard when needed # Acknowledge broadcasts, and send this rank's shard when needed
...@@ -140,7 +140,10 @@ class OSS(Optimizer): ...@@ -140,7 +140,10 @@ class OSS(Optimizer):
len(self._all_states) > 0 len(self._all_states) > 0
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand" ), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
return {"state": self._all_states} return {
"state": [s["state"] for s in self._all_states],
"param_groups": [s["param_groups"] for s in self._all_states],
}
def load_local_state_dict(self, state_dict: dict) -> None: def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """ """ Loads this rank's state_dict. """
...@@ -169,7 +172,12 @@ class OSS(Optimizer): ...@@ -169,7 +172,12 @@ class OSS(Optimizer):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """ """ Restore the global parameter groups as well as the shard """
# Dispatch this rank's state dictionary to the wrapped shard optimizer # Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(state_dict["state"][self.rank]) self.load_local_state_dict(
{"state": state_dict["state"][self.rank], "param_groups": state_dict["param_groups"][self.rank]}
)
# Update the param_groups attribute for this instance
# TODO(ben)
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
...@@ -224,7 +232,7 @@ class OSS(Optimizer): ...@@ -224,7 +232,7 @@ class OSS(Optimizer):
if rank == self.rank: if rank == self.rank:
# Send the state to the reference replica # Send the state to the reference replica
logging.debug( logging.debug(
"Sending the sharded SGD state to the reference replica from rank %s", rank, "Sending the sharded optimizer state to the reference replica from rank %s", rank,
) )
broadcast_object(self.local_state_dict(), src_rank=rank, group=self.group, dist_device=self._device) broadcast_object(self.local_state_dict(), src_rank=rank, group=self.group, dist_device=self._device)
else: else:
......
...@@ -50,11 +50,15 @@ def test_state_dict(): ...@@ -50,11 +50,15 @@ def test_state_dict():
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict() state_dict = o.state_dict()
# Check that the state dict is pytorch-compliant key wise
assert "param_groups" in state_dict.keys()
assert "state" in state_dict.keys()
# Check that the pulled state is what we expect # Check that the pulled state is what we expect
assert state_dict["state"][0]["param_groups"][0]["lr"] == 0.1 assert state_dict["param_groups"][0][0]["lr"] == 0.1
# Check that the pulled state and the .param_groups attribute are in sync # Check that the pulled state and the .param_groups attribute are in sync
assert state_dict["state"][0]["param_groups"][0]["lr"] == o.param_groups[0]["lr"] assert state_dict["param_groups"][0][0]["lr"] == o.param_groups[0]["lr"]
# Check that it's correctly loaded # Check that it's correctly loaded
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment