Unverified Commit e4a0804c authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] optim/oss: save memory and time by avoiding duplicate copy of parameters (#57)

parent 220ee323
...@@ -85,10 +85,9 @@ class OSS(Optimizer): ...@@ -85,10 +85,9 @@ class OSS(Optimizer):
param_lists[rank].append(param) param_lists[rank].append(param)
sizes[rank] += param.numel() sizes[rank] += param.numel()
for rank, params in enumerate(param_lists): for rank, params in enumerate(param_lists):
if len(params) > 0: param_group_rank = copy.copy(param_group)
param_group_rank = copy.copy(param_group) param_group_rank["params"] = params
param_group_rank["params"] = params param_groups[rank].append(param_group_rank)
param_groups[rank].append(param_group_rank)
return param_groups return param_groups
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
...@@ -134,7 +133,7 @@ class OSS(Optimizer): ...@@ -134,7 +133,7 @@ class OSS(Optimizer):
len(self._all_states) > 0 len(self._all_states) > 0
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand" ), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
return {"state": self._all_states, "param_groups": self.param_groups} return {"state": self._all_states}
def load_local_state_dict(self, state_dict: dict) -> None: def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """ """ Loads this rank's state_dict. """
...@@ -146,8 +145,11 @@ class OSS(Optimizer): ...@@ -146,8 +145,11 @@ class OSS(Optimizer):
# Dispatch this rank's state dictionary to the wrapped shard optimizer # Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(state_dict["state"][self.rank]) self.load_local_state_dict(state_dict["state"][self.rank])
# Restore the global param_groups # Restore the global param_groups (the params themselves are already correct)
self.param_groups = recursive_copy_to_device(state_dict["param_groups"], non_blocking=True, device=self._device) for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
......
...@@ -51,10 +51,10 @@ def test_state_dict(): ...@@ -51,10 +51,10 @@ def test_state_dict():
state_dict = o.state_dict() state_dict = o.state_dict()
# Check that the pulled state is what we expect # Check that the pulled state is what we expect
assert state_dict["param_groups"][0]["lr"] == 0.1 assert state_dict["state"][0]["param_groups"][0]["lr"] == 0.1
# Check that the pulled state and the .param_groups attribute are in sync # Check that the pulled state and the .param_groups attribute are in sync
assert state_dict["param_groups"][0]["lr"] == o.param_groups[0]["lr"] assert state_dict["state"][0]["param_groups"][0]["lr"] == o.param_groups[0]["lr"]
# Check that it's correctly loaded # Check that it's correctly loaded
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
...@@ -113,10 +113,7 @@ def run_test_add_param_group(rank, world_size): ...@@ -113,10 +113,7 @@ def run_test_add_param_group(rank, world_size):
assert len(o.param_groups) == 2 assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements. # Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8 assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
if rank == 1: assert len(o.optim.param_groups) == 2
assert len(o.optim.param_groups) == 2
else:
assert len(o.optim.param_groups) == 1
def test_add_param_group(): def test_add_param_group():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment