Commit 7cd6b3cd authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Put `InjectPtxAsyncCopy` Pass behind `ThreadSync` Pass (#97)

* bump version into v0.1.0

* [Enhancement] Add custom develop command for editable installs and update .gitignore

* [Documentation] Update README to include system dependencies installation instructions

* [Build] Update setup.py to support library file copying for both release and develop modes

* [Build] Refactor library file copying logic in setup.py

* [Documentation] Remove unnecessary install section header in Installation.md

* [Build] Add tox configuration and local distribution script for multi-Python version support

* [Build] Improve git submodule update function with better error handling

* [Build] Update LLVM configuration path in ROCm installation script

* [Build] Add .tox/ to .gitignore for tox testing environment

* [Build] Add support for TVM prebuild path configuration in CMakeLists.txt

* [Cleanup] Remove unused TVM runtime error codes header

* [Cleanup] Fix TVM grid constant type reference in CUDA module

* [Cleanup] Remove unused customized_code function from IR module

* [Feature] Add TileLang thread synchronization and storage access analysis passes

* [Build] Reorder DLL search path directories for more flexible library loading

* [Refactor] Improve thread synchronization and library path handling

- Rename ThreadSync and TileLangThreadSync functions in C++ code
- Update Python docstring for ThreadSync with more detailed description
- Reorder library path detection in tilelang environment setup
- Minor comment and code cleanup in CUDA and warp specialization modules

* [Refactor] Improve thread synchronization code style and formatting

- Standardize pointer type spacing in storage_access.h and storage_access.cc
- Update whitespace and indentation in thread_storage_sync.cc
- Reorder include statements in thread_partial_sync.cc
- Minor code formatting improvements across thread synchronization files

* [Refactor] Fix global function registration for ThreadSync

- Correct global function registration to use ThreadSync instead of TileLangThreadSync
- Update TVM global registration to match recent refactoring efforts

* [Refactor] Simplify ThreadSync global function registration

- Remove unnecessary whitespace in global function registration
- Compact the TVM global registration line for ThreadSync

* [Feature] Add WebGPU code generation support in TileLang

- Implement WebGPU code generator (codegen_webgpu.cc and codegen_webgpu.h)
- Add WebGPU target support in lower.py and target.py
- Update CMakeLists.txt to include WebGPU codegen source files
- Introduce WebGPU-specific code generation for WGSL shader language

* [Refactor] Improve WebGPU code generation formatting and readability

- Enhance code formatting in codegen_webgpu.cc and codegen_webgpu.h
- Standardize pointer type spacing and indentation
- Improve line breaks and reduce line length for better readability
- Minor code style improvements in WebGPU code generation

* [Test] Add WebGPU matrix multiplication code generation test

- Implement test_webgpu_codegen.py for WebGPU matrix multiplication
- Add assert_gemm_codegen function to validate WebGPU code generation
- Include basic matrix multiplication kernel test case

* Update README with WebGPU codegen support announcement

* Support multi version pypi package build via tox

* Add support for CPU device backend with C code generation

- Introduce `is_cpu_device_backend` function to detect CPU backend with C code generation
- Modify `lower` function to handle special case of CPU device backend
- Update host and device call filtering for CPU backend
- Add conditional source code generation for C host target
- Extend JITKernel to support optional target_host parameter

* lint fix

* Enhance JIT kernel adapters with CTypes and Torch C++ backends

- Add CtypesKernelAdapter with dynamic library generation and kernel wrapping
- Implement TorchCPPKernelAdapter for CUDA kernel compilation
- Refactor BaseKernelAdapter to support more flexible initialization
- Improve error handling and argument processing in kernel adapters
- Update adapter initialization to support various execution backends

* Refactor and clean up code style in JIT CTypes adapter modules

- Apply consistent code formatting and whitespace in CTypes adapter files
- Remove unused imports and improve import organization
- Enhance readability of code in adapter, libgen, and wrapper modules
- Add missing whitespace and improve line breaks
- Minor linting and code style improvements across CTypes adapter files

* Add test for TileLang JIT GEMM with CTypes backend

- Implement comprehensive test for matrix multiplication using CTypes execution backend
- Create test functions for GEMM with float16 data type
- Add kernel source verification with custom callback
- Implement reference implementation using PyTorch for result validation
- Support various matrix multiplication configurations (transposition, block sizes)

* test fix

* Update TileLang JIT callback registration with override parameter

- Modify tilelang_callback_cuda_postproc to use @tvm.register_func(override=True)
- Ensure proper function registration with ability to replace existing implementations

* Reorder TileLang lowering passes for Hopper intrinsics and PTX async copy

- Adjust the order of LowerHopperIntrin and InjectPTXAsyncCopy passes
- Move these passes to ensure correct synchronization and device preparation

* Rebase main

* shared.dyn

* lint fix

* test fix

* Add environment variable handling for TileLang template and CUTLASS paths

- Introduce fallback logic for TL_TEMPLATE_PATH environment variable
- Add support for optional TL_CUTLASS_PATH configuration
- Include TODO comment for future environment variable renaming
parent 93294e61
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.testing
import tilelang.language as T
import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
# changing num_stages to 0 gives correct results
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A[by * block_M, ko * block_K], A_shared)
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32):
func = matmul(N, N, N, block_M, block_N, block_K)
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
torch.manual_seed(0)
a = torch.randn(N, N, device="cuda", dtype=torch.float16)
b = torch.randn(N, N, device="cuda", dtype=torch.float16)
ref_c = a @ b.T
c = jit_kernel(a, b)
tilelang.testing.torch_assert_close(c, ref_c, rtol=1e-2, atol=0.2)
def test_pipeline_large_matrix():
"""Test pipeline stages with large matrix multiplication (8192x8192)"""
run_gemm_pipeline_test(8192)
def test_pipeline_small_matrix():
"""Test pipeline stages with smaller matrix multiplication"""
run_gemm_pipeline_test(1024)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -88,7 +88,7 @@ def run_gemm( ...@@ -88,7 +88,7 @@ def run_gemm(
stramp = "&*(XS)" stramp = "&*(XS)"
@tvm.register_func(override=True) @tvm.register_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _): def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
......
...@@ -88,7 +88,7 @@ def run_gemm( ...@@ -88,7 +88,7 @@ def run_gemm(
stramp = "&*(XS)" stramp = "&*(XS)"
@tvm.register_func(override=True) @tvm.register_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _): def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
......
...@@ -203,14 +203,14 @@ def lower( ...@@ -203,14 +203,14 @@ def lower(
mod = tl.transform.ThreadPartialSync("shared.dyn")(mod) mod = tl.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod) mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tl.transform.AnnotateDeviceRegions()(mod) mod = tl.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod) mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tl.transform.ThreadSync("shared")(mod) mod = tl.transform.ThreadSync("shared")(mod)
mod = tl.transform.ThreadSync("shared.dyn")(mod) mod = tl.transform.ThreadSync("shared.dyn")(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tl.transform.MakePackedAPI()(mod) mod = tl.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment