Unverified Commit 8fc78924 authored by Koute's avatar Koute Committed by GitHub
Browse files

Add `move_to_device` kwarg to the optimizer's `load_state_dict` (#1344)

This makes it possible to load an optimizer checkpoint without
automatically moving the optimizer's state to the GPU.
parent abb0c32a
......@@ -153,12 +153,14 @@ class Optimizer8bit(torch.optim.Optimizer):
def __setstate__(self, state):
super().__setstate__(state)
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict, move_to_device=True):
"""Load an optimizer state.
Arguments:
state_dict (`dict`):
An optimizer state (should be returned from a call to `state_dict`) to load.
move_to_device (`bool`, defaults to `True`):
Whether to move the optimizer's state to the device.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
......@@ -195,7 +197,8 @@ class Optimizer8bit(torch.optim.Optimizer):
elif isinstance(value, dict):
for k, v in value.items():
if k in self.non_castable_tensor_keys:
value[k] = v.to(param.device)
if move_to_device:
value[k] = v.to(param.device)
else:
value[k] = cast(param, v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment