Commit 45388d48 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Make separate apex option for distributed lamb

parent 848a2844
...@@ -96,6 +96,25 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -96,6 +96,25 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5'] version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_lamb")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='distributed_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
...@@ -220,8 +239,6 @@ if "--deprecated_fused_lamb" in sys.argv: ...@@ -220,8 +239,6 @@ if "--deprecated_fused_lamb" in sys.argv:
CUDAExtension(name='fused_lamb_cuda', CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp', sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu', 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'], 'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment