Unverified Commit 02ada95d authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #52 from ROCmSoftwarePlatform/add_distributed_fused_lamb

add distributed fused lamb
parents 95797c8d 955256d1
...@@ -163,17 +163,19 @@ if "--distributed_lamb" in sys.argv: ...@@ -163,17 +163,19 @@ if "--distributed_lamb" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
print ("INFO: Building the distributed_lamb extension.")
nvcc_args_distributed_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_distributed_lamb = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='distributed_lamb_cuda', CUDAExtension(name='distributed_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp', sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'], 'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb}))
'--use_fast_math'] + version_dependent_macros}))
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment