Commit d2410c42 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Minor cleanup for setup.py

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1078

Differential Revision: D17072514

Pulled By: myleott

fbshipit-source-id: 69a8c8c9cc7caa7e04c414329a5d79e6e1a6621c
parent 920b85d4
...@@ -10,11 +10,11 @@ except ImportError: ...@@ -10,11 +10,11 @@ except ImportError:
import contextlib import contextlib
import itertools import itertools
import os import os
import numpy as np
import sys import sys
import types import types
import numpy as np
def infer_language_pair(path): def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx""" """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
...@@ -204,12 +204,14 @@ def batch_by_size( ...@@ -204,12 +204,14 @@ def batch_by_size(
raise ImportError( raise ImportError(
'Please build Cython components with: `pip install --editable .`' 'Please build Cython components with: `pip install --editable .`'
) )
max_tokens = max_tokens if max_tokens is not None else sys.maxsize max_tokens = max_tokens if max_tokens is not None else sys.maxsize
max_sentences = max_sentences if max_sentences is not None else sys.maxsize max_sentences = max_sentences if max_sentences is not None else sys.maxsize
bsz_mult = required_batch_size_multiple bsz_mult = required_batch_size_multiple
if isinstance(indices, types.GeneratorType): if isinstance(indices, types.GeneratorType):
indices = np.fromiter(indices, dtype=np.int64, count=-1) indices = np.fromiter(indices, dtype=np.int64, count=-1)
return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult) return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)
......
...@@ -11,6 +11,7 @@ from fairseq.models import MODEL_REGISTRY ...@@ -11,6 +11,7 @@ from fairseq.models import MODEL_REGISTRY
dependencies = [ dependencies = [
'numpy',
'regex', 'regex',
'requests', 'requests',
'torch', 'torch',
......
...@@ -11,47 +11,45 @@ import sys ...@@ -11,47 +11,45 @@ import sys
if sys.version_info < (3,): if sys.version_info < (3,):
sys.exit('Sorry, Python3 is required for fairseq.') sys.exit('Sorry, Python3 is required for fairseq.')
with open('README.md') as f: with open('README.md') as f:
readme = f.read() readme = f.read()
if sys.platform == 'darwin': if sys.platform == 'darwin':
extra_compile_args = ['-stdlib=libc++', '-O3'] extra_compile_args = ['-stdlib=libc++', '-O3']
extra_link_args = ['-stdlib=libc++']
else: else:
extra_compile_args = ['-std=c++11', '-O3'] extra_compile_args = ['-std=c++11', '-O3']
extra_link_args = ['-std=c++11']
bleu = Extension(
'fairseq.libbleu',
sources=[
'fairseq/clib/libbleu/libbleu.cpp',
'fairseq/clib/libbleu/module.cpp',
],
extra_compile_args=extra_compile_args,
)
def get_cython_modules(): extensions = [
token_block_utils = Extension( Extension(
"fairseq.data.token_block_utils_fast", 'fairseq.libbleu',
["fairseq/data/token_block_utils_fast.pyx"], sources=[
'fairseq/clib/libbleu/libbleu.cpp',
'fairseq/clib/libbleu/module.cpp',
],
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, ),
) Extension(
data_utils_fast = Extension( 'fairseq.data.data_utils_fast',
"fairseq.data.data_utils_fast", sources=['fairseq/data/data_utils_fast.pyx'],
["fairseq/data/data_utils_fast.pyx"], language='c++',
language="c++",
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, ),
) Extension(
return [token_block_utils, data_utils_fast] 'fairseq.data.token_block_utils_fast',
sources=['fairseq/data/token_block_utils_fast.pyx'],
language='c++',
extra_compile_args=extra_compile_args,
),
]
def my_build_ext(pars): def my_build_ext(pars):
""" """
Delay loading of numpy headers. Delay loading of numpy headers.
More details: https://stackoverflow.com/questions/54117786/add-numpy-get-include-argument-to-setuptools-without-preinstalled-numpy More details: https://stackoverflow.com/a/54138355
""" """
from setuptools.command.build_ext import build_ext as _build_ext from setuptools.command.build_ext import build_ext as _build_ext
...@@ -81,6 +79,7 @@ setup( ...@@ -81,6 +79,7 @@ setup(
setup_requires=[ setup_requires=[
'numpy', 'numpy',
'cython', 'cython',
'numpy',
'setuptools>=18.0', 'setuptools>=18.0',
], ],
install_requires=[ install_requires=[
...@@ -93,7 +92,7 @@ setup( ...@@ -93,7 +92,7 @@ setup(
'tqdm', 'tqdm',
], ],
packages=find_packages(exclude=['scripts', 'tests']), packages=find_packages(exclude=['scripts', 'tests']),
ext_modules=get_cython_modules() + [bleu], ext_modules=extensions,
test_suite='tests', test_suite='tests',
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment