"include/vscode:/vscode.git/clone" did not exist on "a054f7d604d3bfee9e4ad410df15397bc354ae3d"
Unverified Commit 7fb06776 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Backend] Add metal backend (#799)



* Reset

* Fix other CUDA issue

* fmt

* fmt

* fix cuda error

* fix

* fix

* fmt

* cleanup

* fix

* remove copyright

* trivial update

* readme update

* lint fix

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 394e17d0
name: CI Test on Metal
on: [pull_request]
env:
PYTHON_VERSION: '3.12'
VENV_DIR: tilelang_ci
jobs:
format-check:
runs-on: [macos-latest]
permissions:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: recursive
- name: Install python via uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
ignore-nothing-to-cache: true
activate-environment: true
python-version: ${{ env.PYTHON_VERSION }}
- name: Ensure venv (local & persistent)
run: |
[[ -f requirements-test.txt ]] && \
uv pip install -r requirements-test.txt --no-build-isolation
- name: Run format check
run: |
set -ex
mkdir -p build
# run cmake to create the build directory with compile_commands.json
cd build; cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DUSE_METAL=ON; cd ..
if ! output=$(./format.sh 2>&1); then
echo "------------------------------------"
echo "message:"
echo "$output"
printf '%s\n' "$output"
echo "------------------------------------"
exit 1
fi
build-test-metal:
runs-on: [macos-latest]
needs: format-check
permissions:
contents: read
env:
CMAKE_C_COMPILER_LAUNCHER: ccache
CMAKE_CXX_COMPILER_LAUNCHER: ccache
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
submodules: recursive
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
create-symlink: true
key: ${{ github.job }}-${{ matrix.os }}
- name: Install python via uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
ignore-nothing-to-cache: true
activate-environment: true
python-version: ${{ env.PYTHON_VERSION }}
- name: Ensure venv (local & persistent)
run: uv pip install -r requirements-test.txt -r requirements-build.txt
- name: Build wheel
run: |
source .venv/bin/activate
uv pip install -v --no-build-isolation .
- name: Run metal test
run: |
cd testing/python
unset PYTHONPATH
python -m pytest -k metal -v -r fE --durations=0 --timeout=3600
......@@ -108,13 +108,21 @@ endif()
if(DEFINED TVM_PREBUILD_PATH)
message(STATUS "Using prebuilt TVM from ${TVM_PREBUILD_PATH}")
add_library(tvm SHARED IMPORTED)
find_library(TVM_LIBRARY_LOCATION
NAMES tvm
HINTS "${TVM_PREBUILD_PATH}"
)
set_target_properties(tvm PROPERTIES
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm.so"
IMPORTED_LOCATION "${TVM_LIBRARY_LOCATION}"
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
)
add_library(tvm_runtime SHARED IMPORTED)
find_library(TVM_RUNTIME_LIBRARY_LOCATION
NAMES tvm_runtime
HINTS "${TVM_PREBUILD_PATH}"
)
set_target_properties(tvm_runtime PROPERTIES
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm_runtime.so"
IMPORTED_LOCATION "${TVM_RUNTIME_LIBRARY_LOCATION}"
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
)
else()
......@@ -157,6 +165,13 @@ if(USE_ROCM)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_HIP_SRCS})
endif()
if(USE_METAL)
tilelang_file_glob(GLOB TILE_LANG_METAL_SRCS
src/target/rt_mod_metal.cc
)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS})
endif()
message(STATUS "Collected source files: ${TILE_LANG_SRCS}")
# Add TileLang object library
......@@ -221,6 +236,9 @@ target_compile_definitions(tilelang_objs PRIVATE -DTILE_LANG_EXPORTS)
# Shared library
add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
target_link_libraries(tilelang PUBLIC tvm_runtime)
if(USE_METAL)
target_link_libraries(tilelang PUBLIC tvm)
endif()
# Static library
add_library(tilelang_static STATIC $<TARGET_OBJECTS:tilelang_objs>)
......
......@@ -13,6 +13,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png />
## Latest News
- 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details.
- 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported!
Check out the preview here:
🔗 [link](https://github.com/tile-ai/tilelang-ascend).
......
#!/bin/bash
set -eux
git submodule update --init --recursive
rm -rf build
mkdir build
cp 3rdparty/tvm/cmake/config.cmake build
cd build
echo "set(USE_METAL ON)" >> config.cmake
CMAKE_C_COMPILER_LAUNCHER=ccache CMAKE_CXX_COMPILER_LAUNCHER=ccache cmake ..
CORES=$(sysctl -n hw.logicalcpu)
MAKE_JOBS=$(( CORES / 2 ))
make -j${MAKE_JOBS}
......@@ -32,19 +32,60 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def _read_bool_env(name: str, default: bool = False) -> bool:
if env := os.environ.get(name):
env = env.lower()
if env in ['on', '1', 'true']:
return True
elif env in ['', 'off', '0', 'false']:
return False
return default
# Environment variables False/True
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true"
PYPI_BUILD = _read_bool_env('PYPI_BUILD')
PACKAGE_NAME = "tilelang"
ROOT_DIR = os.path.dirname(__file__)
CYCACHE = Path(os.path.join(ROOT_DIR, "tilelang", "jit", "adapter", "cython", ".cycache"))
if not CYCACHE.exists():
# tvm may needs this, we won't always build cython backend so mkdir here.
CYCACHE.mkdir(exist_ok=True)
IS_LINUX = platform.system() == 'Linux'
MAYBE_METAL = platform.mac_ver()[2] == 'arm64'
# Add LLVM control environment variable
USE_LLVM = os.environ.get("USE_LLVM", "False").lower() == "true"
USE_LLVM = _read_bool_env('USE_LLVM')
# Add ROCM control environment variable
USE_ROCM = _read_bool_env("USE_ROCM")
# Add ROCM control environment variable
USE_ROCM = os.environ.get("USE_ROCM", "False").lower() == "true"
USE_METAL = _read_bool_env("USE_METAL", MAYBE_METAL)
# Add ROCM control environment variable
USE_CUDA = _read_bool_env("USE_CUDA", IS_LINUX and not USE_ROCM)
# Build with Debug mode
DEBUG_MODE = os.environ.get("DEBUG_MODE", "False").lower() == "true"
DEBUG_MODE = _read_bool_env('DEBUG_MODE')
# Include commit ID in wheel filename and package metadata
WITH_COMMITID = os.environ.get("WITH_COMMITID", "True").lower() == "true"
WITH_COMMITID = _read_bool_env("WITH_COMMITID")
TVM_PREBUILD_ITEMS = [
"libtvm_runtime.so",
"libtvm.so",
"libtilelang.so",
"libtilelang_module.so",
] if IS_LINUX else [
"libtvm_runtime.dylib",
"libtvm.dylib",
"libtilelang.dylib",
"libtilelang_module.dylib",
]
# from tvm's internal cython?
TVM_PREBUILD_ITEMS_TO_DELETE = [] if IS_LINUX else [
'libtvm_runtime.dylib.dSYM',
'libtvm.dylib.dSYM',
]
def load_module_from_path(module_name, path):
......@@ -65,24 +106,17 @@ if USE_ROCM and not ROCM_HOME:
raise ValueError(
"ROCM support is enabled (USE_ROCM=True) but ROCM_HOME is not set or detected.")
if not USE_ROCM and not CUDA_HOME:
if USE_CUDA and not CUDA_HOME:
raise ValueError(
"CUDA support is enabled by default (USE_ROCM=False) but CUDA_HOME is not set or detected.")
"CUDA support is enabled by default on linux if `USE_ROCM=False`," \
" but CUDA_HOME is not set or detected.")
# Ensure one of CUDA or ROCM is available
if not (CUDA_HOME or ROCM_HOME):
if IS_LINUX and not (CUDA_HOME or ROCM_HOME):
raise ValueError(
"Failed to automatically detect CUDA or ROCM installation. Please set the CUDA_HOME or ROCM_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda or export ROCM_HOME=/opt/rocm)."
)
# TileLang only supports Linux platform
assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)."
def _is_linux_like():
return (sys.platform == "darwin" or sys.platform.startswith("linux") or
sys.platform.startswith("freebsd"))
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
......@@ -144,7 +178,9 @@ def get_rocm_version():
return Version("5.0.0")
def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str:
def get_tilelang_version(with_cuda=USE_CUDA,
with_system_info=not MAYBE_METAL,
with_commit_id=False) -> str:
version = find_version(get_path(".", "VERSION"))
local_version_parts = []
if with_system_info:
......@@ -194,9 +230,6 @@ def get_cplus_compiler():
The path to the default C/C++ compiler, or None if none was found.
"""
if not _is_linux_like():
return None
env_cxx = os.environ.get("CXX") or os.environ.get("CC")
if env_cxx:
return env_cxx
......@@ -371,6 +404,8 @@ def patch_libs(libpath):
and have a hard-coded rpath.
Set rpath to the directory of libs so auditwheel works well.
"""
if not IS_LINUX:
return
# check if patchelf is installed
# find patchelf in the system
patchelf_path = shutil.which("patchelf")
......@@ -432,13 +467,6 @@ class TileLangBuilPydCommand(build_py):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
TVM_PREBUILD_ITEMS = [
"libtvm_runtime.so",
"libtvm.so",
"libtilelang.so",
"libtilelang_module.so",
]
potential_dirs = [
ext_output_dir,
self.build_lib,
......@@ -468,6 +496,14 @@ class TileLangBuilPydCommand(build_py):
else:
logger.info(f"WARNING: {item} not found in any expected directories!")
for item in TVM_PREBUILD_ITEMS_TO_DELETE:
source_lib_file = None
for dir in potential_dirs:
candidate = os.path.join(dir, item)
if os.path.exists(candidate):
shutil.rmtree(candidate)
break
TVM_CONFIG_ITEMS = [
f"{build_temp_dir}/config.cmake",
]
......@@ -587,10 +623,10 @@ class CMakeExtension(Extension):
:param sourcedir: Directory containing the top-level CMakeLists.txt.
"""
def __init__(self, name, sourcedir=""):
def __init__(self, name, sourcedir="", **kwargs):
# We pass an empty 'sources' list because
# the actual build is handled by CMake, not setuptools.
super().__init__(name=name, sources=[])
super().__init__(name=name, sources=[], **kwargs)
# Convert the source directory to an absolute path
# so that CMake can correctly locate the CMakeLists.txt.
......@@ -642,7 +678,7 @@ class TilelangExtensionBuild(build_ext):
# To make it works with editable install,
# we need to copy the lib*.so files to the tilelang/lib directory
import glob
files = glob.glob("*.so")
files = glob.glob("*.so" if IS_LINUX else "*.dylib")
if os.path.exists(PACKAGE_NAME):
target_lib_dir = os.path.join(PACKAGE_NAME, "lib")
for file in files:
......@@ -724,7 +760,10 @@ class TilelangExtensionBuild(build_ext):
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler()
if MAYBE_METAL:
cc += ' -Wl,-undefined,dynamic_lookup'
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}"
logger.info(command)
os.system(command)
# rename the temp file to the library file
......@@ -783,7 +822,7 @@ class TilelangExtensionBuild(build_ext):
"-G",
"Ninja",
]
if not USE_ROCM:
if USE_CUDA and not USE_ROCM:
cmake_args.append(f"-DCMAKE_CUDA_COMPILER={os.path.join(CUDA_HOME, 'bin', 'nvcc')}")
# Create the temporary build directory (if it doesn't exist).
......@@ -804,12 +843,17 @@ class TilelangExtensionBuild(build_ext):
content_lines.append(f"set(USE_LLVM {llvm_config_path})")
# Append GPU backend configuration based on environment
if USE_ROCM:
if USE_METAL:
content_lines += [
"set(USE_METAL ON)",
"set(USE_ROCM OFF)",
]
elif USE_ROCM:
content_lines += [
f"set(USE_ROCM {ROCM_HOME})",
"set(USE_CUDA OFF)",
]
else:
elif CUDA_HOME:
content_lines += [
f"set(USE_CUDA {CUDA_HOME})",
"set(USE_ROCM OFF)",
......@@ -846,6 +890,12 @@ class TilelangExtensionBuild(build_ext):
cwd=build_temp)
ext_modules = [
CMakeExtension("TileLangCXX", sourcedir="."),
]
if not MAYBE_METAL:
ext_modules.append(CythonExtension("TileLangCython", sourcedir="."))
setup(
name=PACKAGE_NAME,
version=(get_tilelang_version(with_cuda=False, with_system_info=False, with_commit_id=False)
......
// Currently mps backend use the codegen from tvm without modification.
// But in the future we're likely to add functions on top of that.
// Added an empty file for now.
import tilelang
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
import torch
@tilelang.jit(execution_backend='torch')
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope='shared')
B_shared = T.alloc_shared((block_K, block_N), dtype, scope='shared')
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2)
T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2)
for i, j, k in T.Parallel(block_M, block_N, block_K):
C_local[i, j] += A_shared[i, k] * B_shared[k, j]
T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2)
return gemm
def assert_gemm(
M,
N,
K,
block_M,
block_N,
block_K,
dtype="float32",
accum_dtype="float",
atol=1e-8,
):
jit_kernel = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
torch_dtype = getattr(torch, dtype)
a, b = None, None
if 'int' in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device='mps')
b = torch.randint(100, (K, N), dtype=torch_dtype, device='mps')
else:
a = torch.randn(M, K, dtype=torch_dtype, device='mps')
b = torch.randn(K, N, dtype=torch_dtype, device='mps')
c = torch.zeros(M, N, dtype=torch_dtype, device='mps')
jit_kernel(a, b, c)
assert torch.allclose(a @ b, c, atol=atol)
assert jit_kernel.kernel_source is not None
@tilelang.testing.requires_metal
def test_gemm_float32():
assert_gemm(1024, 1024, 1024, 16, 16, 16)
@tilelang.testing.requires_metal
def test_gemm_float16():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='float16', atol=1)
@tilelang.testing.requires_metal
def test_gemm_int32():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='int32', atol=1)
if __name__ == "__main__":
if torch.mps.is_available():
tilelang.testing.main()
......@@ -465,13 +465,24 @@ class AutoTuner:
futures = []
future_to_index = {}
def device_wrapper(func, device, **config_arg):
torch.cuda.set_device(device)
return func(**config_arg)
def cuda_device_wrapper(func, device):
def inner(**config_arg):
torch.cuda.set_device(device)
return func(**config_arg)
return inner
for i, config_arg in enumerate(config_args):
compile_func = self.jit_compile
if torch.cuda.is_available():
device = torch.cuda.current_device()
compile_func = cuda_device_wrapper(self.jit_compile, device)
future = pool.submit(
functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()),
compile_func,
**config_arg,
)
futures.append(future)
......@@ -543,7 +554,7 @@ class AutoTuner:
func=best_kernel.prim_func,
kernel=best_kernel)
if self.compile_args.execution_backend == "dlpack":
if self.compile_args.execution_backend in ("dlpack", "torch"):
logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
......
......@@ -191,8 +191,8 @@ class KernelCache:
pass_configs=pass_configs,
compile_flags=compile_flags,
)
if execution_backend == "dlpack":
self.logger.warning("DLPack backend does not support cache saving to disk.")
if execution_backend in ("dlpack", "torch"):
self.logger.warning("DLPack or torch backend does not support cache saving to disk.")
else:
with self._lock:
if env.is_cache_enabled():
......
from .arch_base import TileDevice
from .cuda import CUDA
from .cpu import CPU
from .cdna import CDNA
from .cuda import *
from .cpu import *
from .cdna import *
from .metal import *
from typing import Union
from tvm.target import Target
import torch
......@@ -17,6 +18,8 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
return CPU(target)
elif target.kind.name == "hip":
return CDNA(target)
elif target.kind.name == "metal":
return METAL(target)
else:
raise ValueError(f"Unsupported target: {target.kind.name}")
......@@ -28,18 +31,25 @@ def auto_infer_current_arch() -> TileDevice:
return get_arch("hip")
if torch.cuda.is_available():
return get_arch("cuda")
elif torch.mps.is_available():
return get_arch("metal")
else:
return get_arch("llvm")
from .cpu import is_cpu_arch # noqa: F401
from .cuda import (
is_cuda_arch, # noqa: F401
is_volta_arch, # noqa: F401
is_ampere_arch, # noqa: F401
is_ada_arch, # noqa: F401
is_hopper_arch, # noqa: F401
is_tensorcore_supported_precision, # noqa: F401
has_mma_support, # noqa: F401
)
from .cdna import is_cdna_arch # noqa: F401
__all__ = [
'is_cpu_arch',
'is_cuda_arch',
'is_volta_arch',
'is_ampere_arch',
'is_ada_arch',
'is_hopper_arch',
'is_tensorcore_supported_precision',
'has_mma_support',
'is_cdna_arch',
'is_metal_arch',
'CUDA',
'CDNA',
'METAL',
'CPU',
]
......@@ -30,3 +30,9 @@ class CDNA(TileDevice):
self.transaction_size: List[int] = [32, 128] # in bytes
self.bandwidth: List[int] = [1300, 14000]
__all__ = [
'is_cdna_arch',
'CDNA',
]
......@@ -18,3 +18,9 @@ class CPU(TileDevice):
raise RuntimeError("Cannot find cpu device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "CPU"
__all__ = [
'is_cpu_arch',
'CPU',
]
......@@ -145,3 +145,15 @@ class CUDA(TileDevice):
def __repr__(self):
return f"CUDA({self.target})"
__all__ = [
'is_cuda_arch',
'is_volta_arch',
'is_ampere_arch',
'is_ada_arch',
'is_hopper_arch',
'is_tensorcore_supported_precision',
'has_mma_support',
"CUDA",
]
from tvm.target import Target
from .arch_base import TileDevice
def is_metal_arch(arch: TileDevice) -> bool:
return isinstance(arch, METAL)
class METAL(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = Target(target)
self.target = target
__all__ = [
'is_metal_arch',
'METAL',
]
......@@ -19,6 +19,7 @@ import functools
import os
import shutil
import subprocess
import platform
# pylint: disable=invalid-name
import sys
......@@ -89,6 +90,10 @@ def get_cplus_compiler():
return None
def is_darwin():
return platform.system() == 'Darwin'
def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None):
"""Create shared library.
......
......@@ -181,6 +181,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
elif target.kind.name == "metal":
device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
......
......@@ -16,6 +16,7 @@ from typing import (
Optional,
)
from tilelang import tvm as tvm
from tilelang.jit.adapter.utils import is_metal_target
from tvm.tir import PrimFunc
from tvm.target import Target
......@@ -74,6 +75,9 @@ def compile(
# This path is not a performance critical path, so we can afford to convert the target.
target = Target(determine_target(target))
if is_metal_target(target):
assert execution_backend == 'torch', 'Currently metal target only support `tl.jit(execution_backend="torch")`'
return cached(
func=func,
out_idx=out_idx,
......@@ -264,7 +268,7 @@ def jit( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None.
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
Backend for kernel execution and argument passing. Defaults to "cython".
verbose : bool, optional
Enables verbose logging during compilation. Defaults to False.
......
......@@ -3,3 +3,4 @@ from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401
from .torch import MetalKernelAdapter # noqa: F401
......@@ -21,11 +21,11 @@ from tvm.relax import TensorType
from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target, is_metal_target
from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type
from tilelang.contrib.cc import get_cplus_compiler
from tilelang.contrib.cc import get_cplus_compiler, is_darwin
logger = logging.getLogger(__name__)
......@@ -153,7 +153,9 @@ with open(cython_wrapper_path, "r") as f:
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler()
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}"
dynamic_flag = '-Wl,-undefined,dynamic_lookup' if is_darwin(
) else '-Wl,--unresolved-symbols=ignore-all'
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing {dynamic_flag} -I{python_include_path} {source_path} -o {temp_path}"
os.system(command)
# rename the temp file to the library file
......@@ -450,6 +452,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
device = torch.device("cuda")
elif is_cpu_target(self.target):
device = torch.device("cpu")
elif is_metal_target(self.target):
device = torch.device("mps")
else:
raise ValueError(f"Unsupported target: {self.target}")
......
from .metal import MetalKernelAdapter
__all__ = ['MetalKernelAdapter']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment