Unverified Commit 7fb06776 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Backend] Add metal backend (#799)



* Reset

* Fix other CUDA issue

* fmt

* fmt

* fix cuda error

* fix

* fix

* fmt

* cleanup

* fix

* remove copyright

* trivial update

* readme update

* lint fix

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 394e17d0
name: CI Test on Metal
on: [pull_request]
env:
PYTHON_VERSION: '3.12'
VENV_DIR: tilelang_ci
jobs:
format-check:
runs-on: [macos-latest]
permissions:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: recursive
- name: Install python via uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
ignore-nothing-to-cache: true
activate-environment: true
python-version: ${{ env.PYTHON_VERSION }}
- name: Ensure venv (local & persistent)
run: |
[[ -f requirements-test.txt ]] && \
uv pip install -r requirements-test.txt --no-build-isolation
- name: Run format check
run: |
set -ex
mkdir -p build
# run cmake to create the build directory with compile_commands.json
cd build; cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DUSE_METAL=ON; cd ..
if ! output=$(./format.sh 2>&1); then
echo "------------------------------------"
echo "message:"
echo "$output"
printf '%s\n' "$output"
echo "------------------------------------"
exit 1
fi
build-test-metal:
runs-on: [macos-latest]
needs: format-check
permissions:
contents: read
env:
CMAKE_C_COMPILER_LAUNCHER: ccache
CMAKE_CXX_COMPILER_LAUNCHER: ccache
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
submodules: recursive
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
create-symlink: true
key: ${{ github.job }}-${{ matrix.os }}
- name: Install python via uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
ignore-nothing-to-cache: true
activate-environment: true
python-version: ${{ env.PYTHON_VERSION }}
- name: Ensure venv (local & persistent)
run: uv pip install -r requirements-test.txt -r requirements-build.txt
- name: Build wheel
run: |
source .venv/bin/activate
uv pip install -v --no-build-isolation .
- name: Run metal test
run: |
cd testing/python
unset PYTHONPATH
python -m pytest -k metal -v -r fE --durations=0 --timeout=3600
...@@ -108,13 +108,21 @@ endif() ...@@ -108,13 +108,21 @@ endif()
if(DEFINED TVM_PREBUILD_PATH) if(DEFINED TVM_PREBUILD_PATH)
message(STATUS "Using prebuilt TVM from ${TVM_PREBUILD_PATH}") message(STATUS "Using prebuilt TVM from ${TVM_PREBUILD_PATH}")
add_library(tvm SHARED IMPORTED) add_library(tvm SHARED IMPORTED)
find_library(TVM_LIBRARY_LOCATION
NAMES tvm
HINTS "${TVM_PREBUILD_PATH}"
)
set_target_properties(tvm PROPERTIES set_target_properties(tvm PROPERTIES
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm.so" IMPORTED_LOCATION "${TVM_LIBRARY_LOCATION}"
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include" INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
) )
add_library(tvm_runtime SHARED IMPORTED) add_library(tvm_runtime SHARED IMPORTED)
find_library(TVM_RUNTIME_LIBRARY_LOCATION
NAMES tvm_runtime
HINTS "${TVM_PREBUILD_PATH}"
)
set_target_properties(tvm_runtime PROPERTIES set_target_properties(tvm_runtime PROPERTIES
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm_runtime.so" IMPORTED_LOCATION "${TVM_RUNTIME_LIBRARY_LOCATION}"
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include" INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
) )
else() else()
...@@ -157,6 +165,13 @@ if(USE_ROCM) ...@@ -157,6 +165,13 @@ if(USE_ROCM)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_HIP_SRCS}) list(APPEND TILE_LANG_SRCS ${TILE_LANG_HIP_SRCS})
endif() endif()
if(USE_METAL)
tilelang_file_glob(GLOB TILE_LANG_METAL_SRCS
src/target/rt_mod_metal.cc
)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS})
endif()
message(STATUS "Collected source files: ${TILE_LANG_SRCS}") message(STATUS "Collected source files: ${TILE_LANG_SRCS}")
# Add TileLang object library # Add TileLang object library
...@@ -221,6 +236,9 @@ target_compile_definitions(tilelang_objs PRIVATE -DTILE_LANG_EXPORTS) ...@@ -221,6 +236,9 @@ target_compile_definitions(tilelang_objs PRIVATE -DTILE_LANG_EXPORTS)
# Shared library # Shared library
add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>) add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
target_link_libraries(tilelang PUBLIC tvm_runtime) target_link_libraries(tilelang PUBLIC tvm_runtime)
if(USE_METAL)
target_link_libraries(tilelang PUBLIC tvm)
endif()
# Static library # Static library
add_library(tilelang_static STATIC $<TARGET_OBJECTS:tilelang_objs>) add_library(tilelang_static STATIC $<TARGET_OBJECTS:tilelang_objs>)
......
...@@ -13,6 +13,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to ...@@ -13,6 +13,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png /> <img src=./images/MatmulExample.png />
## Latest News ## Latest News
- 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details.
- 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported! - 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported!
Check out the preview here: Check out the preview here:
🔗 [link](https://github.com/tile-ai/tilelang-ascend). 🔗 [link](https://github.com/tile-ai/tilelang-ascend).
......
#!/bin/bash
set -eux
git submodule update --init --recursive
rm -rf build
mkdir build
cp 3rdparty/tvm/cmake/config.cmake build
cd build
echo "set(USE_METAL ON)" >> config.cmake
CMAKE_C_COMPILER_LAUNCHER=ccache CMAKE_CXX_COMPILER_LAUNCHER=ccache cmake ..
CORES=$(sysctl -n hw.logicalcpu)
MAKE_JOBS=$(( CORES / 2 ))
make -j${MAKE_JOBS}
...@@ -32,19 +32,60 @@ logging.basicConfig( ...@@ -32,19 +32,60 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _read_bool_env(name: str, default: bool = False) -> bool:
if env := os.environ.get(name):
env = env.lower()
if env in ['on', '1', 'true']:
return True
elif env in ['', 'off', '0', 'false']:
return False
return default
# Environment variables False/True # Environment variables False/True
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true" PYPI_BUILD = _read_bool_env('PYPI_BUILD')
PACKAGE_NAME = "tilelang" PACKAGE_NAME = "tilelang"
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
CYCACHE = Path(os.path.join(ROOT_DIR, "tilelang", "jit", "adapter", "cython", ".cycache"))
if not CYCACHE.exists():
# tvm may needs this, we won't always build cython backend so mkdir here.
CYCACHE.mkdir(exist_ok=True)
IS_LINUX = platform.system() == 'Linux'
MAYBE_METAL = platform.mac_ver()[2] == 'arm64'
# Add LLVM control environment variable # Add LLVM control environment variable
USE_LLVM = os.environ.get("USE_LLVM", "False").lower() == "true" USE_LLVM = _read_bool_env('USE_LLVM')
# Add ROCM control environment variable
USE_ROCM = _read_bool_env("USE_ROCM")
# Add ROCM control environment variable # Add ROCM control environment variable
USE_ROCM = os.environ.get("USE_ROCM", "False").lower() == "true" USE_METAL = _read_bool_env("USE_METAL", MAYBE_METAL)
# Add ROCM control environment variable
USE_CUDA = _read_bool_env("USE_CUDA", IS_LINUX and not USE_ROCM)
# Build with Debug mode # Build with Debug mode
DEBUG_MODE = os.environ.get("DEBUG_MODE", "False").lower() == "true" DEBUG_MODE = _read_bool_env('DEBUG_MODE')
# Include commit ID in wheel filename and package metadata # Include commit ID in wheel filename and package metadata
WITH_COMMITID = os.environ.get("WITH_COMMITID", "True").lower() == "true" WITH_COMMITID = _read_bool_env("WITH_COMMITID")
TVM_PREBUILD_ITEMS = [
"libtvm_runtime.so",
"libtvm.so",
"libtilelang.so",
"libtilelang_module.so",
] if IS_LINUX else [
"libtvm_runtime.dylib",
"libtvm.dylib",
"libtilelang.dylib",
"libtilelang_module.dylib",
]
# from tvm's internal cython?
TVM_PREBUILD_ITEMS_TO_DELETE = [] if IS_LINUX else [
'libtvm_runtime.dylib.dSYM',
'libtvm.dylib.dSYM',
]
def load_module_from_path(module_name, path): def load_module_from_path(module_name, path):
...@@ -65,24 +106,17 @@ if USE_ROCM and not ROCM_HOME: ...@@ -65,24 +106,17 @@ if USE_ROCM and not ROCM_HOME:
raise ValueError( raise ValueError(
"ROCM support is enabled (USE_ROCM=True) but ROCM_HOME is not set or detected.") "ROCM support is enabled (USE_ROCM=True) but ROCM_HOME is not set or detected.")
if not USE_ROCM and not CUDA_HOME: if USE_CUDA and not CUDA_HOME:
raise ValueError( raise ValueError(
"CUDA support is enabled by default (USE_ROCM=False) but CUDA_HOME is not set or detected.") "CUDA support is enabled by default on linux if `USE_ROCM=False`," \
" but CUDA_HOME is not set or detected.")
# Ensure one of CUDA or ROCM is available # Ensure one of CUDA or ROCM is available
if not (CUDA_HOME or ROCM_HOME): if IS_LINUX and not (CUDA_HOME or ROCM_HOME):
raise ValueError( raise ValueError(
"Failed to automatically detect CUDA or ROCM installation. Please set the CUDA_HOME or ROCM_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda or export ROCM_HOME=/opt/rocm)." "Failed to automatically detect CUDA or ROCM installation. Please set the CUDA_HOME or ROCM_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda or export ROCM_HOME=/opt/rocm)."
) )
# TileLang only supports Linux platform
assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)."
def _is_linux_like():
return (sys.platform == "darwin" or sys.platform.startswith("linux") or
sys.platform.startswith("freebsd"))
def get_path(*filepath) -> str: def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath) return os.path.join(ROOT_DIR, *filepath)
...@@ -144,7 +178,9 @@ def get_rocm_version(): ...@@ -144,7 +178,9 @@ def get_rocm_version():
return Version("5.0.0") return Version("5.0.0")
def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str: def get_tilelang_version(with_cuda=USE_CUDA,
with_system_info=not MAYBE_METAL,
with_commit_id=False) -> str:
version = find_version(get_path(".", "VERSION")) version = find_version(get_path(".", "VERSION"))
local_version_parts = [] local_version_parts = []
if with_system_info: if with_system_info:
...@@ -194,9 +230,6 @@ def get_cplus_compiler(): ...@@ -194,9 +230,6 @@ def get_cplus_compiler():
The path to the default C/C++ compiler, or None if none was found. The path to the default C/C++ compiler, or None if none was found.
""" """
if not _is_linux_like():
return None
env_cxx = os.environ.get("CXX") or os.environ.get("CC") env_cxx = os.environ.get("CXX") or os.environ.get("CC")
if env_cxx: if env_cxx:
return env_cxx return env_cxx
...@@ -371,6 +404,8 @@ def patch_libs(libpath): ...@@ -371,6 +404,8 @@ def patch_libs(libpath):
and have a hard-coded rpath. and have a hard-coded rpath.
Set rpath to the directory of libs so auditwheel works well. Set rpath to the directory of libs so auditwheel works well.
""" """
if not IS_LINUX:
return
# check if patchelf is installed # check if patchelf is installed
# find patchelf in the system # find patchelf in the system
patchelf_path = shutil.which("patchelf") patchelf_path = shutil.which("patchelf")
...@@ -432,13 +467,6 @@ class TileLangBuilPydCommand(build_py): ...@@ -432,13 +467,6 @@ class TileLangBuilPydCommand(build_py):
os.makedirs(target_dir) os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir) shutil.copy2(source_dir, target_dir)
TVM_PREBUILD_ITEMS = [
"libtvm_runtime.so",
"libtvm.so",
"libtilelang.so",
"libtilelang_module.so",
]
potential_dirs = [ potential_dirs = [
ext_output_dir, ext_output_dir,
self.build_lib, self.build_lib,
...@@ -468,6 +496,14 @@ class TileLangBuilPydCommand(build_py): ...@@ -468,6 +496,14 @@ class TileLangBuilPydCommand(build_py):
else: else:
logger.info(f"WARNING: {item} not found in any expected directories!") logger.info(f"WARNING: {item} not found in any expected directories!")
for item in TVM_PREBUILD_ITEMS_TO_DELETE:
source_lib_file = None
for dir in potential_dirs:
candidate = os.path.join(dir, item)
if os.path.exists(candidate):
shutil.rmtree(candidate)
break
TVM_CONFIG_ITEMS = [ TVM_CONFIG_ITEMS = [
f"{build_temp_dir}/config.cmake", f"{build_temp_dir}/config.cmake",
] ]
...@@ -587,10 +623,10 @@ class CMakeExtension(Extension): ...@@ -587,10 +623,10 @@ class CMakeExtension(Extension):
:param sourcedir: Directory containing the top-level CMakeLists.txt. :param sourcedir: Directory containing the top-level CMakeLists.txt.
""" """
def __init__(self, name, sourcedir=""): def __init__(self, name, sourcedir="", **kwargs):
# We pass an empty 'sources' list because # We pass an empty 'sources' list because
# the actual build is handled by CMake, not setuptools. # the actual build is handled by CMake, not setuptools.
super().__init__(name=name, sources=[]) super().__init__(name=name, sources=[], **kwargs)
# Convert the source directory to an absolute path # Convert the source directory to an absolute path
# so that CMake can correctly locate the CMakeLists.txt. # so that CMake can correctly locate the CMakeLists.txt.
...@@ -642,7 +678,7 @@ class TilelangExtensionBuild(build_ext): ...@@ -642,7 +678,7 @@ class TilelangExtensionBuild(build_ext):
# To make it works with editable install, # To make it works with editable install,
# we need to copy the lib*.so files to the tilelang/lib directory # we need to copy the lib*.so files to the tilelang/lib directory
import glob import glob
files = glob.glob("*.so") files = glob.glob("*.so" if IS_LINUX else "*.dylib")
if os.path.exists(PACKAGE_NAME): if os.path.exists(PACKAGE_NAME):
target_lib_dir = os.path.join(PACKAGE_NAME, "lib") target_lib_dir = os.path.join(PACKAGE_NAME, "lib")
for file in files: for file in files:
...@@ -724,7 +760,10 @@ class TilelangExtensionBuild(build_ext): ...@@ -724,7 +760,10 @@ class TilelangExtensionBuild(build_ext):
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}") os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
python_include_path = sysconfig.get_path("include") python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler() cc = get_cplus_compiler()
if MAYBE_METAL:
cc += ' -Wl,-undefined,dynamic_lookup'
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}" command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}"
logger.info(command)
os.system(command) os.system(command)
# rename the temp file to the library file # rename the temp file to the library file
...@@ -783,7 +822,7 @@ class TilelangExtensionBuild(build_ext): ...@@ -783,7 +822,7 @@ class TilelangExtensionBuild(build_ext):
"-G", "-G",
"Ninja", "Ninja",
] ]
if not USE_ROCM: if USE_CUDA and not USE_ROCM:
cmake_args.append(f"-DCMAKE_CUDA_COMPILER={os.path.join(CUDA_HOME, 'bin', 'nvcc')}") cmake_args.append(f"-DCMAKE_CUDA_COMPILER={os.path.join(CUDA_HOME, 'bin', 'nvcc')}")
# Create the temporary build directory (if it doesn't exist). # Create the temporary build directory (if it doesn't exist).
...@@ -804,12 +843,17 @@ class TilelangExtensionBuild(build_ext): ...@@ -804,12 +843,17 @@ class TilelangExtensionBuild(build_ext):
content_lines.append(f"set(USE_LLVM {llvm_config_path})") content_lines.append(f"set(USE_LLVM {llvm_config_path})")
# Append GPU backend configuration based on environment # Append GPU backend configuration based on environment
if USE_ROCM: if USE_METAL:
content_lines += [
"set(USE_METAL ON)",
"set(USE_ROCM OFF)",
]
elif USE_ROCM:
content_lines += [ content_lines += [
f"set(USE_ROCM {ROCM_HOME})", f"set(USE_ROCM {ROCM_HOME})",
"set(USE_CUDA OFF)", "set(USE_CUDA OFF)",
] ]
else: elif CUDA_HOME:
content_lines += [ content_lines += [
f"set(USE_CUDA {CUDA_HOME})", f"set(USE_CUDA {CUDA_HOME})",
"set(USE_ROCM OFF)", "set(USE_ROCM OFF)",
...@@ -846,6 +890,12 @@ class TilelangExtensionBuild(build_ext): ...@@ -846,6 +890,12 @@ class TilelangExtensionBuild(build_ext):
cwd=build_temp) cwd=build_temp)
ext_modules = [
CMakeExtension("TileLangCXX", sourcedir="."),
]
if not MAYBE_METAL:
ext_modules.append(CythonExtension("TileLangCython", sourcedir="."))
setup( setup(
name=PACKAGE_NAME, name=PACKAGE_NAME,
version=(get_tilelang_version(with_cuda=False, with_system_info=False, with_commit_id=False) version=(get_tilelang_version(with_cuda=False, with_system_info=False, with_commit_id=False)
......
// Currently mps backend use the codegen from tvm without modification.
// But in the future we're likely to add functions on top of that.
// Added an empty file for now.
import tilelang
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
import torch
@tilelang.jit(execution_backend='torch')
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope='shared')
B_shared = T.alloc_shared((block_K, block_N), dtype, scope='shared')
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2)
T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2)
for i, j, k in T.Parallel(block_M, block_N, block_K):
C_local[i, j] += A_shared[i, k] * B_shared[k, j]
T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2)
return gemm
def assert_gemm(
M,
N,
K,
block_M,
block_N,
block_K,
dtype="float32",
accum_dtype="float",
atol=1e-8,
):
jit_kernel = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
torch_dtype = getattr(torch, dtype)
a, b = None, None
if 'int' in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device='mps')
b = torch.randint(100, (K, N), dtype=torch_dtype, device='mps')
else:
a = torch.randn(M, K, dtype=torch_dtype, device='mps')
b = torch.randn(K, N, dtype=torch_dtype, device='mps')
c = torch.zeros(M, N, dtype=torch_dtype, device='mps')
jit_kernel(a, b, c)
assert torch.allclose(a @ b, c, atol=atol)
assert jit_kernel.kernel_source is not None
@tilelang.testing.requires_metal
def test_gemm_float32():
assert_gemm(1024, 1024, 1024, 16, 16, 16)
@tilelang.testing.requires_metal
def test_gemm_float16():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='float16', atol=1)
@tilelang.testing.requires_metal
def test_gemm_int32():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='int32', atol=1)
if __name__ == "__main__":
if torch.mps.is_available():
tilelang.testing.main()
...@@ -465,13 +465,24 @@ class AutoTuner: ...@@ -465,13 +465,24 @@ class AutoTuner:
futures = [] futures = []
future_to_index = {} future_to_index = {}
def device_wrapper(func, device, **config_arg): def cuda_device_wrapper(func, device):
torch.cuda.set_device(device)
return func(**config_arg) def inner(**config_arg):
torch.cuda.set_device(device)
return func(**config_arg)
return inner
for i, config_arg in enumerate(config_args): for i, config_arg in enumerate(config_args):
compile_func = self.jit_compile
if torch.cuda.is_available():
device = torch.cuda.current_device()
compile_func = cuda_device_wrapper(self.jit_compile, device)
future = pool.submit( future = pool.submit(
functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()), compile_func,
**config_arg, **config_arg,
) )
futures.append(future) futures.append(future)
...@@ -543,7 +554,7 @@ class AutoTuner: ...@@ -543,7 +554,7 @@ class AutoTuner:
func=best_kernel.prim_func, func=best_kernel.prim_func,
kernel=best_kernel) kernel=best_kernel)
if self.compile_args.execution_backend == "dlpack": if self.compile_args.execution_backend in ("dlpack", "torch"):
logger.warning("DLPack backend does not support cache saving to disk.") logger.warning("DLPack backend does not support cache saving to disk.")
else: else:
with self._lock: with self._lock:
......
...@@ -191,8 +191,8 @@ class KernelCache: ...@@ -191,8 +191,8 @@ class KernelCache:
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
) )
if execution_backend == "dlpack": if execution_backend in ("dlpack", "torch"):
self.logger.warning("DLPack backend does not support cache saving to disk.") self.logger.warning("DLPack or torch backend does not support cache saving to disk.")
else: else:
with self._lock: with self._lock:
if env.is_cache_enabled(): if env.is_cache_enabled():
......
from .arch_base import TileDevice from .arch_base import TileDevice
from .cuda import CUDA from .cuda import *
from .cpu import CPU from .cpu import *
from .cdna import CDNA from .cdna import *
from .metal import *
from typing import Union from typing import Union
from tvm.target import Target from tvm.target import Target
import torch import torch
...@@ -17,6 +18,8 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: ...@@ -17,6 +18,8 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
return CPU(target) return CPU(target)
elif target.kind.name == "hip": elif target.kind.name == "hip":
return CDNA(target) return CDNA(target)
elif target.kind.name == "metal":
return METAL(target)
else: else:
raise ValueError(f"Unsupported target: {target.kind.name}") raise ValueError(f"Unsupported target: {target.kind.name}")
...@@ -28,18 +31,25 @@ def auto_infer_current_arch() -> TileDevice: ...@@ -28,18 +31,25 @@ def auto_infer_current_arch() -> TileDevice:
return get_arch("hip") return get_arch("hip")
if torch.cuda.is_available(): if torch.cuda.is_available():
return get_arch("cuda") return get_arch("cuda")
elif torch.mps.is_available():
return get_arch("metal")
else: else:
return get_arch("llvm") return get_arch("llvm")
from .cpu import is_cpu_arch # noqa: F401 __all__ = [
from .cuda import ( 'is_cpu_arch',
is_cuda_arch, # noqa: F401 'is_cuda_arch',
is_volta_arch, # noqa: F401 'is_volta_arch',
is_ampere_arch, # noqa: F401 'is_ampere_arch',
is_ada_arch, # noqa: F401 'is_ada_arch',
is_hopper_arch, # noqa: F401 'is_hopper_arch',
is_tensorcore_supported_precision, # noqa: F401 'is_tensorcore_supported_precision',
has_mma_support, # noqa: F401 'has_mma_support',
) 'is_cdna_arch',
from .cdna import is_cdna_arch # noqa: F401 'is_metal_arch',
'CUDA',
'CDNA',
'METAL',
'CPU',
]
...@@ -30,3 +30,9 @@ class CDNA(TileDevice): ...@@ -30,3 +30,9 @@ class CDNA(TileDevice):
self.transaction_size: List[int] = [32, 128] # in bytes self.transaction_size: List[int] = [32, 128] # in bytes
self.bandwidth: List[int] = [1300, 14000] self.bandwidth: List[int] = [1300, 14000]
__all__ = [
'is_cdna_arch',
'CDNA',
]
...@@ -18,3 +18,9 @@ class CPU(TileDevice): ...@@ -18,3 +18,9 @@ class CPU(TileDevice):
raise RuntimeError("Cannot find cpu device 0.") raise RuntimeError("Cannot find cpu device 0.")
self.device: tvm.runtime.Device = device self.device: tvm.runtime.Device = device
self.platform: str = "CPU" self.platform: str = "CPU"
__all__ = [
'is_cpu_arch',
'CPU',
]
...@@ -145,3 +145,15 @@ class CUDA(TileDevice): ...@@ -145,3 +145,15 @@ class CUDA(TileDevice):
def __repr__(self): def __repr__(self):
return f"CUDA({self.target})" return f"CUDA({self.target})"
__all__ = [
'is_cuda_arch',
'is_volta_arch',
'is_ampere_arch',
'is_ada_arch',
'is_hopper_arch',
'is_tensorcore_supported_precision',
'has_mma_support',
"CUDA",
]
from tvm.target import Target
from .arch_base import TileDevice
def is_metal_arch(arch: TileDevice) -> bool:
return isinstance(arch, METAL)
class METAL(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = Target(target)
self.target = target
__all__ = [
'is_metal_arch',
'METAL',
]
...@@ -19,6 +19,7 @@ import functools ...@@ -19,6 +19,7 @@ import functools
import os import os
import shutil import shutil
import subprocess import subprocess
import platform
# pylint: disable=invalid-name # pylint: disable=invalid-name
import sys import sys
...@@ -89,6 +90,10 @@ def get_cplus_compiler(): ...@@ -89,6 +90,10 @@ def get_cplus_compiler():
return None return None
def is_darwin():
return platform.system() == 'Darwin'
def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None): def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None):
"""Create shared library. """Create shared library.
......
...@@ -181,6 +181,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> ...@@ -181,6 +181,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu": elif target.kind.name == "webgpu":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
elif target.kind.name == "metal":
device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
else: else:
raise ValueError(f"Target {target.kind.name} is not supported") raise ValueError(f"Target {target.kind.name} is not supported")
......
...@@ -16,6 +16,7 @@ from typing import ( ...@@ -16,6 +16,7 @@ from typing import (
Optional, Optional,
) )
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.jit.adapter.utils import is_metal_target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tvm.target import Target from tvm.target import Target
...@@ -74,6 +75,9 @@ def compile( ...@@ -74,6 +75,9 @@ def compile(
# This path is not a performance critical path, so we can afford to convert the target. # This path is not a performance critical path, so we can afford to convert the target.
target = Target(determine_target(target)) target = Target(determine_target(target))
if is_metal_target(target):
assert execution_backend == 'torch', 'Currently metal target only support `tl.jit(execution_backend="torch")`'
return cached( return cached(
func=func, func=func,
out_idx=out_idx, out_idx=out_idx,
...@@ -264,7 +268,7 @@ def jit( # This is the new public interface ...@@ -264,7 +268,7 @@ def jit( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None. Target host for cross-compilation. Defaults to None.
execution_backend : Literal["dlpack", "ctypes", "cython"], optional execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
Backend for kernel execution and argument passing. Defaults to "cython". Backend for kernel execution and argument passing. Defaults to "cython".
verbose : bool, optional verbose : bool, optional
Enables verbose logging during compilation. Defaults to False. Enables verbose logging during compilation. Defaults to False.
......
...@@ -3,3 +3,4 @@ from .dlpack import TorchDLPackKernelAdapter # noqa: F401 ...@@ -3,3 +3,4 @@ from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401 from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401
from .torch import MetalKernelAdapter # noqa: F401
...@@ -21,11 +21,11 @@ from tvm.relax import TensorType ...@@ -21,11 +21,11 @@ from tvm.relax import TensorType
from tilelang.jit.adapter.base import BaseKernelAdapter from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.jit.adapter.wrapper import TLWrapper from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target, is_metal_target
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler, is_darwin
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -153,7 +153,9 @@ with open(cython_wrapper_path, "r") as f: ...@@ -153,7 +153,9 @@ with open(cython_wrapper_path, "r") as f:
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}") os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
python_include_path = sysconfig.get_path("include") python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler() cc = get_cplus_compiler()
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}" dynamic_flag = '-Wl,-undefined,dynamic_lookup' if is_darwin(
) else '-Wl,--unresolved-symbols=ignore-all'
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing {dynamic_flag} -I{python_include_path} {source_path} -o {temp_path}"
os.system(command) os.system(command)
# rename the temp file to the library file # rename the temp file to the library file
...@@ -450,6 +452,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -450,6 +452,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
device = torch.device("cuda") device = torch.device("cuda")
elif is_cpu_target(self.target): elif is_cpu_target(self.target):
device = torch.device("cpu") device = torch.device("cpu")
elif is_metal_target(self.target):
device = torch.device("mps")
else: else:
raise ValueError(f"Unsupported target: {self.target}") raise ValueError(f"Unsupported target: {self.target}")
......
from .metal import MetalKernelAdapter
__all__ = ['MetalKernelAdapter']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment