Unverified Commit da74c09d authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

Trivial update to calculate target arch (#702)



* Trivial update to calculate target arch

* Update tilelang/contrib/nvrtc.py
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* fmt

---------
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 6f59668d
......@@ -52,9 +52,9 @@ def compile_cuda(code,
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split("."))
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]
compute_version = get_target_compute_version(Target.current(allow_none=True))
target_arch = get_target_arch(compute_version)
arch = ["-gencode", f"arch=compute_{target_arch},code=sm_{target_arch}"]
temp = utils.tempdir()
file_name = "tvm_kernels"
......@@ -298,7 +298,7 @@ def get_target_compute_version(target=None):
"Try specifying it by adding '-arch=sm_xx' to your target.")
def parse_compute_version(compute_version):
def parse_compute_version(compute_version) -> tuple[int, int]:
"""Parse compute capability string to divide major and minor version
Parameters
......@@ -323,6 +323,14 @@ def parse_compute_version(compute_version):
raise RuntimeError("Compute version parsing error") from err
def get_target_arch(compute_version) -> str:
major, minor = parse_compute_version(compute_version)
target_arch = str(major * 10 + minor)
if major >= 9:
target_arch += "a"
return target_arch
def have_fp16(compute_version):
"""Either fp16 support is provided in the compute capability or not
......
import cuda.bindings.nvrtc as nvrtc
from typing import Literal, Union, List, Optional, Tuple
from tvm.target import Target
from .nvcc import get_target_compute_version
from .nvcc import get_target_compute_version, parse_compute_version
def get_nvrtc_version() -> Tuple[int, int]:
......@@ -42,9 +42,9 @@ def compile_cuda(code: str,
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "80", "90", "90a", etc.
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split("."))
arch = int(compute_version)
major, minor = parse_compute_version(
get_target_compute_version(Target.current(allow_none=True)))
arch = major * 10 + minor
prefix = "compute" if target_format == "ptx" else "sm"
suffix = "a" if arch >= 90 else ""
arch_option = f"--gpu-architecture={prefix}_{arch}{suffix}"
......
......@@ -64,15 +64,10 @@ def tilelang_callback_cuda_compile(code, target):
cutlass_path = os.environ["TL_CUTLASS_PATH"]
else:
cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
compute_version = "".join(nvcc.get_target_compute_version(target).split("."))
target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target))
# special handle for Hopper
if compute_version == "90":
arch = ["-arch=sm_90a"]
format = "cubin"
else:
arch = [f"-arch=sm_{compute_version}"]
format = "cubin"
arch = [f"-arch=sm_{target_arch}"]
format = "cubin"
# printing out number of registers
debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
......
......@@ -53,11 +53,10 @@ def _initialize_torch_cuda_arch_flags():
target = determine_target(return_object=True)
# create tmp source file for torch cpp extension
compute_version = "".join(nvcc.get_target_compute_version(target).split("."))
# set TORCH_CUDA_ARCH_LIST
major = compute_version[0]
minor = compute_version[1]
compute_version = nvcc.get_target_compute_version(target)
major, minor = nvcc.parse_compute_version(compute_version)
# set TORCH_CUDA_ARCH_LIST
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"
......
......@@ -11,7 +11,7 @@ from tvm.target import Target
from tilelang import tvm as tvm
from tilelang.transform import PassConfigKey
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_compute_version
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_arch, get_target_compute_version
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
from tilelang.env import TILELANG_TEMPLATE_PATH
......@@ -67,9 +67,7 @@ class LibraryGenerator(object):
if is_cuda_target(target):
from tilelang.env import CUTLASS_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = "".join(get_target_compute_version(target).split("."))
if compute_version == "90":
compute_version = "90a"
target_arch = get_target_arch(get_target_compute_version(target))
libpath = src.name.replace(".cu", ".so")
disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False)
......@@ -91,7 +89,7 @@ class LibraryGenerator(object):
src.name,
"-lcuda",
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
f"arch=compute_{target_arch},code=sm_{target_arch}",
]
if not disable_fast_math:
command += ["--use_fast_math"]
......
......@@ -36,11 +36,7 @@ def _get_workspace_dir_name() -> pathlib.Path:
target = determine_target(return_object=True)
# create tmp source file for torch cpp extension
compute_version = "".join(nvcc.get_target_compute_version(target).split("."))
# set TORCH_CUDA_ARCH_LIST
major = compute_version[0]
minor = compute_version[1]
arch = f"{major}_{minor}"
arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target))
except Exception:
arch = "noarch"
# e.g.: $HOME/.cache/tilelang/75_80_89_90/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment