Commit eb757608 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AMD] Fix for missing composable kernel include path when compile kernels on amd gpus (#334)

* [Enhancement] Add new matrix multiplication functions and tests for GEMM with transpose options

- Introduced `matmul_rs` function for flexible matrix multiplication with optional transposition.
- Added `run_gemm_rs` function to facilitate testing of the new matrix multiplication implementation.
- Expanded test coverage for GEMM with additional cases for transposition configurations.
- Corrected index usage in `gemm.h` to ensure proper matrix layout handling.

These changes enhance the GEMM functionality and improve testing capabilities for various matrix configurations.

* [Enhancement] Add Composable Kernel Path Handling in Environment Setup

- Introduced support for the Composable Kernel by adding a new environment variable `TL_COMPOSABLE_KERNEL_PATH`.
- Updated the environment setup to check for the existence of the Composable Kernel and log a warning if not found.
- Modified the `LibraryGenerator` to include the Composable Kernel include directory during compilation for HIP targets.

These changes improve the integration of the Composable Kernel into the TileLang environment, enhancing flexibility for users.
parent 85e411c8
...@@ -35,6 +35,7 @@ def _find_cuda_home() -> str: ...@@ -35,6 +35,7 @@ def _find_cuda_home() -> str:
CUDA_HOME = _find_cuda_home() CUDA_HOME = _find_cuda_home()
CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None) CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None)
COMPOSABLE_KERNEL_INCLUDE_DIR: str = os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None)
TVM_PYTHON_PATH: str = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) TVM_PYTHON_PATH: str = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None) TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None)
TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None) TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None)
...@@ -49,6 +50,9 @@ TILELANG_CLEAR_CACHE = os.environ.get("TILELANG_CLEAR_CACHE", "0") ...@@ -49,6 +50,9 @@ TILELANG_CLEAR_CACHE = os.environ.get("TILELANG_CLEAR_CACHE", "0")
# SETUP ENVIRONMENT VARIABLES # SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend." ", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = (
"Composable Kernel is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend." ", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path") TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
...@@ -110,6 +114,20 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None: ...@@ -110,6 +114,20 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
else: else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE) logger.warning(CUTLASS_NOT_FOUND_MESSAGE)
if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
install_ck_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "composable_kernel")
develop_ck_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "composable_kernel")
if os.path.exists(install_ck_path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = install_ck_path + "/include"
COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include"
elif (os.path.exists(develop_ck_path) and develop_ck_path not in sys.path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = develop_ck_path + "/include"
COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include"
else:
logger.warning(COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE)
if os.environ.get("TL_TEMPLATE_PATH", None) is None: if os.environ.get("TL_TEMPLATE_PATH", None) is None:
install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src") install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src") develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
...@@ -151,6 +169,7 @@ is_cache_enabled = CacheState.is_enabled ...@@ -151,6 +169,7 @@ is_cache_enabled = CacheState.is_enabled
__all__ = [ __all__ = [
"CUTLASS_INCLUDE_DIR", "CUTLASS_INCLUDE_DIR",
"COMPOSABLE_KERNEL_INCLUDE_DIR",
"TVM_PYTHON_PATH", "TVM_PYTHON_PATH",
"TVM_LIBRARY_PATH", "TVM_LIBRARY_PATH",
"TILELANG_TEMPLATE_PATH", "TILELANG_TEMPLATE_PATH",
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
import tempfile import tempfile
import subprocess import subprocess
import logging import logging
from tilelang.env import TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR from tilelang.env import TILELANG_TEMPLATE_PATH
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -31,9 +31,10 @@ class LibraryGenerator(object): ...@@ -31,9 +31,10 @@ class LibraryGenerator(object):
lib_path = self.libpath lib_path = self.libpath
return ctypes.CDLL(lib_path) return ctypes.CDLL(lib_path)
def compile_lib(self, timeout: float = None, with_tl: bool = True): def compile_lib(self, timeout: float = None):
target = self.target target = self.target
if is_cuda_target(target): if is_cuda_target(target):
from tilelang.env import CUTLASS_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = "".join(get_target_compute_version(target).split(".")) compute_version = "".join(get_target_compute_version(target).split("."))
if compute_version == "90": if compute_version == "90":
...@@ -55,8 +56,12 @@ class LibraryGenerator(object): ...@@ -55,8 +56,12 @@ class LibraryGenerator(object):
"-gencode", "-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}", f"arch=compute_{compute_version},code=sm_{compute_version}",
] ]
command += [
"-I" + CUTLASS_INCLUDE_DIR,
]
elif is_hip_target(target): elif is_hip_target(target):
from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so") libpath = src.name.replace(".cpp", ".so")
rocm_path = find_rocm_path() rocm_path = find_rocm_path()
...@@ -69,23 +74,23 @@ class LibraryGenerator(object): ...@@ -69,23 +74,23 @@ class LibraryGenerator(object):
"--shared", "--shared",
src.name, src.name,
] ]
command += [
"-I" + COMPOSABLE_KERNEL_INCLUDE_DIR,
]
elif is_cpu_target(target): elif is_cpu_target(target):
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so") libpath = src.name.replace(".cpp", ".so")
command = [get_cplus_compiler(), "-std=c++17", "-fPIC", "-shared", src.name] command = [get_cplus_compiler(), "-std=c++17", "-fPIC", "-shared", src.name]
with_tl = False
command += [ command += [
"-I" + TILELANG_TEMPLATE_PATH, "-I" + TILELANG_TEMPLATE_PATH,
] ]
else: else:
raise ValueError(f"Unsupported target: {target}") raise ValueError(f"Unsupported target: {target}")
if with_tl:
command += [ command += [
"-I" + TILELANG_TEMPLATE_PATH, "-I" + TILELANG_TEMPLATE_PATH,
"-I" + CUTLASS_INCLUDE_DIR,
] ]
command += ["-o", libpath] command += ["-o", libpath]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment