Unverified Commit da74c09d authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

Trivial update to calculate target arch (#702)



* Trivial update to calculate target arch

* Update tilelang/contrib/nvrtc.py
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* fmt

---------
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 6f59668d
...@@ -52,9 +52,9 @@ def compile_cuda(code, ...@@ -52,9 +52,9 @@ def compile_cuda(code,
# "-gencode", "arch=compute_52,code=sm_52", # "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70" # "-gencode", "arch=compute_70,code=sm_70"
# ] # ]
compute_version = "".join( compute_version = get_target_compute_version(Target.current(allow_none=True))
get_target_compute_version(Target.current(allow_none=True)).split(".")) target_arch = get_target_arch(compute_version)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] arch = ["-gencode", f"arch=compute_{target_arch},code=sm_{target_arch}"]
temp = utils.tempdir() temp = utils.tempdir()
file_name = "tvm_kernels" file_name = "tvm_kernels"
...@@ -298,7 +298,7 @@ def get_target_compute_version(target=None): ...@@ -298,7 +298,7 @@ def get_target_compute_version(target=None):
"Try specifying it by adding '-arch=sm_xx' to your target.") "Try specifying it by adding '-arch=sm_xx' to your target.")
def parse_compute_version(compute_version): def parse_compute_version(compute_version) -> tuple[int, int]:
"""Parse compute capability string to divide major and minor version """Parse compute capability string to divide major and minor version
Parameters Parameters
...@@ -323,6 +323,14 @@ def parse_compute_version(compute_version): ...@@ -323,6 +323,14 @@ def parse_compute_version(compute_version):
raise RuntimeError("Compute version parsing error") from err raise RuntimeError("Compute version parsing error") from err
def get_target_arch(compute_version) -> str:
major, minor = parse_compute_version(compute_version)
target_arch = str(major * 10 + minor)
if major >= 9:
target_arch += "a"
return target_arch
def have_fp16(compute_version): def have_fp16(compute_version):
"""Either fp16 support is provided in the compute capability or not """Either fp16 support is provided in the compute capability or not
......
import cuda.bindings.nvrtc as nvrtc import cuda.bindings.nvrtc as nvrtc
from typing import Literal, Union, List, Optional, Tuple from typing import Literal, Union, List, Optional, Tuple
from tvm.target import Target from tvm.target import Target
from .nvcc import get_target_compute_version from .nvcc import get_target_compute_version, parse_compute_version
def get_nvrtc_version() -> Tuple[int, int]: def get_nvrtc_version() -> Tuple[int, int]:
...@@ -42,9 +42,9 @@ def compile_cuda(code: str, ...@@ -42,9 +42,9 @@ def compile_cuda(code: str,
if arch is None: if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`. # If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "80", "90", "90a", etc. # Target arch could be a str like "80", "90", "90a", etc.
compute_version = "".join( major, minor = parse_compute_version(
get_target_compute_version(Target.current(allow_none=True)).split(".")) get_target_compute_version(Target.current(allow_none=True)))
arch = int(compute_version) arch = major * 10 + minor
prefix = "compute" if target_format == "ptx" else "sm" prefix = "compute" if target_format == "ptx" else "sm"
suffix = "a" if arch >= 90 else "" suffix = "a" if arch >= 90 else ""
arch_option = f"--gpu-architecture={prefix}_{arch}{suffix}" arch_option = f"--gpu-architecture={prefix}_{arch}{suffix}"
......
...@@ -64,15 +64,10 @@ def tilelang_callback_cuda_compile(code, target): ...@@ -64,15 +64,10 @@ def tilelang_callback_cuda_compile(code, target):
cutlass_path = os.environ["TL_CUTLASS_PATH"] cutlass_path = os.environ["TL_CUTLASS_PATH"]
else: else:
cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include")) cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
compute_version = "".join(nvcc.get_target_compute_version(target).split(".")) target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target))
# special handle for Hopper arch = [f"-arch=sm_{target_arch}"]
if compute_version == "90": format = "cubin"
arch = ["-arch=sm_90a"]
format = "cubin"
else:
arch = [f"-arch=sm_{compute_version}"]
format = "cubin"
# printing out number of registers # printing out number of registers
debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage" debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
......
...@@ -53,11 +53,10 @@ def _initialize_torch_cuda_arch_flags(): ...@@ -53,11 +53,10 @@ def _initialize_torch_cuda_arch_flags():
target = determine_target(return_object=True) target = determine_target(return_object=True)
# create tmp source file for torch cpp extension # create tmp source file for torch cpp extension
compute_version = "".join(nvcc.get_target_compute_version(target).split(".")) compute_version = nvcc.get_target_compute_version(target)
# set TORCH_CUDA_ARCH_LIST major, minor = nvcc.parse_compute_version(compute_version)
major = compute_version[0]
minor = compute_version[1]
# set TORCH_CUDA_ARCH_LIST
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"
......
...@@ -11,7 +11,7 @@ from tvm.target import Target ...@@ -11,7 +11,7 @@ from tvm.target import Target
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.transform import PassConfigKey from tilelang.transform import PassConfigKey
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_compute_version from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_arch, get_target_compute_version
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
from tilelang.env import TILELANG_TEMPLATE_PATH from tilelang.env import TILELANG_TEMPLATE_PATH
...@@ -67,9 +67,7 @@ class LibraryGenerator(object): ...@@ -67,9 +67,7 @@ class LibraryGenerator(object):
if is_cuda_target(target): if is_cuda_target(target):
from tilelang.env import CUTLASS_INCLUDE_DIR from tilelang.env import CUTLASS_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = "".join(get_target_compute_version(target).split(".")) target_arch = get_target_arch(get_target_compute_version(target))
if compute_version == "90":
compute_version = "90a"
libpath = src.name.replace(".cu", ".so") libpath = src.name.replace(".cu", ".so")
disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False) disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False)
...@@ -91,7 +89,7 @@ class LibraryGenerator(object): ...@@ -91,7 +89,7 @@ class LibraryGenerator(object):
src.name, src.name,
"-lcuda", "-lcuda",
"-gencode", "-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}", f"arch=compute_{target_arch},code=sm_{target_arch}",
] ]
if not disable_fast_math: if not disable_fast_math:
command += ["--use_fast_math"] command += ["--use_fast_math"]
......
...@@ -36,11 +36,7 @@ def _get_workspace_dir_name() -> pathlib.Path: ...@@ -36,11 +36,7 @@ def _get_workspace_dir_name() -> pathlib.Path:
target = determine_target(return_object=True) target = determine_target(return_object=True)
# create tmp source file for torch cpp extension # create tmp source file for torch cpp extension
compute_version = "".join(nvcc.get_target_compute_version(target).split(".")) arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target))
# set TORCH_CUDA_ARCH_LIST
major = compute_version[0]
minor = compute_version[1]
arch = f"{major}_{minor}"
except Exception: except Exception:
arch = "noarch" arch = "noarch"
# e.g.: $HOME/.cache/tilelang/75_80_89_90/ # e.g.: $HOME/.cache/tilelang/75_80_89_90/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment