Unverified Commit e4a0804c authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] optim/oss: save memory and time by avoiding duplicate copy of parameters (#57)

parent 220ee323
......@@ -85,10 +85,9 @@ class OSS(Optimizer):
param_lists[rank].append(param)
sizes[rank] += param.numel()
for rank, params in enumerate(param_lists):
if len(params) > 0:
param_group_rank = copy.copy(param_group)
param_group_rank["params"] = params
param_groups[rank].append(param_group_rank)
param_group_rank = copy.copy(param_group)
param_group_rank["params"] = params
param_groups[rank].append(param_group_rank)
return param_groups
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
......@@ -134,7 +133,7 @@ class OSS(Optimizer):
len(self._all_states) > 0
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
return {"state": self._all_states, "param_groups": self.param_groups}
return {"state": self._all_states}
def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """
......@@ -146,8 +145,11 @@ class OSS(Optimizer):
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(state_dict["state"][self.rank])
# Restore the global param_groups
self.param_groups = recursive_copy_to_device(state_dict["param_groups"], non_blocking=True, device=self._device)
# Restore the global param_groups (the params themselves are already correct)
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
......
......@@ -51,10 +51,10 @@ def test_state_dict():
state_dict = o.state_dict()
# Check that the pulled state is what we expect
assert state_dict["param_groups"][0]["lr"] == 0.1
assert state_dict["state"][0]["param_groups"][0]["lr"] == 0.1
# Check that the pulled state and the .param_groups attribute are in sync
assert state_dict["param_groups"][0]["lr"] == o.param_groups[0]["lr"]
assert state_dict["state"][0]["param_groups"][0]["lr"] == o.param_groups[0]["lr"]
# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01)
......@@ -113,10 +113,7 @@ def run_test_add_param_group(rank, world_size):
assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
if rank == 1:
assert len(o.optim.param_groups) == 2
else:
assert len(o.optim.param_groups) == 1
assert len(o.optim.param_groups) == 2
def test_add_param_group():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment