Unverified Commit d89f5e66 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[submodule update] Bump cudnn-frontend to v0.6.1 (#1353)

* bump version

* add guard

* fix the cond
parent 727a6452
Subproject commit 7b83dba83fa31381aeca508d89aab94f4639ac6d Subproject commit fa611998a360cbabaa2dcc7c9859748144114fc0
...@@ -58,6 +58,13 @@ def append_nvcc_threads(nvcc_extra_args): ...@@ -58,6 +58,13 @@ def append_nvcc_threads(nvcc_extra_args):
return nvcc_extra_args return nvcc_extra_args
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
green = torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= required_cudnn_version
if not green:
warnings.warn(f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later")
return green
if not torch.cuda.is_available(): if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486 # https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...@@ -649,15 +656,16 @@ if "--fast_bottleneck" in sys.argv: ...@@ -649,15 +656,16 @@ if "--fast_bottleneck" in sys.argv:
if "--fused_conv_bias_relu" in sys.argv: if "--fused_conv_bias_relu" in sys.argv:
sys.argv.remove("--fused_conv_bias_relu") sys.argv.remove("--fused_conv_bias_relu")
raise_if_cuda_home_none("--fused_conv_bias_relu") raise_if_cuda_home_none("--fused_conv_bias_relu")
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) if check_cudnn_version_and_warn("--fused_conv_bias_relu", 8400):
ext_modules.append( subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
CUDAExtension( ext_modules.append(
name="fused_conv_bias_relu", CUDAExtension(
sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"], name="fused_conv_bias_relu",
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
) )
)
setup( setup(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment