Unverified Commit fb49b515 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] optim/oss: fix state cast (#56)

Workaround PyTorch bug that casts state (pytorch/pytorch#43706).

Copied from https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/optim/fp16_optimizer.py#L251-L268
parent e4a0804c
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy import copy
from itertools import chain
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
...@@ -140,6 +141,19 @@ class OSS(Optimizer): ...@@ -140,6 +141,19 @@ class OSS(Optimizer):
self.optim.load_state_dict(state_dict) self.optim.load_state_dict(state_dict)
# Workaround PyTorch bug that casts state (https://github.com/pytorch/pytorch/issues/43706)
# Copied from https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/optim/fp16_optimizer.py#L251-L268
groups = self.optim.param_groups
saved_groups = state_dict["param_groups"]
id_map = {
old_id: p
for old_id, p in zip(chain(*(g["params"] for g in saved_groups)), chain(*(g["params"] for g in groups)))
}
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """ """ Restore the global parameter groups as well as the shard """
# Dispatch this rank's state dictionary to the wrapped shard optimizer # Dispatch this rank's state dictionary to the wrapped shard optimizer
......
...@@ -7,6 +7,7 @@ _params_t = Union[Iterable[Tensor], Iterable[dict]] ...@@ -7,6 +7,7 @@ _params_t = Union[Iterable[Tensor], Iterable[dict]]
class Optimizer(object): class Optimizer(object):
param_groups: List[dict] param_groups: List[dict]
state: dict
def __init__(self, params: _params_t, defaults: dict) -> None: ... def __init__(self, params: _params_t, defaults: dict) -> None: ...
def state_dict(self) -> dict: ... def state_dict(self) -> dict: ...
def load_state_dict(self, state_dict: dict) -> None: ... def load_state_dict(self, state_dict: dict) -> None: ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment