"docs/vscode:/vscode.git/clone" did not exist on "06123ed8ea4aee20ef131ab0bd52d0735147c36f"
Commit 39fc5a6d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev][jit] Introduce jit for kernel functions (#12)

* instruction update

* replace link with TileLang/tile-lang

* [Dev][Adapter] Implement Torch DLPack Kernel Adapter and related utilities

* lint fix

* Implement JIT Compiler Components

* Documents update

* lint fix

* update logo

* install script fix
parent 18718446
<img src=./images/logo-row.svg />
<div align="center">
# Tile Language
......@@ -57,7 +59,7 @@ pip install tilelang
Alternatively, you can install directly from the GitHub repository:
```bash
pip install git+https://github.com/TileLang/tile-lang
pip install git+https://github.com/tile-ai/tilelang
```
Or install locally:
......@@ -82,6 +84,9 @@ In this section, you’ll learn how to write and execute a straightforward GEMM
Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.
```python
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for for MMA operations
......@@ -91,6 +96,7 @@ from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
......@@ -105,13 +111,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# T.annotate_layout({
# A_shared: make_swizzle_layout(A_shared),
# B_shared: make_swizzle_layout(B_shared),
# })
# Enable rasterization for better L2 cache locality (Optional)
T.use_swizzle(panel_size=10, enable=True)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
......@@ -133,6 +139,45 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# 5.Pofile latency with kernel
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
```
### Dive Deep into TileLang Beyond GEMM
......@@ -152,4 +197,4 @@ TileLang has now been used in project [BitBLAS](https://github.com/microsoft/Bit
## Acknowledgements
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions.
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions. The initial version of this project is mainly contributed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410). Part of this work was done during the internship at Microsoft Research, under the supervision of Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang.
......@@ -22,7 +22,7 @@ RUN conda install pip cmake && conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/TileLang/tile-lang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install.sh
CMD bash
To ease the process of installing all the dependencies, we provide a Dockerfile and a simple guideline to build a Docker image with all of above installed. The Docker image is built on top of Ubuntu 20.04, and it contains all the dependencies required to run the experiments. We only provide the Dockerfile for NVIDIA GPU, and the Dockerfile for AMD GPU will be provided upon request.
```bash
git clone --recursive https://github.com/TileLang/tile-lang TileLang
git clone --recursive https://github.com/tile-ai/tilelang TileLang
cd TileLang/docker
# build the image, this may take a while (around 10+ minutes on our test machine)
docker build -t tilelang_cuda -f Dockerfile.cu120 .
......
......@@ -9,7 +9,7 @@
The easiest way to install TileLang is directly from the PyPi using pip. To install the latest version, run the following command in your terminal.
**Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/TileLang/tile-lang/blob/main/docs/Installation.md#building-from-source).**
**Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/tile-ai/tilelang/blob/main/docs/Installation.md#building-from-source).**
```bash
pip install tilelang
......@@ -24,7 +24,7 @@ pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
To install the latest version of TileLang from the github repository, you can run the following command:
```bash
pip install git+https://github.com/TileLang/tile-lang.git
pip install git+https://github.com/tile-ai/tilelang.git
```
After installing TileLang, you can verify the installation by running:
......@@ -56,7 +56,7 @@ sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev
After installing the prerequisites, you can clone the TileLang repository and install it using pip:
```bash
git clone --recursive https://github.com/TileLang/tile-lang.git
git clone --recursive https://github.com/tile-ai/tilelang.git
cd TileLang
pip install . # Please be patient, this may take some time.
```
......@@ -80,7 +80,7 @@ If you already have a compatible TVM installation, follow these steps:
1. **Clone the Repository:**
```bash
git clone --recursive https://github.com/TileLang/tile-lang
git clone --recursive https://github.com/tile-ai/tilelang
cd TileLang
```
......@@ -114,7 +114,7 @@ If you prefer to use the built-in TVM version, follow these instructions:
1. **Clone the Repository:**
```bash
git clone --recursive https://github.com/TileLang/tile-lang
git clone --recursive https://github.com/tile-ai/tilelang
cd TileLang
```
......@@ -152,7 +152,7 @@ For a simplified installation, use the provided script:
1. **Clone the Repository:**
```bash
git clone --recursive https://github.com/TileLang/tile-lang
git clone --recursive https://github.com/tile-ai/tilelang
cd TileLang
```
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
# Kernel configuration remains similar
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically
# T.annotate_layout({
# A_shared: make_swizzle_layout(A_shared),
# B_shared: make_swizzle_layout(B_shared),
# })
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, k * block_K], A_shared)
# Demonstrate parallelized copy from global to shared for B
for ko, j in T.Parallel(block_K, block_N):
B_shared[ko, j] = B[k * block_K + ko, bx * block_N + j]
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# 5.Pofile latency with kernel
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
echo "Starting installation script..."
# Step 1: Install Python requirements
echo "Installing Python requirements from requirements.txt..."
pip install -r requirements.txt
if [ $? -ne 0 ]; then
echo "Error: Failed to install Python requirements."
exit 1
else
echo "Python requirements installed successfully."
fi
# Step 2: Define LLVM version and architecture
LLVM_VERSION="10.0.1"
IS_AARCH64=false
EXTRACT_PATH="3rdparty"
echo "LLVM version set to ${LLVM_VERSION}."
echo "Is AARCH64 architecture: $IS_AARCH64"
# Step 3: Determine the correct Ubuntu version based on LLVM version
UBUNTU_VERSION="16.04"
if [[ "$LLVM_VERSION" > "17.0.0" ]]; then
UBUNTU_VERSION="22.04"
elif [[ "$LLVM_VERSION" > "16.0.0" ]]; then
UBUNTU_VERSION="20.04"
elif [[ "$LLVM_VERSION" > "13.0.0" ]]; then
UBUNTU_VERSION="18.04"
fi
echo "Ubuntu version for LLVM set to ${UBUNTU_VERSION}."
# Step 4: Set download URL and file name for LLVM
BASE_URL="https://github.com/llvm/llvm-project/releases/download/llvmorg-${LLVM_VERSION}"
if $IS_AARCH64; then
FILE_NAME="clang+llvm-${LLVM_VERSION}-aarch64-linux-gnu.tar.xz"
else
FILE_NAME="clang+llvm-${LLVM_VERSION}-x86_64-linux-gnu-ubuntu-${UBUNTU_VERSION}.tar.xz"
fi
DOWNLOAD_URL="${BASE_URL}/${FILE_NAME}"
echo "Download URL for LLVM: ${DOWNLOAD_URL}"
# Step 5: Create extraction directory
echo "Creating extraction directory at ${EXTRACT_PATH}..."
mkdir -p "$EXTRACT_PATH"
if [ $? -ne 0 ]; then
echo "Error: Failed to create extraction directory."
exit 1
else
echo "Extraction directory created successfully."
fi
# Step 6: Download LLVM
echo "Downloading $FILE_NAME from $DOWNLOAD_URL..."
curl -L -o "${EXTRACT_PATH}/${FILE_NAME}" "$DOWNLOAD_URL"
if [ $? -ne 0 ]; then
echo "Error: Download failed!"
exit 1
else
echo "Download completed successfully."
fi
# Step 7: Extract LLVM
echo "Extracting $FILE_NAME to $EXTRACT_PATH..."
tar -xJf "${EXTRACT_PATH}/${FILE_NAME}" -C "$EXTRACT_PATH"
if [ $? -ne 0 ]; then
echo "Error: Extraction failed!"
exit 1
else
echo "Extraction completed successfully."
fi
# Step 8: Determine LLVM config path
LLVM_CONFIG_PATH="$(realpath ${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config)"
echo "LLVM config path determined as: $LLVM_CONFIG_PATH"
# Step 9: Clone and build TVM
echo "Cloning TVM repository and initializing submodules..."
# clone and build tvm
git submodule update --init --recursive
if [ -d build ]; then
rm -rf build
fi
mkdir build
cp 3rdparty/tvm/cmake/config.cmake build
cd build
echo "Configuring TVM build with LLVM and CUDA paths..."
echo "set(USE_LLVM $LLVM_CONFIG_PATH)" >> config.cmake
echo "Running CMake for TileLang..."
cmake ..
if [ $? -ne 0 ]; then
echo "Error: CMake configuration failed."
exit 1
fi
echo "Building TileLang with make..."
make -j
if [ $? -ne 0 ]; then
echo "Error: TileLang build failed."
exit 1
else
echo "TileLang build completed successfully."
fi
cd ../../..
# Step 11: Set environment variables
TILELANG_PATH="$(pwd)"
echo "Configuring environment variables for TVM..."
echo "export PYTHONPATH=${TILELANG_PATH}:\$PYTHONPATH" >> ~/.bashrc
# Step 12: Source .bashrc to apply changes
echo "Applying environment changes by sourcing .bashrc..."
source ~/.bashrc
if [ $? -ne 0 ]; then
echo "Error: Failed to source .bashrc."
exit 1
else
echo "Environment configured successfully."
fi
echo "Installation script completed successfully."
......@@ -113,8 +113,9 @@ fi
cd ../../..
# Step 11: Set environment variables
TILELANG_PATH="$(pwd)"
echo "Configuring environment variables for TVM..."
echo "export PYTHONPATH=$(pwd):\$PYTHONPATH" >> ~/.bashrc
echo "export PYTHONPATH=${TILELANG_PATH}:\$PYTHONPATH" >> ~/.bashrc
echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc
# Step 12: Source .bashrc to apply changes
......
......@@ -85,8 +85,12 @@ cd ../../..
# Define the lines to be added
TVM_HOME_ENV="export TVM_HOME=$(pwd)/3rdparty/tvm"
TILELANG_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH"
TILELANG_PATH="$(pwd)"
echo "Configuring environment variables for TVM..."
echo "export PYTHONPATH=${TILELANG_PATH}:\$PYTHONPATH" >> ~/.bashrc
echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc
TVM_HOME_ENV="export TVM_HOME=${TILELANG_PATH}/3rdparty/tvm"
TILELANG_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:${TILELANG_PATH}:\$PYTHONPATH"
CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID"
# Check and add the first line if not already present
......
......@@ -447,7 +447,7 @@ setup(
],
license="MIT",
keywords="BLAS, CUDA, HIP, Code Generation, TVM",
url="https://github.com/TileLang/tile-lang",
url="https://github.com/tile-ai/tilelang",
classifiers=[
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: MIT License",
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang
import torch
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
execution_backend="dl_pack",
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
matmul_kernel = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
if __name__ == "__main__":
# tilelang.testing.main()
test_gemm_f16f16f16_nn()
......@@ -359,18 +359,4 @@ def run_matmul_rrr(
# )
if __name__ == "__main__":
# tilelang.testing.main()
run_matmul_ssr(
1024,
1024,
1024,
False,
True,
"float16",
"float16",
"float16",
128,
128,
32,
2,
)
tilelang.testing.main()
......@@ -81,65 +81,7 @@ def deprecated(reason):
logger = logging.getLogger(__name__)
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0")
# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
if TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH + "/python")
else:
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
install_tvm_library_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lib")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, install_tvm_path + "/python")
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
develop_tvm_library_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, develop_tvm_path + "/python")
if os.environ.get("TVM_LIBRARY_PATH") is None:
if os.path.exists(develop_tvm_library_path):
os.environ["TVM_LIBRARY_PATH"] = develop_tvm_library_path
elif os.path.exists(install_tvm_library_path):
os.environ["TVM_LIBRARY_PATH"] = install_tvm_library_path
else:
logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE)
if os.environ.get("TL_CUTLASS_PATH", None) is None:
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(install_cutlass_path):
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path):
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE)
if os.environ.get("TL_TEMPLATE_PATH", None) is None:
install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
if os.path.exists(install_tl_template_path):
os.environ["TL_TEMPLATE_PATH"] = install_tl_template_path
elif (os.path.exists(develop_tl_template_path) and develop_tl_template_path not in sys.path):
os.environ["TL_TEMPLATE_PATH"] = develop_tl_template_path
else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
from .env import SKIP_LOADING_TILELANG_SO
import tvm
import tvm._ffi.base
......@@ -163,8 +105,10 @@ def _load_tile_lang_lib():
if SKIP_LOADING_TILELANG_SO == "0":
_LIB, _LIB_PATH = _load_tile_lang_lib()
from .jit import jit, JITKernel # noqa: F401
from .profiler import Profiler # noqa: F401
from .utils import (
Profiler, # noqa: F401
TensorSupplyType, # noqa: F401
)
from .layout import (
......
......@@ -107,11 +107,6 @@ def extrac_params(func: tir.PrimFunc):
def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]):
def target_is_c(target):
if isinstance(target, str):
return target == "c"
return target.kind.name == "c"
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
......
import sys
import os
import pathlib
import logging
logger = logging.getLogger(__name__)
CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None)
TVM_PYTHON_PATH: str = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None)
TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None)
TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0]
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0")
# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
if TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH)
else:
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
install_tvm_library_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lib")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = (
install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, install_tvm_path + "/python")
TVM_IMPORT_PYTHON_PATH = install_tvm_path + "/python"
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
develop_tvm_library_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = (
develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, develop_tvm_path + "/python")
TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python"
if os.environ.get("TVM_LIBRARY_PATH") is None:
if os.path.exists(develop_tvm_library_path):
os.environ["TVM_LIBRARY_PATH"] = develop_tvm_library_path
elif os.path.exists(install_tvm_library_path):
os.environ["TVM_LIBRARY_PATH"] = install_tvm_library_path
else:
logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE)
TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None)
if os.environ.get("TL_CUTLASS_PATH", None) is None:
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(install_cutlass_path):
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
CUTLASS_INCLUDE_DIR = install_cutlass_path + "/include"
elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path):
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
CUTLASS_INCLUDE_DIR = develop_cutlass_path + "/include"
else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE)
if os.environ.get("TL_TEMPLATE_PATH", None) is None:
install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
if os.path.exists(install_tl_template_path):
os.environ["TL_TEMPLATE_PATH"] = install_tl_template_path
TILELANG_TEMPLATE_PATH = install_tl_template_path
elif (os.path.exists(develop_tl_template_path) and develop_tl_template_path not in sys.path):
os.environ["TL_TEMPLATE_PATH"] = develop_tl_template_path
TILELANG_TEMPLATE_PATH = develop_tl_template_path
else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
__all__ = [
"CUTLASS_INCLUDE_DIR",
"TVM_PYTHON_PATH",
"TVM_LIBRARY_PATH",
"TILELANG_TEMPLATE_PATH",
]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
"""
from typing import Callable, List, Literal, Union
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
from tilelang.jit.adapter import BaseKernelAdapter
from tilelang.jit.kernel import JITKernel
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from logging import getLogger
logger = getLogger(__name__)
def jit(
func: Callable = None,
*, # Enforce keyword-only arguments from here on
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dl_pack", "torch_cpp", "ctypes"] = "dl_pack",
target: Union[str, Target] = "auto",
verbose: bool = False,
) -> BaseKernelAdapter:
"""
A decorator (or decorator factory) that JIT-compiles a given TileLang PrimFunc
into a runnable kernel adapter using TVM. If called with arguments, it returns
a decorator that can be applied to a function. If called without arguments,
it directly compiles the given function.
Parameters
----------
func : Callable, optional
The TileLang PrimFunc to JIT-compile. If None, this function returns a
decorator that expects a TileLang PrimFunc.
out_idx : Union[List[int], int], optional
The index (or list of indices) of the function outputs. This can be used
to specify which outputs from the compiled function will be returned.
execution_backend : Literal["dl_pack", "torch_cpp", "ctypes"], optional
The wrapper type to use for the kernel adapter. Currently, only "dl_pack"
and "torch_cpp" are supported.
target : Union[str, Target], optional
The compilation target for TVM. If set to "auto", an appropriate target
will be inferred automatically. Otherwise, must be one of the supported
strings in AVALIABLE_TARGETS or a TVM Target instance.
Returns
-------
BaseKernelAdapter
An adapter object that encapsulates the compiled function and can be
used to execute it.
Raises
------
AssertionError
If the provided target is an invalid string not present in AVALIABLE_TARGETS.
"""
# If the target is specified as a string, ensure it is valid and convert to a TVM Target.
if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
target = determine_target(target)
target = Target(target)
assert execution_backend in ["dl_pack", "torch_cpp", "ctypes"], "Invalid execution backend."
def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter:
"""
Compile the given TileLang PrimFunc with TVM and build a kernel adapter.
Parameters
----------
tilelang_func : tvm.tir.PrimFunc
The TileLang (TVM TIR) function to compile.
Returns
-------
BaseKernelAdapter
The compiled and ready-to-run kernel adapter.
"""
if verbose:
logger.info(f"Compiling TileLang function:\n{tilelang_func}")
return JITKernel(
tilelang_func,
target=target,
verbose=verbose,
execution_backend=execution_backend,
out_idx=out_idx,
).adapter
# If `func` was given, compile it immediately and return the adapter.
if func is not None:
return _compile_and_create_adapter(func)
# Otherwise, return a decorator that expects a function to compile.
def real_decorator(tilelang_func: PrimFunc) -> BaseKernelAdapter:
return _compile_and_create_adapter(tilelang_func)
return real_decorator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base import BaseKernelAdapter # noqa: F401
from .dl_pack import TorchDLPackKernelAdapter # noqa: F401
from .torch_cpp import TorchCPPKernelAdapter # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from typing import Any, List
from tvm.relay import TensorType
class BaseKernelAdapter(object):
def __init__(self, mod, params: List[TensorType], result_idx: List[int]) -> None:
self.mod = mod
self.params = params
# result_idx is a list of indices of the output tensors
if result_idx is None:
result_idx = []
elif isinstance(result_idx, int):
if result_idx > len(params) or result_idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
if result_idx < 0:
result_idx = len(params) + result_idx
result_idx = [result_idx]
elif not isinstance(result_idx, list):
raise ValueError("result_idx should be a list of integers")
self.result_idx = result_idx
self.func = self._convert_torch_func()
def _convert_torch_func(self) -> callable:
raise NotImplementedError
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.func(*args, **kwds)
def get_kernel_source(self) -> str:
return self.mod.imported_modules[0].get_source()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from typing import List
from .base import BaseKernelAdapter
from tvm.relay import TensorType
class CtypesKernelAdapter(BaseKernelAdapter):
target = "cuda"
prim_func = None
def __init__(self,
mod,
params: List[TensorType],
result_idx: List[int],
target,
prim_func,
verbose: bool = False):
self.target = target
self.prim_func = prim_func
self.verbose = verbose
super().__init__(mod, params, result_idx)
raise NotImplementedError("CtypesKernelAdapter is not implemented yet.")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
import torch
from typing import List
from tvm.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter
class TorchDLPackKernelAdapter(BaseKernelAdapter):
def _convert_torch_func(self) -> callable:
torch_func = to_pytorch_func(self.mod)
def func(*ins: List[torch.Tensor]):
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
)
ins_idx = 0
args = []
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
for i in range(len(self.params)):
if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype))
shape = list(map(int, self.params[i].shape))
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
tensor = ins[ins_idx]
ins_idx += 1
args.append(tensor)
torch_func(*args)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
else:
return [args[i] for i in self.result_idx]
return func
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
import torch
from typing import List, Union
from .base import BaseKernelAdapter
from pathlib import Path
from tvm.relay import TensorType
from tilelang.jit.core import load_cuda_ops
from tilelang.jit.env import (TILELANG_JIT_WORKSPACE_DIR)
def torch_cpp_cuda_compile(code, target, verbose):
# TODO(lei): This is not fully implemented yet
# TODO(lei): extract name and magic number from module
name: str = "matmul"
magic_number = 0x9f
full_kernel_dir = TILELANG_JIT_WORKSPACE_DIR / Path(f"{name}_{magic_number}")
full_kernel_dir.mkdir(parents=True, exist_ok=True)
sources: List[Union[str, Path]] = []
tmp_cuda_kernel_file = (full_kernel_dir / "kernel.cu")
code = (
code + r"""
void kenrel_interface(void* A, void *B, void *C, int64_t cuda_stream) {
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
main_kernel<<<dim3(4, 4, 1), dim3(128, 1, 1), 0, stream>>>((half_t *)A, (half_t *)B, (half_t *)C);
}
""")
with open(tmp_cuda_kernel_file, "w") as f:
f.write(code)
print(tmp_cuda_kernel_file)
sources.append(tmp_cuda_kernel_file)
tmp_host_file = (full_kernel_dir / "host.cpp")
host_code = r"""
#include <torch/extension.h>
#include <stdio.h>
#include <ATen/ATen.h>
void kenrel_interface(void* A, void *B, void *C, int64_t cuda_stream);
int dispather(at::Tensor& A, at::Tensor& B, at::Tensor& C, int64_t cuda_stream) {
kenrel_interface(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
cuda_stream
);
return 0;
}
int dispather(at::Tensor& A, at::Tensor& B, at::Tensor& C, int64_t cuda_stream);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("matmul", &dispather, "matmul");
printf("Registering matmul\n");
}
"""
with open(tmp_host_file, "w") as f:
f.write(host_code)
sources.append(tmp_host_file)
module = load_cuda_ops(name=name, sources=sources, verbose=verbose)
return module.matmul
class TorchCPPKernelAdapter(BaseKernelAdapter):
target = "cuda"
prim_func = None
def __init__(self,
mod,
params: List[TensorType],
result_idx: List[int],
target,
prim_func,
verbose: bool = False):
self.target = target
self.prim_func = prim_func
self.verbose = verbose
super().__init__(mod, params, result_idx)
def _convert_torch_func(self) -> callable:
target = self.target
verbose = self.verbose
code = self.get_kernel_source()
torch_module = torch_cpp_cuda_compile(code, target, verbose)
# raise NotImplementedError("Please implement this function")
def func(*ins: List[torch.Tensor]):
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
)
ins_idx = 0
args = []
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
for i in range(len(self.params)):
if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype))
shape = list(map(int, self.params[i].shape))
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
tensor = ins[ins_idx]
ins_idx += 1
args.append(tensor)
torch_module(*args, 0)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
else:
return [args[i] for i in self.result_idx]
return func
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment