Unverified Commit fea4662b authored by robotcator's avatar robotcator Committed by GitHub
Browse files

add --disable_extension flags to disable cuda extension (#21)

* add setup flag

* use sys argv rather than env variable

* add readme

* refine  code & readme

* refine
parent a44b9c68
...@@ -20,6 +20,7 @@ Installation ...@@ -20,6 +20,7 @@ Installation
You can use `python setup.py install` or `pip install .` to build Uni-Core from source. The CUDA version in the build environment should be the same as the one in PyTorch. You can use `python setup.py install` or `pip install .` to build Uni-Core from source. The CUDA version in the build environment should be the same as the one in PyTorch.
You can also use `python setup.py install --disable-cuda-ext` to disalbe the cuda extension operator when cuda is not available.
**Use pre-compiled python wheels** **Use pre-compiled python wheels**
......
...@@ -14,6 +14,16 @@ import sys ...@@ -14,6 +14,16 @@ import sys
from setuptools import find_packages, setup from setuptools import find_packages, setup
DISABLE_CUDA_EXTENSION = False
filtered_args = []
for i, arg in enumerate(sys.argv):
if arg == '--disable-cuda-ext':
DISABLE_CUDA_EXTENSION = True
continue
filtered_args.append(arg)
sys.argv = filtered_args
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
sys.exit("Sorry, Python >= 3.7 is required for unicore.") sys.exit("Sorry, Python >= 3.7 is required for unicore.")
...@@ -31,7 +41,7 @@ def write_version_py(): ...@@ -31,7 +41,7 @@ def write_version_py():
version = write_version_py() version = write_version_py()
# ninja build does not work unless include_dirs are abs path # # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
...@@ -44,7 +54,7 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -44,7 +54,7 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_major, bare_metal_minor
if not torch.cuda.is_available(): if not torch.cuda.is_available() and not DISABLE_CUDA_EXTENSION:
print('\nWarning: Torch did not find available GPUs on this system.\n', print('\nWarning: Torch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n' 'If your intention is to cross-compile, this is not an error.\n'
'By default, it will cross-compile for Volta (compute capability 7.0), Turing (compute capability 7.5),\n' 'By default, it will cross-compile for Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
...@@ -71,8 +81,8 @@ ext_modules = [] ...@@ -71,8 +81,8 @@ ext_modules = []
extras = {} extras = {}
if not DISABLE_CUDA_EXTENSION:
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split() output = raw_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
...@@ -82,7 +92,7 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -82,7 +92,7 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1] torch_binary_minor = torch.version.cuda.split(".")[1]
...@@ -97,20 +107,20 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): ...@@ -97,20 +107,20 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("Nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("Nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) # check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
generator_flag = [] generator_flag = []
torch_dir = torch.__path__[0] torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
generator_flag = ['-DOLD_GENERATOR'] generator_flag = ['-DOLD_GENERATOR']
ext_modules.append( ext_modules.append(
CUDAExtension(name='unicore_fused_rounding', CUDAExtension(name='unicore_fused_rounding',
sources=['csrc/rounding/interface.cpp', sources=['csrc/rounding/interface.cpp',
'csrc/rounding/fp32_to_bf16.cu'], 'csrc/rounding/fp32_to_bf16.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
...@@ -125,7 +135,7 @@ CUDAExtension(name='unicore_fused_rounding', ...@@ -125,7 +135,7 @@ CUDAExtension(name='unicore_fused_rounding',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda'] + generator_flag})) '--expt-extended-lambda'] + generator_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='unicore_fused_multi_tensor', CUDAExtension(name='unicore_fused_multi_tensor',
sources=['csrc/multi_tensor/interface.cpp', sources=['csrc/multi_tensor/interface.cpp',
'csrc/multi_tensor/multi_tensor_l2norm_kernel.cu'], 'csrc/multi_tensor/multi_tensor_l2norm_kernel.cu'],
...@@ -143,7 +153,7 @@ ext_modules.append( ...@@ -143,7 +153,7 @@ ext_modules.append(
})) }))
ext_modules.append( ext_modules.append(
CUDAExtension(name='unicore_fused_adam', CUDAExtension(name='unicore_fused_adam',
sources=['csrc/adam/interface.cpp', sources=['csrc/adam/interface.cpp',
'csrc/adam/adam_kernel.cu'], 'csrc/adam/adam_kernel.cu'],
...@@ -151,7 +161,7 @@ ext_modules.append( ...@@ -151,7 +161,7 @@ ext_modules.append(
extra_compile_args={'cxx': ['-O3'], extra_compile_args={'cxx': ['-O3'],
'nvcc':['-O3', '--use_fast_math']})) 'nvcc':['-O3', '--use_fast_math']}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='unicore_fused_softmax_dropout', CUDAExtension(name='unicore_fused_softmax_dropout',
sources=['csrc/softmax_dropout/interface.cpp', sources=['csrc/softmax_dropout/interface.cpp',
'csrc/softmax_dropout/softmax_dropout_kernel.cu'], 'csrc/softmax_dropout/softmax_dropout_kernel.cu'],
...@@ -168,7 +178,7 @@ ext_modules.append( ...@@ -168,7 +178,7 @@ ext_modules.append(
'--expt-extended-lambda'] + generator_flag})) '--expt-extended-lambda'] + generator_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='unicore_fused_layernorm', CUDAExtension(name='unicore_fused_layernorm',
sources=['csrc/layernorm/interface.cpp', sources=['csrc/layernorm/interface.cpp',
'csrc/layernorm/layernorm.cu'], 'csrc/layernorm/layernorm.cu'],
...@@ -185,7 +195,7 @@ ext_modules.append( ...@@ -185,7 +195,7 @@ ext_modules.append(
'--expt-extended-lambda'] + generator_flag})) '--expt-extended-lambda'] + generator_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='unicore_fused_layernorm_backward_gamma_beta', CUDAExtension(name='unicore_fused_layernorm_backward_gamma_beta',
sources=['csrc/layernorm/interface_gamma_beta.cpp', sources=['csrc/layernorm/interface_gamma_beta.cpp',
'csrc/layernorm/layernorm_backward.cu'], 'csrc/layernorm/layernorm_backward.cu'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment