You need to sign in or sign up before continuing.
Commit e2a0b87d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Back out "reduce memory footprint for average_checkpoints" (#743)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/743

Original commit changeset: 0afe37c9a031

According to edunov: "We need to be careful here with shared parameters, I believe right now it is broken if you have shared encoder/decoder input embeddings (encoder.embed_tokens.weight and decoder.embed_tokens.weight) as they get updated several times"

We also have OSS issues that look related, e.g., https://github.com/pytorch/fairseq/issues/732.

Backing this out until we can confirm the correct behavior for shared params.

Differential Revision: D15372673

fbshipit-source-id: 8683c0f2514e21fa1e9d2fe6dfc48d98957a2831
parent e797f633
...@@ -27,8 +27,6 @@ def average_checkpoints(inputs): ...@@ -27,8 +27,6 @@ def average_checkpoints(inputs):
params_dict = collections.OrderedDict() params_dict = collections.OrderedDict()
params_keys = None params_keys = None
new_state = None new_state = None
num_models = len(inputs)
for f in inputs: for f in inputs:
state = torch.load( state = torch.load(
f, f,
...@@ -52,18 +50,20 @@ def average_checkpoints(inputs): ...@@ -52,18 +50,20 @@ def average_checkpoints(inputs):
) )
for k in params_keys: for k in params_keys:
if k not in params_dict:
params_dict[k] = []
p = model_params[k] p = model_params[k]
if isinstance(p, torch.HalfTensor): if isinstance(p, torch.HalfTensor):
p = p.float() p = p.float()
if k not in params_dict: params_dict[k].append(p)
params_dict[k] = p
else:
params_dict[k] += p
averaged_params = collections.OrderedDict() averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor. # v should be a list of torch Tensor.
for k, v in params_dict.items(): for k, v in params_dict.items():
averaged_params[k] = v / num_models summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
new_state['model'] = averaged_params new_state['model'] = averaged_params
return new_state return new_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment