Commit eb757608 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AMD] Fix for missing composable kernel include path when compile kernels on amd gpus (#334)

* [Enhancement] Add new matrix multiplication functions and tests for GEMM with transpose options

- Introduced `matmul_rs` function for flexible matrix multiplication with optional transposition.
- Added `run_gemm_rs` function to facilitate testing of the new matrix multiplication implementation.
- Expanded test coverage for GEMM with additional cases for transposition configurations.
- Corrected index usage in `gemm.h` to ensure proper matrix layout handling.

These changes enhance the GEMM functionality and improve testing capabilities for various matrix configurations.

* [Enhancement] Add Composable Kernel Path Handling in Environment Setup

- Introduced support for the Composable Kernel by adding a new environment variable `TL_COMPOSABLE_KERNEL_PATH`.
- Updated the environment setup to check for the existence of the Composable Kernel and log a warning if not found.
- Modified the `LibraryGenerator` to include the Composable Kernel include directory during compilation for HIP targets.

These changes improve the integration of the Composable Kernel into the TileLang environment, enhancing flexibility for users.
parent 85e411c8
......@@ -35,6 +35,7 @@ def _find_cuda_home() -> str:
CUDA_HOME = _find_cuda_home()
CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None)
COMPOSABLE_KERNEL_INCLUDE_DIR: str = os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None)
TVM_PYTHON_PATH: str = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None)
TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None)
......@@ -49,6 +50,9 @@ TILELANG_CLEAR_CACHE = os.environ.get("TILELANG_CLEAR_CACHE", "0")
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = (
"Composable Kernel is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
......@@ -110,6 +114,20 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE)
if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
install_ck_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "composable_kernel")
develop_ck_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "composable_kernel")
if os.path.exists(install_ck_path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = install_ck_path + "/include"
COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include"
elif (os.path.exists(develop_ck_path) and develop_ck_path not in sys.path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = develop_ck_path + "/include"
COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include"
else:
logger.warning(COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE)
if os.environ.get("TL_TEMPLATE_PATH", None) is None:
install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
......@@ -151,6 +169,7 @@ is_cache_enabled = CacheState.is_enabled
__all__ = [
"CUTLASS_INCLUDE_DIR",
"COMPOSABLE_KERNEL_INCLUDE_DIR",
"TVM_PYTHON_PATH",
"TVM_LIBRARY_PATH",
"TILELANG_TEMPLATE_PATH",
......
......@@ -8,7 +8,7 @@ import os
import tempfile
import subprocess
import logging
from tilelang.env import TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR
from tilelang.env import TILELANG_TEMPLATE_PATH
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
logger = logging.getLogger(__name__)
......@@ -31,9 +31,10 @@ class LibraryGenerator(object):
lib_path = self.libpath
return ctypes.CDLL(lib_path)
def compile_lib(self, timeout: float = None, with_tl: bool = True):
def compile_lib(self, timeout: float = None):
target = self.target
if is_cuda_target(target):
from tilelang.env import CUTLASS_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = "".join(get_target_compute_version(target).split("."))
if compute_version == "90":
......@@ -55,8 +56,12 @@ class LibraryGenerator(object):
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
]
command += [
"-I" + CUTLASS_INCLUDE_DIR,
]
elif is_hip_target(target):
from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so")
rocm_path = find_rocm_path()
......@@ -69,23 +74,23 @@ class LibraryGenerator(object):
"--shared",
src.name,
]
command += [
"-I" + COMPOSABLE_KERNEL_INCLUDE_DIR,
]
elif is_cpu_target(target):
from tilelang.contrib.cc import get_cplus_compiler
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so")
command = [get_cplus_compiler(), "-std=c++17", "-fPIC", "-shared", src.name]
with_tl = False
command += [
"-I" + TILELANG_TEMPLATE_PATH,
]
else:
raise ValueError(f"Unsupported target: {target}")
if with_tl:
command += [
"-I" + TILELANG_TEMPLATE_PATH,
"-I" + CUTLASS_INCLUDE_DIR,
]
command += ["-o", libpath]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment