Commit efb43450 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix resuming training when using --memory-efficient-fp16

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/678

Differential Revision: D15956712

Pulled By: myleott

fbshipit-source-id: 5048d06ddfbec0045558a22c777a966cca1ec396
parent 39a60b84
......@@ -5,6 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from itertools import chain
import torch
from fairseq import optim, utils
......@@ -292,8 +294,28 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
# Hack: PyTorch automatically casts the optimizer state to match the
# type of the current parameters. But with --memory-efficient-fp16 the
# params are FP16 while the optimizer state is FP32 and we don't want
# to cast. A workaround is to manually copy back the original state
# after the optimizer has been loaded.
groups = self.optimizer.param_groups
saved_groups = state_dict['param_groups']
id_map = {
old_id: p
for old_id, p in zip(
chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups))
)
}
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
self.optimizer.state[param] = v
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
import unittest
import torch
from fairseq.optim.adam import FairseqAdam
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
class TestMemoryEfficientFP16(unittest.TestCase):
def test_load_state_dict(self):
# define simple FP16 model
model = torch.nn.Linear(5, 5).cuda().half()
params = list(model.parameters())
# initialize memory efficient FP16 optimizer
optimizer = FairseqAdam(
argparse.Namespace(
lr=[0.00001],
adam_betas='(0.9, 0.999)',
adam_eps=1e-8,
weight_decay=0.0,
),
params,
)
me_optimizer = MemoryEfficientFP16Optimizer(
argparse.Namespace(
fp16_init_scale=1,
fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
),
params,
optimizer,
)
# optimizer state is created in the first step
loss = model(torch.rand(5).cuda().half()).sum()
me_optimizer.backward(loss)
me_optimizer.step()
# reload state
state = me_optimizer.state_dict()
me_optimizer.load_state_dict(state)
for k, v in me_optimizer.optimizer.state.items():
self.assertTrue(k.dtype == torch.float16)
for v_i in v.values():
if torch.is_tensor(v_i):
self.assertTrue(v_i.dtype == torch.float32)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment