Commit 69d0f7f8 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Rename _load_model_ensemble -> load_model_ensemble_and_task

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/738

Differential Revision: D16377803

Pulled By: myleott

fbshipit-source-id: 6beb2f78e7464b70ff65a965d2b747cdca0ca951
parent 7efde226
...@@ -153,11 +153,11 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None): ...@@ -153,11 +153,11 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
were used during model training were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading task (fairseq.tasks.FairseqTask, optional): task to use for loading
""" """
ensemble, args, _task = _load_model_ensemble(filenames, arg_overrides, task) ensemble, args, _task = load_model_ensemble_and_task(filenames, arg_overrides, task)
return ensemble, args return ensemble, args
def _load_model_ensemble(filenames, arg_overrides=None, task=None): def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None):
from fairseq import tasks from fairseq import tasks
ensemble = [] ensemble = []
......
...@@ -191,7 +191,7 @@ class BaseFairseqModel(nn.Module): ...@@ -191,7 +191,7 @@ class BaseFairseqModel(nn.Module):
if os.path.exists(path): if os.path.exists(path):
kwargs[arg] = path kwargs[arg] = path
models, args, task = checkpoint_utils._load_model_ensemble( models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')], [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs, arg_overrides=kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment