Unverified Commit 47c269b6 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

build fused grad accum w/ wgrad only if cuda>10 (#1312)

parent ddc08039
...@@ -298,42 +298,42 @@ if "--cuda_ext" in sys.argv: ...@@ -298,42 +298,42 @@ if "--cuda_ext" in sys.argv:
) )
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag = []
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) > 0: if int(bare_metal_minor) > 0:
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86") cc_flag.append("arch=compute_86,code=sm_86")
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="fused_weight_gradient_mlp_cuda", name="fused_weight_gradient_mlp_cuda",
include_dirs=[os.path.join(this_dir, "csrc")], include_dirs=[os.path.join(this_dir, "csrc")],
sources=[ sources=[
"csrc/megatron/fused_weight_gradient_dense.cpp", "csrc/megatron/fused_weight_gradient_dense.cpp",
"csrc/megatron/fused_weight_gradient_dense_cuda.cu", "csrc/megatron/fused_weight_gradient_dense_cuda.cu",
"csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu", "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu",
], ],
extra_compile_args={ extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros, "cxx": ["-O3"] + version_dependent_macros,
"nvcc": append_nvcc_threads( "nvcc": append_nvcc_threads(
[ [
"-O3", "-O3",
"-gencode", "-gencode",
"arch=compute_70,code=sm_70", "arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr", "--expt-relaxed-constexpr",
"--expt-extended-lambda", "--expt-extended-lambda",
"--use_fast_math", "--use_fast_math",
] ]
+ version_dependent_macros + version_dependent_macros
+ cc_flag + cc_flag
), ),
}, },
)
) )
)
if "--permutation_search" in sys.argv: if "--permutation_search" in sys.argv:
sys.argv.remove("--permutation_search") sys.argv.remove("--permutation_search")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment