Commit 39fc5a6d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev][jit] Introduce jit for kernel functions (#12)

* instruction update

* replace link with TileLang/tile-lang

* [Dev][Adapter] Implement Torch DLPack Kernel Adapter and related utilities

* lint fix

* Implement JIT Compiler Components

* Documents update

* lint fix

* update logo

* install script fix
parent 18718446
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from pathlib import Path
from typing import List, Union
import torch.utils.cpp_extension as torch_cpp_ext
from filelock import FileLock
from .env import CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH, TILELANG_JIT_DIR
from contextlib import suppress
class TileLangJITLogger(logging.Logger):
def __init__(self, name):
super().__init__(name)
self.setLevel(logging.INFO)
# Add a StreamHandler for console output
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
self.addHandler(stream_handler)
def info(self, msg):
super().info("tilelang.jit: " + msg)
logger = TileLangJITLogger("tilelang.jit")
def check_cuda_arch():
# cuda arch check for fp8 at the moment.
for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): # noqa: B007
pass
def remove_unwanted_pytorch_nvcc_flags():
REMOVE_NVCC_FLAGS = [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
for flag in REMOVE_NVCC_FLAGS:
try:
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
except ValueError:
suppress(ValueError)
remove_unwanted_pytorch_nvcc_flags()
sm90a_nvcc_flags = ["-gencode", "arch=compute_90a,code=sm_90a"]
def load_cuda_ops(
name: str,
sources: List[Union[str, Path]],
extra_cflags: List[str] = None,
extra_cuda_cflags: List[str] = None,
extra_ldflags=None,
extra_include_paths=None,
verbose=False,
):
if extra_cflags is None:
extra_cflags = []
if extra_cuda_cflags is None:
extra_cuda_cflags = []
cflags = ["-O3", "-Wno-switch-bool"]
cuda_cflags = [
"-O3",
"-std=c++17",
"-use_fast_math",
]
cflags += extra_cflags
cuda_cflags += extra_cuda_cflags
check_cuda_arch()
build_directory = TILELANG_JIT_DIR / name
os.makedirs(build_directory, exist_ok=True)
if extra_include_paths is None:
extra_include_paths = [
CUTLASS_INCLUDE_DIR,
TILELANG_TEMPLATE_PATH,
]
lock = FileLock(TILELANG_JIT_DIR / f"{name}.lock", thread_local=False)
with lock:
module = torch_cpp_ext.load(
name,
list(map(lambda _: str(_), sources)),
extra_cflags=cflags,
extra_cuda_cflags=cuda_cflags,
extra_ldflags=extra_ldflags,
extra_include_paths=list(map(lambda _: str(_), extra_include_paths)),
build_directory=build_directory,
verbose=verbose,
with_cuda=True,
keep_intermediates=False,
)
logger.info(f"Finished loading JIT ops: {name}")
return module
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Library information. This is a standalone file that can be used to get various info.
Modified from flashinfer
"""
import pathlib
import re
import warnings
from torch.utils.cpp_extension import _get_cuda_arch_flags
from tilelang.env import (
CUTLASS_INCLUDE_DIR, # noqa: F401
TILELANG_TEMPLATE_PATH, # noqa: F401
)
def _initialize_torch_cuda_arch_flags():
import os
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target
target = determine_target(return_object=True)
# create tmp source file for torch cpp extension
compute_version = "".join(nvcc.get_target_compute_version(target).split("."))
# set TORCH_CUDA_ARCH_LIST
major = compute_version[0]
minor = compute_version[1]
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"
def _get_workspace_dir_name() -> pathlib.Path:
try:
with warnings.catch_warnings():
# Ignore the warning for TORCH_CUDA_ARCH_LIST not set
warnings.filterwarnings("ignore", r".*TORCH_CUDA_ARCH_LIST.*", module="torch")
flags = _get_cuda_arch_flags()
arch = "_".join(sorted(set(re.findall(r"compute_(\d+)", "".join(flags)))))
except Exception:
arch = "noarch"
# e.g.: $HOME/.cache/tilelang/75_80_89_90/
return pathlib.Path.home() / ".cache" / "tilelang" / arch
# use pathlib
_initialize_torch_cuda_arch_flags()
TILELANG_JIT_WORKSPACE_DIR = _get_workspace_dir_name()
TILELANG_JIT_DIR = TILELANG_JIT_WORKSPACE_DIR / "cached_ops"
TILELANG_GEN_SRC_DIR = TILELANG_JIT_WORKSPACE_DIR / "generated"
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Union, Any, Callable, Literal
from tvm.target import Target
import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tilelang.jit.adapter import TorchCPPKernelAdapter, TorchDLPackKernelAdapter, BaseKernelAdapter
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.profiler import Profiler, TensorSupplyType
class JITKernel(object):
"""
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
Attributes
----------
rt_module : tvm.runtime.Module
The runtime module compiled by TVM.
rt_params : dict
Parameters for the compiled runtime module (e.g., weights or constants).
torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function.
"""
rt_module: tvm.runtime.Module = None
rt_params: dict = None
adapter: BaseKernelAdapter = None
torch_function: Callable = None
def __init__(
self,
func: PrimFunc = None,
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dl_pack", "torch_cpp", "ctypes"] = "dl_pack",
target: Union[str, Target] = "auto",
verbose: bool = False,
):
"""
Initializes a TorchFunction instance.
Parameters
----------
func : tvm.tir.PrimFunc, optional
The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dl_pack", "torch_cpp", "ctypes"], optional
Execution backend to use for kernel execution (default: "dl_pack").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
verbose : bool, optional
Whether to enable verbose output (default: False).
"""
self.func = func
self.out_idx = out_idx
self.execution_backend = execution_backend
self.target = target
self.verbose = verbose
# If the target is specified as a string, validate it and convert it to a TVM Target.
if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
target = determine_target(target)
# Ensure the target is always a TVM Target object.
target = Target(target)
# Validate the execution backend.
assert execution_backend in ["dl_pack", "torch_cpp", "ctypes"], "Invalid execution backend."
# Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func)
# The adapter's function is assigned as the callable function for this instance.
self.adapter = adapter
self.torch_function = adapter.func
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""
Invokes the compiled function with the given arguments.
Parameters
----------
*args : Any
Positional arguments for the function.
**kwds : Any
Keyword arguments for the function.
Returns
-------
Any
The result of the function execution.
"""
return self.torch_function(*args, **kwds)
def _compile_and_create_adapter(self, tilelang_func: PrimFunc) -> BaseKernelAdapter:
"""
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
Parameters
----------
tilelang_func : tvm.tir.PrimFunc
The TileLang (TVM TIR) function to compile.
Returns
-------
BaseKernelAdapter
The compiled and ready-to-run kernel adapter.
"""
verbose = self.verbose
target = self.target
out_idx = self.out_idx
execution_backend = self.execution_backend
# Compile the function with TVM, optimizing with shared memory lowering.
with tvm.transform.PassContext(opt_level=3):
rt_mod, params = tilelang.lower(tilelang_func, target=target)
# Store the runtime module and parameters for later use.
self.rt_module = rt_mod
self.rt_params = params
# Create an adapter based on the specified execution backend.
if execution_backend == "dl_pack":
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx)
elif execution_backend == "torch_cpp":
# Torch CPP backend adapter (not fully implemented yet).
adapter = TorchCPPKernelAdapter(
rt_mod,
params=params,
result_idx=out_idx,
target=target,
prim_func=tilelang_func,
verbose=verbose,
)
raise NotImplementedError("Torch CPP backend is not fully implemented.")
elif execution_backend == "ctypes":
# CTYPES backend (not implemented yet).
raise NotImplementedError("CTypes backend is not implemented.")
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
return adapter
@classmethod
def from_tilelang_function(cls, tilelang_func: PrimFunc, **kwargs):
"""
Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.
Parameters
----------
tilelang_func : tvm.tir.PrimFunc
The TileLang (TVM TIR) function to compile.
**kwargs : dict
Additional keyword arguments to pass to the constructor.
Returns
-------
TorchFunction
An instance of TorchFunction wrapping the compiled function.
"""
return cls(func=tilelang_func, **kwargs)
def get_profiler(self,
tensor_supply_type: TensorSupplyType = TensorSupplyType.Integer) -> Profiler:
"""
Creates a profiler to benchmark the compiled runtime module.
Parameters
----------
tensor_supply_type : TensorSupplyType, optional
The type of input tensors to supply for profiling (default: TensorSupplyType.Integer).
Returns
-------
Profiler
A Profiler instance for benchmarking the runtime module.
"""
return Profiler(self.rt_module, self.rt_params, self.out_idx, tensor_supply_type)
def get_kernel_source(self) -> str:
"""
Returns the source code of the compiled kernel function.
Returns
-------
str
The source code of the compiled kernel function.
"""
return self.rt_module.imported_modules[0].get_source()
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT License. # Licensed under the MIT License.
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from typing import Any, List, Literal from typing import List, Literal
from functools import partial from functools import partial
import torch import torch
from contextlib import suppress from contextlib import suppress
...@@ -11,8 +11,8 @@ import tvm ...@@ -11,8 +11,8 @@ import tvm
from torch.utils.dlpack import to_dlpack from torch.utils.dlpack import to_dlpack
from tvm.runtime import ndarray from tvm.runtime import ndarray
from tvm.relay import TensorType from tvm.relay import TensorType
from tvm.contrib.dlpack import to_pytorch_func
from tilelang.jit.adapter import TorchDLPackKernelAdapter
from tilelang.utils.tensor import ( from tilelang.utils.tensor import (
get_tensor_supply, get_tensor_supply,
TensorSupplyType, TensorSupplyType,
...@@ -20,53 +20,7 @@ from tilelang.utils.tensor import ( ...@@ -20,53 +20,7 @@ from tilelang.utils.tensor import (
) )
class ConvertTorch: class Profiler(TorchDLPackKernelAdapter):
def __init__(self, mod, params: List[TensorType], result_idx: List[int]) -> None:
self.mod = mod
self.params = params
self.result_idx = result_idx
self.func = self._convert_torch_func()
def _convert_torch_func(self) -> callable:
torch_func = to_pytorch_func(self.mod)
def func(*ins: List[torch.Tensor]):
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
)
ins_idx = 0
args = []
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
for i in range(len(self.params)):
if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype))
shape = list(map(int, self.params[i].shape))
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
tensor = ins[ins_idx]
ins_idx += 1
args.append(tensor)
torch_func(*args)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
else:
return [args[i] for i in self.result_idx]
return func
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.func(*args, **kwds)
def get_kernel_source(self) -> str:
return self.mod.imported_modules[0].get_source()
class Profiler(ConvertTorch):
def __init__( def __init__(
self, self,
...@@ -145,7 +99,7 @@ class Profiler(ConvertTorch): ...@@ -145,7 +99,7 @@ class Profiler(ConvertTorch):
def do_bench( def do_bench(
self, self,
func: callable, func: callable = None,
warmup=25, warmup=25,
rep=100, rep=100,
n_warmup=1, n_warmup=1,
...@@ -153,6 +107,11 @@ class Profiler(ConvertTorch): ...@@ -153,6 +107,11 @@ class Profiler(ConvertTorch):
profiler: Literal["torch", "tvm", "auto"] = "auto", profiler: Literal["torch", "tvm", "auto"] = "auto",
input_tensors: List[torch.Tensor] = None, input_tensors: List[torch.Tensor] = None,
): ):
if func is None:
# set default value if not provided
func = self.mod
profiler = "tvm"
if profiler == "torch": if profiler == "torch":
ins = self._get_inputs() if input_tensors is None else input_tensors ins = self._get_inputs() if input_tensors is None else input_tensors
bench_func = partial(func, *ins) bench_func = partial(func, *ins)
...@@ -179,6 +138,8 @@ class Profiler(ConvertTorch): ...@@ -179,6 +138,8 @@ class Profiler(ConvertTorch):
# Transform Latency to ms # Transform Latency to ms
return time_evaluator(*tvm_inputs).mean * 1e3 return time_evaluator(*tvm_inputs).mean * 1e3
elif profiler == "auto": elif profiler == "auto":
# TODO(lei): select appropriate profiler based on the function
# class
ins = self._get_inputs() ins = self._get_inputs()
bench_func = partial(func, *ins) bench_func = partial(func, *ins)
torch_res = do_bench( torch_res = do_bench(
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from .target import determine_target # noqa: F401 from .target import determine_target # noqa: F401
from .profiler import Profiler # noqa: F401
from .tensor import TensorSupplyType, torch_assert_close # noqa: F401 from .tensor import TensorSupplyType, torch_assert_close # noqa: F401
from .language import ( from .language import (
is_global, # noqa: F401 is_global, # noqa: F401
......
...@@ -42,7 +42,8 @@ def check_hip_availability() -> bool: ...@@ -42,7 +42,8 @@ def check_hip_availability() -> bool:
return False return False
def determine_target(target: Union[str, Target, Literal["auto"]]) -> Union[str, Target]: def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
return_object: bool = False) -> Union[str, Target]:
""" """
Determine the appropriate target for compilation (CUDA, HIP, or manual selection). Determine the appropriate target for compilation (CUDA, HIP, or manual selection).
...@@ -58,6 +59,9 @@ def determine_target(target: Union[str, Target, Literal["auto"]]) -> Union[str, ...@@ -58,6 +59,9 @@ def determine_target(target: Union[str, Target, Literal["auto"]]) -> Union[str,
ValueError: If no CUDA or HIP is available and the target is "auto". ValueError: If no CUDA or HIP is available and the target is "auto".
AssertionError: If the target is invalid. AssertionError: If the target is invalid.
""" """
return_var: Union[str, Target] = target
if target == "auto": if target == "auto":
# Check for CUDA and HIP availability # Check for CUDA and HIP availability
is_cuda_available = check_cuda_availability() is_cuda_available = check_cuda_availability()
...@@ -65,13 +69,18 @@ def determine_target(target: Union[str, Target, Literal["auto"]]) -> Union[str, ...@@ -65,13 +69,18 @@ def determine_target(target: Union[str, Target, Literal["auto"]]) -> Union[str,
# Determine the target based on availability # Determine the target based on availability
if is_cuda_available: if is_cuda_available:
return "cuda" return_var = "cuda"
elif is_hip_available: elif is_hip_available:
return "hip" return_var = "hip"
else: else:
raise ValueError("No CUDA or HIP available on this system.") raise ValueError("No CUDA or HIP available on this system.")
else: else:
# Validate the target if it's not "auto" # Validate the target if it's not "auto"
assert isinstance( assert isinstance(
target, Target) or target in AVALIABLE_TARGETS, f"Target {target} is not supported" target, Target) or target in AVALIABLE_TARGETS, f"Target {target} is not supported"
return target return_var = target
if return_object:
return Target(return_var)
return return_var
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment