"tests/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "45fa3e44a22cd0b74071237304bdfef97b5b9bd6"
Unverified Commit 8fc78924 authored by Koute's avatar Koute Committed by GitHub
Browse files

Add `move_to_device` kwarg to the optimizer's `load_state_dict` (#1344)

This makes it possible to load an optimizer checkpoint without
automatically moving the optimizer's state to the GPU.
parent abb0c32a
...@@ -153,12 +153,14 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -153,12 +153,14 @@ class Optimizer8bit(torch.optim.Optimizer):
def __setstate__(self, state): def __setstate__(self, state):
super().__setstate__(state) super().__setstate__(state)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict, move_to_device=True):
"""Load an optimizer state. """Load an optimizer state.
Arguments: Arguments:
state_dict (`dict`): state_dict (`dict`):
An optimizer state (should be returned from a call to `state_dict`) to load. An optimizer state (should be returned from a call to `state_dict`) to load.
move_to_device (`bool`, defaults to `True`):
Whether to move the optimizer's state to the device.
""" """
# deepcopy, to be consistent with module API # deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict) state_dict = deepcopy(state_dict)
...@@ -195,7 +197,8 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -195,7 +197,8 @@ class Optimizer8bit(torch.optim.Optimizer):
elif isinstance(value, dict): elif isinstance(value, dict):
for k, v in value.items(): for k, v in value.items():
if k in self.non_castable_tensor_keys: if k in self.non_castable_tensor_keys:
value[k] = v.to(param.device) if move_to_device:
value[k] = v.to(param.device)
else: else:
value[k] = cast(param, v) value[k] = cast(param, v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment