Commit c37250ab authored by Suvrat Bhooshan's avatar Suvrat Bhooshan Committed by Facebook Github Bot
Browse files

Loading PreTrained Models (#406)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/406

Static helper function in TranslationTask to load pretrained models

Reviewed By: myleott

Differential Revision: D13345276

fbshipit-source-id: 3a675ee1a144ceb8b010f30e1a6163ef670b53f3
parent 00e47d7c
......@@ -9,7 +9,7 @@ import itertools
import numpy as np
import os
from fairseq import options
from fairseq import options, utils
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
......@@ -63,6 +63,24 @@ class TranslationTask(FairseqTask):
help='amount to upsample primary dataset')
# fmt: on
@staticmethod
def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
model = utils.load_checkpoint_to_cpu(path)
args = model['args']
state_dict = model['model']
args = utils.override_model_args(args, arg_overrides)
src_dict = Dictionary.load(src_dict_path)
tgt_dict = Dictionary.load(tgt_dict_path)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
task = TranslationTask(args, src_dict, tgt_dict)
model = task.build_model(args)
model.upgrade_state_dict(state_dict)
model.load_state_dict(state_dict, strict=True)
return model
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
......
......@@ -131,6 +131,12 @@ def _upgrade_state_dict(state):
return state
def load_checkpoint_to_cpu(path):
state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
return state
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
......@@ -143,8 +149,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
state = load_checkpoint_to_cpu(filename)
states.append(state)
ensemble = []
......@@ -152,7 +157,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
args = state['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
args = override_model_args(args, model_arg_overrides)
# build model for ensemble
model = task.build_model(args)
......@@ -162,12 +167,12 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
# some args (e.g., tokens_per_sample) might have been updated while building the model
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
args = override_model_args(args, model_arg_overrides)
return ensemble, args
def _override_model_args(args, model_arg_overrides):
def override_model_args(args, model_arg_overrides):
# Uses model_arg_overrides {'arg_name': arg} to override model args
for arg_name, arg_val in model_arg_overrides.items():
setattr(args, arg_name, arg_val)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment