Commit dc08ea1c authored by Tri Dao's avatar Tri Dao
Browse files

Support H100 for other CUDA extensions

parent 1b18f1b7
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess
import sys import sys
import warnings import warnings
import os import os
from packaging.version import parse, Version
from setuptools import setup, find_packages
import subprocess
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
...@@ -16,22 +19,19 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -16,22 +19,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split() output = raw_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
release = output[release_idx].split(".") bare_metal_version = parse(output[release_idx].split(",")[0])
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_version
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_version = parse(torch.version.cuda)
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with") print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n") print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): if (bare_metal_version != torch_binary_version):
raise RuntimeError( raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does " "Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. " "not match the version used to compile Pytorch binaries. "
...@@ -53,8 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None: ...@@ -53,8 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args): def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args return nvcc_extra_args
...@@ -72,15 +72,18 @@ if not torch.cuda.is_available(): ...@@ -72,15 +72,18 @@ if not torch.cuda.is_available():
"If you wish to cross-compile for a single specific architecture,\n" "If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
) )
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11: if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else: else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1]) TORCH_MINOR = int(torch.__version__.split(".")[1])
...@@ -98,10 +101,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl. ...@@ -98,10 +101,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
raise_if_cuda_home_none("--ft_attention") raise_if_cuda_home_none("--ft_attention")
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
# cc_flag.append("-gencode") _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
# cc_flag.append("arch=compute_70,code=sm_70") if bare_metal_version < Version("11.0"):
raise RuntimeError("ft_attention is only supported on CUDA 11 and above")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
......
import os import os
import subprocess import subprocess
from packaging.version import parse, Version
import torch import torch
from setuptools import setup from setuptools import setup
...@@ -10,16 +11,14 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -10,16 +11,14 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split() output = raw_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
release = output[release_idx].split(".") bare_metal_version = parse(output[release_idx].split(",")[0])
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_version
def append_nvcc_threads(nvcc_extra_args): def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args return nvcc_extra_args
......
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import sys
import warnings
import os
from packaging.version import parse, Version
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
import sys
import warnings
import os
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
...@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split() output = raw_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
release = output[release_idx].split(".") bare_metal_version = parse(output[release_idx].split(",")[0])
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_version
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_version = parse(torch.version.cuda)
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with") print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n") print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): if (bare_metal_version != torch_binary_version):
raise RuntimeError( raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does " "Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. " "not match the version used to compile Pytorch binaries. "
...@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None: ...@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args): def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args return nvcc_extra_args
...@@ -72,15 +70,18 @@ if not torch.cuda.is_available(): ...@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
"If you wish to cross-compile for a single specific architecture,\n" "If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
) )
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11: if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else: else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1]) TORCH_MINOR = int(torch.__version__.split(".")[1])
...@@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl. ...@@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
raise_if_cuda_home_none("--fast_layer_norm") raise_if_cuda_home_none("--fast_layer_norm")
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("dropout_layer_norm is only supported on CUDA 11 and above")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
......
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import sys
import warnings
import os
from packaging.version import parse, Version
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
import sys
import warnings
import os
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
...@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split() output = raw_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
release = output[release_idx].split(".") bare_metal_version = parse(output[release_idx].split(",")[0])
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_version
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_version = parse(torch.version.cuda)
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with") print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n") print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): if (bare_metal_version != torch_binary_version):
raise RuntimeError( raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does " "Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. " "not match the version used to compile Pytorch binaries. "
...@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None: ...@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args): def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args return nvcc_extra_args
...@@ -72,15 +70,18 @@ if not torch.cuda.is_available(): ...@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
"If you wish to cross-compile for a single specific architecture,\n" "If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
) )
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11: if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else: else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1]) TORCH_MINOR = int(torch.__version__.split(".")[1])
...@@ -91,10 +92,16 @@ ext_modules = [] ...@@ -91,10 +92,16 @@ ext_modules = []
raise_if_cuda_home_none("rotary_emb") raise_if_cuda_home_none("rotary_emb")
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("rotary_emb is only supported on CUDA 11 and above")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
......
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import sys
import warnings
import os
from packaging.version import parse, Version
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
import sys
import warnings
import os
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
...@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split() output = raw_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
release = output[release_idx].split(".") bare_metal_version = parse(output[release_idx].split(",")[0])
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_version
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_version = parse(torch.version.cuda)
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with") print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n") print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): if (bare_metal_version != torch_binary_version):
raise RuntimeError( raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does " "Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. " "not match the version used to compile Pytorch binaries. "
...@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None: ...@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args): def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args return nvcc_extra_args
...@@ -72,15 +70,18 @@ if not torch.cuda.is_available(): ...@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
"If you wish to cross-compile for a single specific architecture,\n" "If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
) )
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11: if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else: else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1]) TORCH_MINOR = int(torch.__version__.split(".")[1])
...@@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl. ...@@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
raise_if_cuda_home_none("--xentropy") raise_if_cuda_home_none("--xentropy")
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("xentropy is only supported on CUDA 11 and above")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
......
...@@ -421,6 +421,8 @@ class FusedMLP(nn.Module): ...@@ -421,6 +421,8 @@ class FusedMLP(nn.Module):
'auto': heuristic will be picked automatically: 'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
is slower than the unfused version.
return_residual: whether to return the input x along with the output. This is for return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection. to fuse the backward of nn.Linear with the residual connection.
...@@ -442,8 +444,11 @@ class FusedMLP(nn.Module): ...@@ -442,8 +444,11 @@ class FusedMLP(nn.Module):
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
if self.heuristic == 'auto': if self.heuristic == 'auto':
if self.activation == 'gelu_approx': if self.activation == 'gelu_approx':
cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) if torch.cuda.get_device_capability('cuda') == (9, 0):
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) heuristic = -1
else:
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
else: else:
heuristic = 0 heuristic = 0
else: else:
......
...@@ -108,7 +108,7 @@ raise_if_cuda_home_none("flash_attn") ...@@ -108,7 +108,7 @@ raise_if_cuda_home_none("flash_attn")
cc_flag = [] cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"): if bare_metal_version < Version("11.0"):
raise RuntimeError("FlashAttention is only supported on CUDA 11") raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode") cc_flag.append("-gencode")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment