Commit 39cd4ce2 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Load pretrained encoder or decoder (#705)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/705

This adds functionality in fairseq to load a pretrained encoder or decoder from another pretrained model into the current model.

Reviewed By: jmp84

Differential Revision: D15207084

fbshipit-source-id: 32a710ff77389928e20793c71d312863df9dd8ae
parent 7176667d
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
from typing import Union
import logging
import os
import re
......@@ -15,7 +16,7 @@ import torch
from torch.serialization import default_restore_location
from fairseq import tasks
from fairseq.models import FairseqEncoder, FairseqDecoder
def load_checkpoint_to_cpu(path):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
......@@ -175,3 +176,34 @@ def _upgrade_state_dict(state):
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state
def load_pretrained_component_from_model(
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
):
"""
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
provided `component` object. If state_dict fails to load, there may be a
mismatch in the architecture of the corresponding `component` found in the
`checkpoint` file.
"""
if not os.path.exists(checkpoint):
raise IOError('Model file not found: {}'.format(checkpoint))
state = load_checkpoint_to_cpu(checkpoint)
if isinstance(component, FairseqEncoder):
component_type = "encoder"
elif isinstance(component, FairseqDecoder):
component_type = "decoder"
else:
raise ValueError(
"component to load must be either a FairseqEncoder or "
"FairseqDecoder. Loading other component types are not supported."
)
component_state_dict = OrderedDict()
for key in state["model"].keys():
if key.startswith(component_type):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1:]
component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=True)
return component
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment