"vscode:/vscode.git/clone" did not exist on "9d4173462999fce5b0b4cf5377e13bc972369a8e"
Commit c37250ab authored by Suvrat Bhooshan's avatar Suvrat Bhooshan Committed by Facebook Github Bot
Browse files

Loading PreTrained Models (#406)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/406

Static helper function in TranslationTask to load pretrained models

Reviewed By: myleott

Differential Revision: D13345276

fbshipit-source-id: 3a675ee1a144ceb8b010f30e1a6163ef670b53f3
parent 00e47d7c
...@@ -9,7 +9,7 @@ import itertools ...@@ -9,7 +9,7 @@ import itertools
import numpy as np import numpy as np
import os import os
from fairseq import options from fairseq import options, utils
from fairseq.data import ( from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, ConcatDataset, data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
...@@ -63,6 +63,24 @@ class TranslationTask(FairseqTask): ...@@ -63,6 +63,24 @@ class TranslationTask(FairseqTask):
help='amount to upsample primary dataset') help='amount to upsample primary dataset')
# fmt: on # fmt: on
@staticmethod
def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
model = utils.load_checkpoint_to_cpu(path)
args = model['args']
state_dict = model['model']
args = utils.override_model_args(args, arg_overrides)
src_dict = Dictionary.load(src_dict_path)
tgt_dict = Dictionary.load(tgt_dict_path)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
task = TranslationTask(args, src_dict, tgt_dict)
model = task.build_model(args)
model.upgrade_state_dict(state_dict)
model.load_state_dict(state_dict, strict=True)
return model
def __init__(self, args, src_dict, tgt_dict): def __init__(self, args, src_dict, tgt_dict):
super().__init__(args) super().__init__(args)
self.src_dict = src_dict self.src_dict = src_dict
......
...@@ -131,6 +131,12 @@ def _upgrade_state_dict(state): ...@@ -131,6 +131,12 @@ def _upgrade_state_dict(state):
return state return state
def load_checkpoint_to_cpu(path):
state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
return state
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference. """Load an ensemble of models for inference.
...@@ -143,8 +149,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): ...@@ -143,8 +149,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
for filename in filenames: for filename in filenames:
if not os.path.exists(filename): if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename)) raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) state = load_checkpoint_to_cpu(filename)
state = _upgrade_state_dict(state)
states.append(state) states.append(state)
ensemble = [] ensemble = []
...@@ -152,7 +157,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): ...@@ -152,7 +157,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
args = state['args'] args = state['args']
if model_arg_overrides is not None: if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides) args = override_model_args(args, model_arg_overrides)
# build model for ensemble # build model for ensemble
model = task.build_model(args) model = task.build_model(args)
...@@ -162,12 +167,12 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): ...@@ -162,12 +167,12 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
# some args (e.g., tokens_per_sample) might have been updated while building the model # some args (e.g., tokens_per_sample) might have been updated while building the model
if model_arg_overrides is not None: if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides) args = override_model_args(args, model_arg_overrides)
return ensemble, args return ensemble, args
def _override_model_args(args, model_arg_overrides): def override_model_args(args, model_arg_overrides):
# Uses model_arg_overrides {'arg_name': arg} to override model args # Uses model_arg_overrides {'arg_name': arg} to override model args
for arg_name, arg_val in model_arg_overrides.items(): for arg_name, arg_val in model_arg_overrides.items():
setattr(args, arg_name, arg_val) setattr(args, arg_name, arg_val)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment