Commit 39cd4ce2 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Load pretrained encoder or decoder (#705)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/705

This adds functionality in fairseq to load a pretrained encoder or decoder from another pretrained model into the current model.

Reviewed By: jmp84

Differential Revision: D15207084

fbshipit-source-id: 32a710ff77389928e20793c71d312863df9dd8ae
parent 7176667d
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import OrderedDict from collections import OrderedDict
from typing import Union
import logging import logging
import os import os
import re import re
...@@ -15,7 +16,7 @@ import torch ...@@ -15,7 +16,7 @@ import torch
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import tasks from fairseq import tasks
from fairseq.models import FairseqEncoder, FairseqDecoder
def load_checkpoint_to_cpu(path): def load_checkpoint_to_cpu(path):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).""" """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
...@@ -175,3 +176,34 @@ def _upgrade_state_dict(state): ...@@ -175,3 +176,34 @@ def _upgrade_state_dict(state):
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0), 'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
} }
return state return state
def load_pretrained_component_from_model(
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
):
"""
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
provided `component` object. If state_dict fails to load, there may be a
mismatch in the architecture of the corresponding `component` found in the
`checkpoint` file.
"""
if not os.path.exists(checkpoint):
raise IOError('Model file not found: {}'.format(checkpoint))
state = load_checkpoint_to_cpu(checkpoint)
if isinstance(component, FairseqEncoder):
component_type = "encoder"
elif isinstance(component, FairseqDecoder):
component_type = "decoder"
else:
raise ValueError(
"component to load must be either a FairseqEncoder or "
"FairseqDecoder. Loading other component types are not supported."
)
component_state_dict = OrderedDict()
for key in state["model"].keys():
if key.startswith(component_type):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1:]
component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=True)
return component
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment