Unverified Commit d89f5e66 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[submodule update] Bump cudnn-frontend to v0.6.1 (#1353)

* bump version

* add guard

* fix the cond
parent 727a6452
Subproject commit 7b83dba83fa31381aeca508d89aab94f4639ac6d
Subproject commit fa611998a360cbabaa2dcc7c9859748144114fc0
......@@ -58,6 +58,13 @@ def append_nvcc_threads(nvcc_extra_args):
return nvcc_extra_args
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
green = torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= required_cudnn_version
if not green:
warnings.warn(f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later")
return green
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
......@@ -649,6 +656,7 @@ if "--fast_bottleneck" in sys.argv:
if "--fused_conv_bias_relu" in sys.argv:
sys.argv.remove("--fused_conv_bias_relu")
raise_if_cuda_home_none("--fused_conv_bias_relu")
if check_cudnn_version_and_warn("--fused_conv_bias_relu", 8400):
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
ext_modules.append(
CUDAExtension(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment