Commit 8ce2c35d authored by Yongqiang Wang's avatar Yongqiang Wang Committed by Facebook Github Bot
Browse files

Implement reducing footprint of average checkpoint correctly (#747)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/747

In https://github.com/pytorch/fairseq/pull/647, checkpoint averaging
is not Implemented correctly when it comes to shared parameters. This diff
has the right Implementation and a test case to guard future change.

Reviewed By: myleott

Differential Revision: D15402943

fbshipit-source-id: 8004836d5c2571814ea54844650618008a9ee522
parent 6b0cce84
......@@ -27,6 +27,8 @@ def average_checkpoints(inputs):
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for f in inputs:
state = torch.load(
f,
......@@ -50,20 +52,19 @@ def average_checkpoints(inputs):
)
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
params_dict[k].append(p)
if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items():
summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
averaged_params[k] = v
averaged_params[k].div_(num_models)
new_state['model'] = averaged_params
return new_state
......
......@@ -9,13 +9,33 @@ import collections
import os
import tempfile
import unittest
import shutil
import numpy as np
import torch
from torch import nn
from scripts.average_checkpoints import average_checkpoints
class ModelWithSharedParameter(nn.Module):
def __init__(self):
super(ModelWithSharedParameter, self).__init__()
self.embedding = nn.Embedding(1000, 200)
self.FC1 = nn.Linear(200, 200)
self.FC2 = nn.Linear(200, 200)
# tie weight in FC2 to FC1
self.FC2.weight = nn.Parameter(self.FC1.weight)
self.FC2.bias = nn.Parameter(self.FC1.bias)
self.relu = nn.ReLU()
def forward(self, input):
return self.FC2(self.ReLU(self.FC1(input))) + self.FC1(input)
class TestAverageCheckpoints(unittest.TestCase):
def test_average_checkpoints(self):
params_0 = collections.OrderedDict(
......@@ -67,6 +87,60 @@ class TestAverageCheckpoints(unittest.TestCase):
err_msg='Tensor value mismatch for key {}'.format(k_expected)
)
def test_average_checkpoints_with_shared_parameters(self):
def _construct_model_with_shared_parameters(path, value):
m = ModelWithSharedParameter()
nn.init.constant_(m.FC1.weight, value)
torch.save(
{'model': m.state_dict()},
path
)
return m
tmpdir = tempfile.mkdtemp()
paths = []
path = os.path.join(tmpdir, "m1.pt")
m1 = _construct_model_with_shared_parameters(path, 1.0)
paths.append(path)
path = os.path.join(tmpdir, "m2.pt")
m2 = _construct_model_with_shared_parameters(path, 2.0)
paths.append(path)
path = os.path.join(tmpdir, "m3.pt")
m3 = _construct_model_with_shared_parameters(path, 3.0)
paths.append(path)
new_model = average_checkpoints(paths)
self.assertTrue(
torch.equal(
new_model['model']['embedding.weight'],
(m1.embedding.weight +
m2.embedding.weight +
m3.embedding.weight) / 3.0
)
)
self.assertTrue(
torch.equal(
new_model['model']['FC1.weight'],
(m1.FC1.weight +
m2.FC1.weight +
m3.FC1.weight) / 3.0
)
)
self.assertTrue(
torch.equal(
new_model['model']['FC2.weight'],
(m1.FC2.weight +
m2.FC2.weight +
m3.FC2.weight) / 3.0
)
)
shutil.rmtree(tmpdir)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment