Commit 203e3231 authored by Hubert Lu's avatar Hubert Lu
Browse files

scaled_upper_triang_masked_softmax_cuda and scaled_masked_softmax_cuda in --cuda_ext are skipped

parent 8091b3e2
...@@ -85,7 +85,7 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -85,7 +85,7 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0: if TORCH_MAJOR == 0:
raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, " raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version__)) "found torch.__version__ = {}".format(torch.__version__))
cmdclass['build_ext'] = BuildExtension
if "--cpp_ext" in sys.argv: if "--cpp_ext" in sys.argv:
sys.argv.remove("--cpp_ext") sys.argv.remove("--cpp_ext")
ext_modules.append( ext_modules.append(
...@@ -233,7 +233,7 @@ if "--cuda_ext" in sys.argv: ...@@ -233,7 +233,7 @@ if "--cuda_ext" in sys.argv:
'csrc/fused_dense_cuda.cu'], 'csrc/fused_dense_cuda.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
"""
ext_modules.append( ext_modules.append(
CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda', CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda',
sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp', sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp',
...@@ -257,6 +257,7 @@ if "--cuda_ext" in sys.argv: ...@@ -257,6 +257,7 @@ if "--cuda_ext" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda'] + version_dependent_macros})) '--expt-extended-lambda'] + version_dependent_macros}))
"""
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
sys.argv.remove("--bnp") sys.argv.remove("--bnp")
...@@ -580,6 +581,7 @@ setup( ...@@ -580,6 +581,7 @@ setup(
'apex.egg-info',)), 'apex.egg-info',)),
description='PyTorch Extensions written by NVIDIA', description='PyTorch Extensions written by NVIDIA',
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension} if ext_modules else {}, cmdclass=cmdclass,
#cmdclass={'build_ext': BuildExtension} if ext_modules else {},
extras_require=extras, extras_require=extras,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment