Commit 39a65c92 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Add IS_ROCM_PYTORCH if statement for some newly-added extensions

parent 1436a66a
...@@ -357,7 +357,7 @@ if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')) ...@@ -357,7 +357,7 @@ if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h'))
if "--fast_layer_norm" in sys.argv: if "--fast_layer_norm" in sys.argv:
sys.argv.remove("--fast_layer_norm") sys.argv.remove("--fast_layer_norm")
if CUDA_HOME is None: if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
...@@ -386,7 +386,7 @@ if "--fast_layer_norm" in sys.argv: ...@@ -386,7 +386,7 @@ if "--fast_layer_norm" in sys.argv:
if "--fmha" in sys.argv: if "--fmha" in sys.argv:
sys.argv.remove("--fmha") sys.argv.remove("--fmha")
if CUDA_HOME is None: if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
...@@ -523,7 +523,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -523,7 +523,7 @@ if "--fast_multihead_attn" in sys.argv:
if "--transducer" in sys.argv: if "--transducer" in sys.argv:
sys.argv.remove("--transducer") sys.argv.remove("--transducer")
if CUDA_HOME is None: if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
...@@ -544,7 +544,7 @@ if "--transducer" in sys.argv: ...@@ -544,7 +544,7 @@ if "--transducer" in sys.argv:
if "--fast_bottleneck" in sys.argv: if "--fast_bottleneck" in sys.argv:
sys.argv.remove("--fast_bottleneck") sys.argv.remove("--fast_bottleneck")
if CUDA_HOME is None: if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment