Unverified Commit 4f597233 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] OSS flatten state dict (#65)

Changes the structure of the returned state dict with respect to the param_groups to make it closer to what a vanilla optimizer would return (un-shard them). Shard again when loading
parent 6fe88a91
......@@ -49,8 +49,8 @@ def train(
# Data setup, dummy data
def collate(inputs: List[Any]):
return {
"inputs": torch.stack([i[0] for i in inputs]).to(rank),
"label": torch.stack([i[1] for i in inputs]).to(rank),
"inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
"label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
}
dataloader = DataLoader(
......@@ -119,7 +119,7 @@ def train(
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
if use_oss and check_regression and dist.get_rank() == 0:
assert (mean - 3.0 * std) < reference_speed, "Speed regression detected"
assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
print("[Regression Test] VALID")
......@@ -133,11 +133,12 @@ if __name__ == "__main__":
parser.add_argument("--epochs", action="store", default=10, type=int)
parser.add_argument("--batch_size", action="store", default=32, type=int)
parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store", default=True, type=bool)
parser.add_argument("--reference_speed", action="store", default=39.82, type=float)
parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)
args = parser.parse_args()
print(f"Benchmark arguments: {args}")
print("\nBenchmark vanilla optimizer")
mp.spawn(
......
......@@ -6,7 +6,7 @@
import copy
from itertools import chain
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type
import torch
import torch.distributed as dist
......@@ -146,9 +146,21 @@ class OSS(Optimizer):
len(self._all_states) > 0
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
# Flatten the param_groups, save the partition which logs the rank <> shard correspondence
partition: List[Tuple[int, int]] = []
param_groups: List[Dict[Any, Any]] = []
start = 0
for i, s in enumerate(self._all_states):
param_groups.extend(s["param_groups"])
end = start + len(s["param_groups"])
partition.append((start, end))
start = end
return {
"state": [s["state"] for s in self._all_states],
"param_groups": [s["param_groups"] for s in self._all_states],
"param_groups": param_groups,
"partition": partition,
}
def load_local_state_dict(self, state_dict: dict) -> None:
......@@ -177,10 +189,13 @@ class OSS(Optimizer):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """
# Get this optimizer's param_groups shard
param_groups = state_dict["param_groups"][
state_dict["partition"][self.rank][0] : state_dict["partition"][self.rank][1]
]
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(
{"state": state_dict["state"][self.rank], "param_groups": state_dict["param_groups"][self.rank]}
)
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups})
def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
......
......@@ -55,16 +55,16 @@ def test_state_dict():
assert "state" in state_dict.keys()
# Check that the pulled state is what we expect, and that we have all the expected keys
assert state_dict["param_groups"][0][0]["lr"] == 0.1
assert state_dict["param_groups"][0][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0][0]["nesterov"]
assert state_dict["param_groups"][0][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0][0]["dampening"] == 0.0
assert state_dict["param_groups"][0]["lr"] == 0.1
assert state_dict["param_groups"][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0]["nesterov"]
assert state_dict["param_groups"][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0]["dampening"] == 0.0
# Check that the pulled state and the .param_groups attribute are in sync
for k in state_dict["param_groups"][0][0].keys():
for k in state_dict["param_groups"][0].keys():
if k != "params":
assert state_dict["param_groups"][0][0][k] == o.param_groups[0][k]
assert state_dict["param_groups"][0][k] == o.param_groups[0][k]
# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment