Unverified Commit 4f597233 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] OSS flatten state dict (#65)

Changes the structure of the returned state dict with respect to the param_groups to make it closer to what a vanilla optimizer would return (un-shard them). Shard again when loading
parent 6fe88a91
...@@ -49,8 +49,8 @@ def train( ...@@ -49,8 +49,8 @@ def train(
# Data setup, dummy data # Data setup, dummy data
def collate(inputs: List[Any]): def collate(inputs: List[Any]):
return { return {
"inputs": torch.stack([i[0] for i in inputs]).to(rank), "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
"label": torch.stack([i[1] for i in inputs]).to(rank), "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
} }
dataloader = DataLoader( dataloader = DataLoader(
...@@ -119,7 +119,7 @@ def train( ...@@ -119,7 +119,7 @@ def train(
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
if use_oss and check_regression and dist.get_rank() == 0: if use_oss and check_regression and dist.get_rank() == 0:
assert (mean - 3.0 * std) < reference_speed, "Speed regression detected" assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
assert max_memory < 1.05 * reference_memory, "Memory use regression detected" assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
print("[Regression Test] VALID") print("[Regression Test] VALID")
...@@ -133,11 +133,12 @@ if __name__ == "__main__": ...@@ -133,11 +133,12 @@ if __name__ == "__main__":
parser.add_argument("--epochs", action="store", default=10, type=int) parser.add_argument("--epochs", action="store", default=10, type=int)
parser.add_argument("--batch_size", action="store", default=32, type=int) parser.add_argument("--batch_size", action="store", default=32, type=int)
parser.add_argument("--data_size", action="store", default=512, type=int) parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store", default=True, type=bool) parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=39.82, type=float) parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float) parser.add_argument("--reference_memory", action="store", default=4475, type=float)
args = parser.parse_args() args = parser.parse_args()
print(f"Benchmark arguments: {args}")
print("\nBenchmark vanilla optimizer") print("\nBenchmark vanilla optimizer")
mp.spawn( mp.spawn(
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import copy import copy
from itertools import chain from itertools import chain
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -146,9 +146,21 @@ class OSS(Optimizer): ...@@ -146,9 +146,21 @@ class OSS(Optimizer):
len(self._all_states) > 0 len(self._all_states) > 0
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand" ), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
# Flatten the param_groups, save the partition which logs the rank <> shard correspondence
partition: List[Tuple[int, int]] = []
param_groups: List[Dict[Any, Any]] = []
start = 0
for i, s in enumerate(self._all_states):
param_groups.extend(s["param_groups"])
end = start + len(s["param_groups"])
partition.append((start, end))
start = end
return { return {
"state": [s["state"] for s in self._all_states], "state": [s["state"] for s in self._all_states],
"param_groups": [s["param_groups"] for s in self._all_states], "param_groups": param_groups,
"partition": partition,
} }
def load_local_state_dict(self, state_dict: dict) -> None: def load_local_state_dict(self, state_dict: dict) -> None:
...@@ -177,10 +189,13 @@ class OSS(Optimizer): ...@@ -177,10 +189,13 @@ class OSS(Optimizer):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """ """ Restore the global parameter groups as well as the shard """
# Get this optimizer's param_groups shard
param_groups = state_dict["param_groups"][
state_dict["partition"][self.rank][0] : state_dict["partition"][self.rank][1]
]
# Dispatch this rank's state dictionary to the wrapped shard optimizer # Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict( self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups})
{"state": state_dict["state"][self.rank], "param_groups": state_dict["param_groups"][self.rank]}
)
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
......
...@@ -55,16 +55,16 @@ def test_state_dict(): ...@@ -55,16 +55,16 @@ def test_state_dict():
assert "state" in state_dict.keys() assert "state" in state_dict.keys()
# Check that the pulled state is what we expect, and that we have all the expected keys # Check that the pulled state is what we expect, and that we have all the expected keys
assert state_dict["param_groups"][0][0]["lr"] == 0.1 assert state_dict["param_groups"][0]["lr"] == 0.1
assert state_dict["param_groups"][0][0]["momentum"] == 0.9 assert state_dict["param_groups"][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0][0]["nesterov"] assert not state_dict["param_groups"][0]["nesterov"]
assert state_dict["param_groups"][0][0]["weight_decay"] == 0.0 assert state_dict["param_groups"][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0][0]["dampening"] == 0.0 assert state_dict["param_groups"][0]["dampening"] == 0.0
# Check that the pulled state and the .param_groups attribute are in sync # Check that the pulled state and the .param_groups attribute are in sync
for k in state_dict["param_groups"][0][0].keys(): for k in state_dict["param_groups"][0].keys():
if k != "params": if k != "params":
assert state_dict["param_groups"][0][0][k] == o.param_groups[0][k] assert state_dict["param_groups"][0][k] == o.param_groups[0][k]
# Check that it's correctly loaded # Check that it's correctly loaded
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment