Unverified Commit c80e4cae authored by wdmwhh's avatar wdmwhh Committed by GitHub
Browse files

Fixed fp16_optimizer state bug (#580)

parent acee61d7
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import copy import copy
from collections import defaultdict
from itertools import chain
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
...@@ -67,8 +69,16 @@ class Fp16OptimizerHook(OptimizerHook): ...@@ -67,8 +69,16 @@ class Fp16OptimizerHook(OptimizerHook):
2. Convert the main model from fp32 to fp16. 2. Convert the main model from fp32 to fp16.
""" """
# keep a copy of fp32 weights # keep a copy of fp32 weights
old_groups = runner.optimizer.param_groups
runner.optimizer.param_groups = copy.deepcopy( runner.optimizer.param_groups = copy.deepcopy(
runner.optimizer.param_groups) runner.optimizer.param_groups)
state = defaultdict(dict)
p_map = {old_p: p for old_p, p in
zip(chain(*(g['params'] for g in old_groups)),
chain(*(g['params'] for g in runner.optimizer.param_groups)))}
for k, v in runner.optimizer.state.items():
state[p_map[k]] = v
runner.optimizer.state = state
# convert model to fp16 # convert model to fp16
wrap_fp16_model(runner.model) wrap_fp16_model(runner.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment