Commit 8ce2c35d authored by Yongqiang Wang's avatar Yongqiang Wang Committed by Facebook Github Bot
Browse files

Implement reducing footprint of average checkpoint correctly (#747)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/747

In https://github.com/pytorch/fairseq/pull/647, checkpoint averaging
is not Implemented correctly when it comes to shared parameters. This diff
has the right Implementation and a test case to guard future change.

Reviewed By: myleott

Differential Revision: D15402943

fbshipit-source-id: 8004836d5c2571814ea54844650618008a9ee522
parent 6b0cce84
...@@ -27,6 +27,8 @@ def average_checkpoints(inputs): ...@@ -27,6 +27,8 @@ def average_checkpoints(inputs):
params_dict = collections.OrderedDict() params_dict = collections.OrderedDict()
params_keys = None params_keys = None
new_state = None new_state = None
num_models = len(inputs)
for f in inputs: for f in inputs:
state = torch.load( state = torch.load(
f, f,
...@@ -50,20 +52,19 @@ def average_checkpoints(inputs): ...@@ -50,20 +52,19 @@ def average_checkpoints(inputs):
) )
for k in params_keys: for k in params_keys:
if k not in params_dict:
params_dict[k] = []
p = model_params[k] p = model_params[k]
if isinstance(p, torch.HalfTensor): if isinstance(p, torch.HalfTensor):
p = p.float() p = p.float()
params_dict[k].append(p) if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = collections.OrderedDict() averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items(): for k, v in params_dict.items():
summed_v = None averaged_params[k] = v
for x in v: averaged_params[k].div_(num_models)
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
new_state['model'] = averaged_params new_state['model'] = averaged_params
return new_state return new_state
......
...@@ -9,13 +9,33 @@ import collections ...@@ -9,13 +9,33 @@ import collections
import os import os
import tempfile import tempfile
import unittest import unittest
import shutil
import numpy as np import numpy as np
import torch import torch
from torch import nn
from scripts.average_checkpoints import average_checkpoints from scripts.average_checkpoints import average_checkpoints
class ModelWithSharedParameter(nn.Module):
def __init__(self):
super(ModelWithSharedParameter, self).__init__()
self.embedding = nn.Embedding(1000, 200)
self.FC1 = nn.Linear(200, 200)
self.FC2 = nn.Linear(200, 200)
# tie weight in FC2 to FC1
self.FC2.weight = nn.Parameter(self.FC1.weight)
self.FC2.bias = nn.Parameter(self.FC1.bias)
self.relu = nn.ReLU()
def forward(self, input):
return self.FC2(self.ReLU(self.FC1(input))) + self.FC1(input)
class TestAverageCheckpoints(unittest.TestCase): class TestAverageCheckpoints(unittest.TestCase):
def test_average_checkpoints(self): def test_average_checkpoints(self):
params_0 = collections.OrderedDict( params_0 = collections.OrderedDict(
...@@ -67,6 +87,60 @@ class TestAverageCheckpoints(unittest.TestCase): ...@@ -67,6 +87,60 @@ class TestAverageCheckpoints(unittest.TestCase):
err_msg='Tensor value mismatch for key {}'.format(k_expected) err_msg='Tensor value mismatch for key {}'.format(k_expected)
) )
def test_average_checkpoints_with_shared_parameters(self):
def _construct_model_with_shared_parameters(path, value):
m = ModelWithSharedParameter()
nn.init.constant_(m.FC1.weight, value)
torch.save(
{'model': m.state_dict()},
path
)
return m
tmpdir = tempfile.mkdtemp()
paths = []
path = os.path.join(tmpdir, "m1.pt")
m1 = _construct_model_with_shared_parameters(path, 1.0)
paths.append(path)
path = os.path.join(tmpdir, "m2.pt")
m2 = _construct_model_with_shared_parameters(path, 2.0)
paths.append(path)
path = os.path.join(tmpdir, "m3.pt")
m3 = _construct_model_with_shared_parameters(path, 3.0)
paths.append(path)
new_model = average_checkpoints(paths)
self.assertTrue(
torch.equal(
new_model['model']['embedding.weight'],
(m1.embedding.weight +
m2.embedding.weight +
m3.embedding.weight) / 3.0
)
)
self.assertTrue(
torch.equal(
new_model['model']['FC1.weight'],
(m1.FC1.weight +
m2.FC1.weight +
m3.FC1.weight) / 3.0
)
)
self.assertTrue(
torch.equal(
new_model['model']['FC2.weight'],
(m1.FC2.weight +
m2.FC2.weight +
m3.FC2.weight) / 3.0
)
)
shutil.rmtree(tmpdir)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment