Unverified Commit 7fb06776 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Backend] Add metal backend (#799)



* Reset

* Fix other CUDA issue

* fmt

* fmt

* fix cuda error

* fix

* fix

* fmt

* cleanup

* fix

* remove copyright

* trivial update

* readme update

* lint fix

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 394e17d0
from functools import wraps
from typing import Callable, Optional, Union
import torch
from tvm import tir
from tilelang import tvm as tvm
from ..base import BaseKernelAdapter
from tilelang.engine.param import KernelParam
class MetalKernelAdapter(BaseKernelAdapter):
def __init__(
self,
params: list[KernelParam],
result_idx: list[int],
# target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
# host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False,
# pass_configs: Optional[Dict[str, Any]] = None,
# compile_flags: Optional[List[str]] = None
):
self.kernel_global_source = kernel_global_source
self.kernel_name = func_or_mod.__name__ + '_kernel'
self.verbose = verbose
self.block_info = [1, 1, 1]
self.grid_info = [1, 1, 1]
for var, func in device_mod.functions.items():
assert var.name_hint == self.kernel_name
thread_extent = func.attrs['thread_extent']
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag:
self.grid_info["xyz".index(tag[-1])] = extent
break
else:
raise AssertionError(f'no kernel with name {func_or_mod.__name__}')
# print(self.block_info, self.grid_info)
super().__init__(func_or_mod, result_idx=result_idx, params=params)
_kernel = None
def _convert_torch_func(self) -> Callable:
if self._kernel is None:
_kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name)
_threads = [x * y for (x, y) in zip(self.block_info, self.grid_info)]
@wraps(_kernel)
def launcher(*args: torch.Tensor):
return _kernel(
*args,
threads=_threads,
group_size=self.block_info,
)
self._kernel = launcher
return self._kernel
...@@ -60,6 +60,10 @@ def is_cpu_target(target: Target) -> bool: ...@@ -60,6 +60,10 @@ def is_cpu_target(target: Target) -> bool:
return target.kind.name in ["c"] return target.kind.name in ["c"]
def is_metal_target(target: Target) -> bool:
return target.kind.name == "metal"
def get_annotated_mod( def get_annotated_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
......
...@@ -3,8 +3,8 @@ from tilelang import tvm as tvm ...@@ -3,8 +3,8 @@ from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union, Any from typing import Optional, List, Dict, Union, Any
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import (match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target, from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target,
is_cpu_target, get_annotated_mod, pythonic_expr) is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr)
import re import re
import logging import logging
import textwrap import textwrap
...@@ -1066,6 +1066,28 @@ class TLCPUSourceWrapper(object): ...@@ -1066,6 +1066,28 @@ class TLCPUSourceWrapper(object):
raise ValueError("Cannot find primary function in the module.") raise ValueError("Cannot find primary function in the module.")
class TLMetalSourceWrapper(object):
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.pass_configs = pass_configs
self.device_mod = device_mod
self.host_mod = host_mod
self.lib_code = self.update_lib_code(source)
def update_lib_code(self, code: str):
self.lib_code = code
return self.lib_code
class TLWrapper(BaseWrapper): class TLWrapper(BaseWrapper):
""" """
A wrapper class for the TileLang backend. A wrapper class for the TileLang backend.
...@@ -1104,6 +1126,8 @@ class TLWrapper(BaseWrapper): ...@@ -1104,6 +1126,8 @@ class TLWrapper(BaseWrapper):
wrapper_class = TLHIPSourceWrapper wrapper_class = TLHIPSourceWrapper
elif is_cpu_target(self.target): elif is_cpu_target(self.target):
wrapper_class = TLCPUSourceWrapper wrapper_class = TLCPUSourceWrapper
elif is_metal_target(self.target):
wrapper_class = TLMetalSourceWrapper
else: else:
raise ValueError(f"Unsupported platform: {self.arch.platform}") raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class( wrapper = wrapper_class(
......
from typing import Any, Callable, Dict, List, Literal, Optional, Union from typing import Any, Callable, Dict, List, Literal, Optional, Union
from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
...@@ -8,7 +9,7 @@ from tilelang import tvm ...@@ -8,7 +9,7 @@ from tilelang import tvm
from tilelang import env from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
NVRTCKernelAdapter, TorchDLPackKernelAdapter) NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import AVALIABLE_TARGETS, determine_target from tilelang.utils.target import AVALIABLE_TARGETS, determine_target
import logging import logging
...@@ -103,6 +104,7 @@ class JITKernel(object): ...@@ -103,6 +104,7 @@ class JITKernel(object):
"ctypes", "ctypes",
"cython", "cython",
"nvrtc", "nvrtc",
"torch",
], f"Invalid execution backend. {execution_backend}" ], f"Invalid execution backend. {execution_backend}"
if execution_backend == "cython": if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler
...@@ -278,6 +280,20 @@ class JITKernel(object): ...@@ -278,6 +280,20 @@ class JITKernel(object):
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
) )
elif execution_backend == "torch":
assert is_metal_target(target)
adapter = MetalKernelAdapter(
params=artifact.params,
result_idx=out_idx,
# target=target,
func_or_mod=tilelang_func,
# host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source,
verbose=verbose,
# pass_configs=pass_configs,
# compile_flags=compile_flags,
)
else: else:
# Handle invalid backend. # Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}") raise ValueError(f"Invalid execution backend: {execution_backend}")
......
...@@ -54,6 +54,11 @@ class suppress_stdout_stderr: ...@@ -54,6 +54,11 @@ class suppress_stdout_stderr:
self.errnull_file.close() self.errnull_file.close()
IS_CUDA = torch.cuda.is_available()
device = 'cuda:0' if IS_CUDA else 'mps:0'
Event = torch.cuda.Event if IS_CUDA else torch.mps.Event
def do_bench( def do_bench(
fn: Callable, fn: Callable,
warmup: float = 25, warmup: float = 25,
...@@ -92,7 +97,7 @@ def do_bench( ...@@ -92,7 +97,7 @@ def do_bench(
# Initial function call and synchronization # Initial function call and synchronization
fn() fn()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Create L2 cache flush buffer (256 MB) # Create L2 cache flush buffer (256 MB)
# Fast flush uses int32 (4 bytes), regular uses int8 (1 byte) # Fast flush uses int32 (4 bytes), regular uses int8 (1 byte)
...@@ -108,7 +113,8 @@ def do_bench( ...@@ -108,7 +113,8 @@ def do_bench(
cache.zero_() cache.zero_()
fn() fn()
end_event.record() end_event.record()
torch.cuda.synchronize() start_event.synchronize()
end_event.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5 estimate_ms = start_event.elapsed_time(end_event) / 5
# Calculate warmup and repeat counts (minimum 1 iteration each) # Calculate warmup and repeat counts (minimum 1 iteration each)
......
...@@ -5,11 +5,21 @@ import random ...@@ -5,11 +5,21 @@ import random
import torch import torch
import numpy as np import numpy as np
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
from tvm.testing.utils import * from tvm.testing.utils import (requires_cuda, requires_package, requires_llvm, requires_metal,
from tvm.testing.utils import _compose requires_rocm, _compose)
from tilelang.utils.tensor import torch_assert_close as torch_assert_close from tilelang.utils.tensor import torch_assert_close as torch_assert_close
__all__ = [
'requires_package',
'requires_cuda',
'requires_metal',
'requires_rocm',
'requires_llvm',
'main',
'requires_cuda_compute_version',
] + [f'requires_cuda_compute_version_{op}' for op in ('ge', 'gt', 'le', 'lt', 'eq')]
# pytest.main() wrapper to allow running single test file # pytest.main() wrapper to allow running single test file
def main(): def main():
......
import torch
IS_CUDA = torch.cuda.is_available()
IS_MPS = torch.mps.is_available()
def get_current_device():
device = None
if IS_CUDA:
device = torch.cuda.current_device()
elif IS_MPS:
device = "mps:0"
return device
from platform import mac_ver
from typing import Literal, Union from typing import Literal, Union
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang import _ffi_api from tilelang import _ffi_api
...@@ -12,6 +13,7 @@ AVALIABLE_TARGETS = { ...@@ -12,6 +13,7 @@ AVALIABLE_TARGETS = {
"webgpu", "webgpu",
"c", # represent c source backend "c", # represent c source backend
"llvm", "llvm",
"metal",
} }
...@@ -41,6 +43,14 @@ def check_hip_availability() -> bool: ...@@ -41,6 +43,14 @@ def check_hip_availability() -> bool:
return False return False
def check_metal_availability() -> bool:
mac_release, _, arch = mac_ver()
if not mac_release:
return False
# todo: check torch version?
return arch == 'arm64'
def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
return_object: bool = False) -> Union[str, Target]: return_object: bool = False) -> Union[str, Target]:
""" """
...@@ -74,8 +84,10 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", ...@@ -74,8 +84,10 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
return_var = "cuda" return_var = "cuda"
elif is_hip_available: elif is_hip_available:
return_var = "hip" return_var = "hip"
elif check_metal_availability():
return_var = "metal"
else: else:
raise ValueError("No CUDA or HIP available on this system.") raise ValueError("No CUDA or HIP or MPS available on this system.")
else: else:
# Validate the target if it's not "auto" # Validate the target if it's not "auto"
assert isinstance( assert isinstance(
......
...@@ -58,10 +58,11 @@ def adapt_torch2tvm(arg): ...@@ -58,10 +58,11 @@ def adapt_torch2tvm(arg):
def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from .device import get_current_device
def get_tensor(param: KernelParam) -> torch.Tensor: def get_tensor(param: KernelParam) -> torch.Tensor:
dtype: torch.dtype = param.dtype dtype: torch.dtype = param.dtype
device: torch.device = torch.cuda.current_device() device = get_current_device()
if hasattr(param, "shape") and not param.shape: if hasattr(param, "shape") and not param.shape:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment