Commit 746e59a2 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Improve support for `python setup.py build_ext --inplace`

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/852

Differential Revision: D17147452

Pulled By: myleott

fbshipit-source-id: 5fd9c7da3cc019c7beec98d41db1aef1329ee57a
parent c1951aa2
...@@ -173,11 +173,7 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False): ...@@ -173,11 +173,7 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray): if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes > max_positions].tolist() ignored = indices[dataset.sizes > max_positions].tolist()
indices = indices[dataset.sizes <= max_positions] indices = indices[dataset.sizes <= max_positions]
elif ( elif hasattr(dataset, 'sizes') and isinstance(dataset.sizes, list) and len(dataset.sizes) == 1:
hasattr(dataset, 'sizes') and
isinstance(dataset.sizes, list) and
len(dataset.sizes) == 1
):
ignored = indices[dataset.sizes[0] > max_positions].tolist() ignored = indices[dataset.sizes[0] > max_positions].tolist()
indices = indices[dataset.sizes[0] <= max_positions] indices = indices[dataset.sizes[0] <= max_positions]
else: else:
...@@ -221,7 +217,8 @@ def batch_by_size( ...@@ -221,7 +217,8 @@ def batch_by_size(
from fairseq.data.data_utils_fast import batch_by_size_fast from fairseq.data.data_utils_fast import batch_by_size_fast
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Please build Cython components with: `pip install --editable .`' 'Please build Cython components with: `pip install --editable .` '
'or `python setup.py build_ext --inplace`'
) )
max_tokens = max_tokens if max_tokens is not None else sys.maxsize max_tokens = max_tokens if max_tokens is not None else sys.maxsize
......
...@@ -49,7 +49,8 @@ class TokenBlockDataset(FairseqDataset): ...@@ -49,7 +49,8 @@ class TokenBlockDataset(FairseqDataset):
) )
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Please build Cython components with: `pip install --editable .`' 'Please build Cython components with: `pip install --editable .` '
'or `python setup.py build_ext --inplace`'
) )
super().__init__() super().__init__()
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from setuptools import setup, find_packages, Extension from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
import sys import sys
...@@ -46,20 +47,13 @@ extensions = [ ...@@ -46,20 +47,13 @@ extensions = [
] ]
def my_build_ext(pars): class CustomBuildExtCommand(build_ext):
""" """Source: https://stackoverflow.com/a/42163080"""
Delay loading of numpy headers. def run(self):
More details: https://stackoverflow.com/a/54138355 # Import numpy here, only when headers are needed
""" import numpy
from setuptools.command.build_ext import build_ext as _build_ext self.include_dirs.append(numpy.get_include())
super().run()
class build_ext(_build_ext):
def finalize_options(self):
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
return build_ext(pars)
setup( setup(
...@@ -105,6 +99,6 @@ setup( ...@@ -105,6 +99,6 @@ setup(
'fairseq-validate = fairseq_cli.validate:cli_main', 'fairseq-validate = fairseq_cli.validate:cli_main',
], ],
}, },
cmdclass={'build_ext': my_build_ext}, cmdclass={'build_ext': CustomBuildExtCommand},
zip_safe=False, zip_safe=False,
) )
...@@ -9,9 +9,9 @@ Train a new model on one or across multiple GPUs. ...@@ -9,9 +9,9 @@ Train a new model on one or across multiple GPUs.
import collections import collections
import math import math
import numpy as np
import random import random
import numpy as np
import torch import torch
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment