Commit e2a0b87d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Back out "reduce memory footprint for average_checkpoints" (#743)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/743

Original commit changeset: 0afe37c9a031

According to edunov: "We need to be careful here with shared parameters, I believe right now it is broken if you have shared encoder/decoder input embeddings (encoder.embed_tokens.weight and decoder.embed_tokens.weight) as they get updated several times"

We also have OSS issues that look related, e.g., https://github.com/pytorch/fairseq/issues/732.

Backing this out until we can confirm the correct behavior for shared params.

Differential Revision: D15372673

fbshipit-source-id: 8683c0f2514e21fa1e9d2fe6dfc48d98957a2831
parent e797f633
......@@ -27,8 +27,6 @@ def average_checkpoints(inputs):
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for f in inputs:
state = torch.load(
f,
......@@ -52,18 +50,20 @@ def average_checkpoints(inputs):
)
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
if k not in params_dict:
params_dict[k] = p
else:
params_dict[k] += p
params_dict[k].append(p)
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items():
averaged_params[k] = v / num_models
summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
new_state['model'] = averaged_params
return new_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment