Unverified Commit 7fb06776 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Backend] Add metal backend (#799)



* Reset

* Fix other CUDA issue

* fmt

* fmt

* fix cuda error

* fix

* fix

* fmt

* cleanup

* fix

* remove copyright

* trivial update

* readme update

* lint fix

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 394e17d0
from functools import wraps
from typing import Callable, Optional, Union
import torch
from tvm import tir
from tilelang import tvm as tvm
from ..base import BaseKernelAdapter
from tilelang.engine.param import KernelParam
class MetalKernelAdapter(BaseKernelAdapter):
def __init__(
self,
params: list[KernelParam],
result_idx: list[int],
# target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
# host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False,
# pass_configs: Optional[Dict[str, Any]] = None,
# compile_flags: Optional[List[str]] = None
):
self.kernel_global_source = kernel_global_source
self.kernel_name = func_or_mod.__name__ + '_kernel'
self.verbose = verbose
self.block_info = [1, 1, 1]
self.grid_info = [1, 1, 1]
for var, func in device_mod.functions.items():
assert var.name_hint == self.kernel_name
thread_extent = func.attrs['thread_extent']
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag:
self.grid_info["xyz".index(tag[-1])] = extent
break
else:
raise AssertionError(f'no kernel with name {func_or_mod.__name__}')
# print(self.block_info, self.grid_info)
super().__init__(func_or_mod, result_idx=result_idx, params=params)
_kernel = None
def _convert_torch_func(self) -> Callable:
if self._kernel is None:
_kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name)
_threads = [x * y for (x, y) in zip(self.block_info, self.grid_info)]
@wraps(_kernel)
def launcher(*args: torch.Tensor):
return _kernel(
*args,
threads=_threads,
group_size=self.block_info,
)
self._kernel = launcher
return self._kernel
......@@ -60,6 +60,10 @@ def is_cpu_target(target: Target) -> bool:
return target.kind.name in ["c"]
def is_metal_target(target: Target) -> bool:
return target.kind.name == "metal"
def get_annotated_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto",
......
......@@ -3,8 +3,8 @@ from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union, Any
from tvm import IRModule
from tvm.target import Target
from .utils import (match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target,
is_cpu_target, get_annotated_mod, pythonic_expr)
from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target,
is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr)
import re
import logging
import textwrap
......@@ -1066,6 +1066,28 @@ class TLCPUSourceWrapper(object):
raise ValueError("Cannot find primary function in the module.")
class TLMetalSourceWrapper(object):
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.pass_configs = pass_configs
self.device_mod = device_mod
self.host_mod = host_mod
self.lib_code = self.update_lib_code(source)
def update_lib_code(self, code: str):
self.lib_code = code
return self.lib_code
class TLWrapper(BaseWrapper):
"""
A wrapper class for the TileLang backend.
......@@ -1104,6 +1126,8 @@ class TLWrapper(BaseWrapper):
wrapper_class = TLHIPSourceWrapper
elif is_cpu_target(self.target):
wrapper_class = TLCPUSourceWrapper
elif is_metal_target(self.target):
wrapper_class = TLMetalSourceWrapper
else:
raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(
......
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target
from tvm.tir import PrimFunc
......@@ -8,7 +9,7 @@ from tilelang import tvm
from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
NVRTCKernelAdapter, TorchDLPackKernelAdapter)
NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import AVALIABLE_TARGETS, determine_target
import logging
......@@ -103,6 +104,7 @@ class JITKernel(object):
"ctypes",
"cython",
"nvrtc",
"torch",
], f"Invalid execution backend. {execution_backend}"
if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler
......@@ -278,6 +280,20 @@ class JITKernel(object):
pass_configs=pass_configs,
compile_flags=compile_flags,
)
elif execution_backend == "torch":
assert is_metal_target(target)
adapter = MetalKernelAdapter(
params=artifact.params,
result_idx=out_idx,
# target=target,
func_or_mod=tilelang_func,
# host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source,
verbose=verbose,
# pass_configs=pass_configs,
# compile_flags=compile_flags,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
......
......@@ -54,6 +54,11 @@ class suppress_stdout_stderr:
self.errnull_file.close()
IS_CUDA = torch.cuda.is_available()
device = 'cuda:0' if IS_CUDA else 'mps:0'
Event = torch.cuda.Event if IS_CUDA else torch.mps.Event
def do_bench(
fn: Callable,
warmup: float = 25,
......@@ -92,7 +97,7 @@ def do_bench(
# Initial function call and synchronization
fn()
torch.cuda.synchronize()
torch.accelerator.synchronize()
# Create L2 cache flush buffer (256 MB)
# Fast flush uses int32 (4 bytes), regular uses int8 (1 byte)
......@@ -108,7 +113,8 @@ def do_bench(
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
start_event.synchronize()
end_event.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# Calculate warmup and repeat counts (minimum 1 iteration each)
......
......@@ -5,11 +5,21 @@ import random
import torch
import numpy as np
from tilelang.contrib import nvcc
from tvm.testing.utils import *
from tvm.testing.utils import _compose
from tvm.testing.utils import (requires_cuda, requires_package, requires_llvm, requires_metal,
requires_rocm, _compose)
from tilelang.utils.tensor import torch_assert_close as torch_assert_close
__all__ = [
'requires_package',
'requires_cuda',
'requires_metal',
'requires_rocm',
'requires_llvm',
'main',
'requires_cuda_compute_version',
] + [f'requires_cuda_compute_version_{op}' for op in ('ge', 'gt', 'le', 'lt', 'eq')]
# pytest.main() wrapper to allow running single test file
def main():
......
import torch
IS_CUDA = torch.cuda.is_available()
IS_MPS = torch.mps.is_available()
def get_current_device():
device = None
if IS_CUDA:
device = torch.cuda.current_device()
elif IS_MPS:
device = "mps:0"
return device
from platform import mac_ver
from typing import Literal, Union
from tilelang import tvm as tvm
from tilelang import _ffi_api
......@@ -12,6 +13,7 @@ AVALIABLE_TARGETS = {
"webgpu",
"c", # represent c source backend
"llvm",
"metal",
}
......@@ -41,6 +43,14 @@ def check_hip_availability() -> bool:
return False
def check_metal_availability() -> bool:
mac_release, _, arch = mac_ver()
if not mac_release:
return False
# todo: check torch version?
return arch == 'arm64'
def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
return_object: bool = False) -> Union[str, Target]:
"""
......@@ -74,8 +84,10 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
return_var = "cuda"
elif is_hip_available:
return_var = "hip"
elif check_metal_availability():
return_var = "metal"
else:
raise ValueError("No CUDA or HIP available on this system.")
raise ValueError("No CUDA or HIP or MPS available on this system.")
else:
# Validate the target if it's not "auto"
assert isinstance(
......
......@@ -58,10 +58,11 @@ def adapt_torch2tvm(arg):
def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
from tilelang.engine.param import KernelParam
from .device import get_current_device
def get_tensor(param: KernelParam) -> torch.Tensor:
dtype: torch.dtype = param.dtype
device: torch.device = torch.cuda.current_device()
device = get_current_device()
if hasattr(param, "shape") and not param.shape:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment