Commit 36df0dad authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix RoBERTa model import (fixes #918)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/920

Differential Revision: D16540932

Pulled By: myleott

fbshipit-source-id: b64438ad8651ecc8fe8904c5f69fa6111b4bed64
parent ce7f044b
......@@ -122,8 +122,10 @@ def register_model_architecture(model_name, arch_name):
# automatically import any Python files in the models/ directory
for file in os.listdir(os.path.dirname(__file__)):
if not file.startswith('_'):
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if not file.startswith('_') and (file.endswith('.py') or os.path.isdir(path)):
model_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('fairseq.models.' + model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment