Commit a3ffb8a7 authored by Kexin Yu's avatar Kexin Yu
Browse files

add l2norm source for FusedLAMB

parent 04927b3a
...@@ -211,7 +211,8 @@ if "--deprecated_fused_lamb" in sys.argv: ...@@ -211,7 +211,8 @@ if "--deprecated_fused_lamb" in sys.argv:
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_lamb_cuda', CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp', sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu'], 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc':['-O3',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment