Commit 746e59a2 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Improve support for `python setup.py build_ext --inplace`

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/852

Differential Revision: D17147452

Pulled By: myleott

fbshipit-source-id: 5fd9c7da3cc019c7beec98d41db1aef1329ee57a
parent c1951aa2
......@@ -173,11 +173,7 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes > max_positions].tolist()
indices = indices[dataset.sizes <= max_positions]
elif (
hasattr(dataset, 'sizes') and
isinstance(dataset.sizes, list) and
len(dataset.sizes) == 1
):
elif hasattr(dataset, 'sizes') and isinstance(dataset.sizes, list) and len(dataset.sizes) == 1:
ignored = indices[dataset.sizes[0] > max_positions].tolist()
indices = indices[dataset.sizes[0] <= max_positions]
else:
......@@ -221,7 +217,8 @@ def batch_by_size(
from fairseq.data.data_utils_fast import batch_by_size_fast
except ImportError:
raise ImportError(
'Please build Cython components with: `pip install --editable .`'
'Please build Cython components with: `pip install --editable .` '
'or `python setup.py build_ext --inplace`'
)
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
......
......@@ -49,7 +49,8 @@ class TokenBlockDataset(FairseqDataset):
)
except ImportError:
raise ImportError(
'Please build Cython components with: `pip install --editable .`'
'Please build Cython components with: `pip install --editable .` '
'or `python setup.py build_ext --inplace`'
)
super().__init__()
......
......@@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
import sys
......@@ -46,20 +47,13 @@ extensions = [
]
def my_build_ext(pars):
"""
Delay loading of numpy headers.
More details: https://stackoverflow.com/a/54138355
"""
from setuptools.command.build_ext import build_ext as _build_ext
class build_ext(_build_ext):
def finalize_options(self):
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
return build_ext(pars)
class CustomBuildExtCommand(build_ext):
"""Source: https://stackoverflow.com/a/42163080"""
def run(self):
# Import numpy here, only when headers are needed
import numpy
self.include_dirs.append(numpy.get_include())
super().run()
setup(
......@@ -105,6 +99,6 @@ setup(
'fairseq-validate = fairseq_cli.validate:cli_main',
],
},
cmdclass={'build_ext': my_build_ext},
cmdclass={'build_ext': CustomBuildExtCommand},
zip_safe=False,
)
......@@ -9,9 +9,9 @@ Train a new model on one or across multiple GPUs.
import collections
import math
import numpy as np
import random
import numpy as np
import torch
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment