Commit 227ed7ec authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Avoid tvm ffi handling when out_idx is specified (#209)

* Optimize CMake build process with dynamic job count calculation

- Modify build_csrc function to use 90% of available CPU cores
- Ensure at least one job is used during compilation
- Improve build performance by dynamically adjusting parallel job count

* Optimize build_csrc function with multiprocessing module

- Replace os.cpu_count() with multiprocessing.cpu_count()
- Maintain existing 90% CPU utilization logic
- Improve CPU core count calculation for build process

* Add dynamic shape support with out_idx in Cython JIT kernel compilation

- Implement `run_cython_dynamic_shape_with_out_idx` function in test_tilelang_jit_gemm_cython.py
- Update Cython wrapper to handle dynamic symbolic shapes during tensor allocation
- Add support for resolving dynamic shape dimensions using input tensor references
- Enhance flexibility of JIT kernel compilation with symbolic shape handling

* Enhance error reporting for dynamic symbolic shape resolution in Cython JIT kernel

- Add detailed error message when a dynamic symbolic dimension is not found in dynamic_symbolic_map
- Improve debugging by providing context about missing symbolic dimensions
- Maintain existing dynamic shape resolution logic

* Fix Copy operation handling for scalar and multi-dimensional tensors

- Add special handling for scalar tensor copy operations
- Enhance error reporting in MakeIndices method with more detailed diagnostic information
- Improve SIMT loop generation to support zero-dimensional tensors
- Add explicit check and handling for scalar tensor scenarios

* Refactor Copy operation code formatting and improve readability

- Improve code formatting in MakeIndices and MakeSIMTLoop methods
- Add line breaks to enhance readability of complex ICHECK statements
- Simplify code structure in scalar tensor handling
- Remove unnecessary whitespace and improve code alignment

* Simplify GEMM example with direct kernel compilation

- Update copyright header to Tile-AI Corporation
- Remove Profiler import and usage
- Replace tilelang.lower() with tilelang.compile()
- Simplify kernel execution workflow
- Update kernel source retrieval method

* Enhance block sparse attention implementation

- Update `blocksparse_flashattn` to use 2 stages for improved performance.
- Change `block_mask_dtype` from `int8` to `bool` for better memory efficiency.
- Modify condition checks in the kernel to utilize boolean values.
- Introduce a new example for top-k sparse attention and a benchmark for native sparse attention.
- Add support for asynchronous copy in PTX and improve pipeline planning with condition handling.

* Refactor and clean up code formatting across multiple files

- Added whitespace for improved readability in `example_blocksparse_gemm.py`, `example_tilelang_nsa_fwd.py`, and `benchmark_nsa_fwd.py`.
- Enhanced code structure and alignment in `inject_ptx_async_copy.cc` and `pipeline_planning.cc`.
- Updated comments and documentation for clarity in `__init__.py` and `phase.py`.
- Ensured consistent formatting and style across the codebase.

* Add kernel source printing in example_tilelang_nsa_fwd.py and implement IfThenElse node replacement in inject_pipeline.cc

- Added a print statement to output the kernel source in `example_tilelang_nsa_fwd.py` for debugging purposes.
- Introduced a new function `replace_if_then_else` in `inject_pipeline.cc` to transform IfThenElse nodes while preserving attributes, enhancing the handling of conditional statements in the pipeline.

* Refactor condition handling in inject_pipeline.cc

- Change the data structure for mapping conditions to statements from a Map to an Array for improved performance and simplicity.
- Update condition comparison logic to use StructuralEqual for better accuracy.
- Enhance logging to provide detailed insights into condition changes and statement processing.
- Adjust final statement construction to utilize the new data structure, ensuring correct handling of conditions and statements.

* Improve logging and formatting in inject_pipeline.cc

- Enhance logging statements for better clarity on condition changes and statement processing.
- Adjust formatting for improved readability, including line breaks and consistent spacing.
- Ensure accurate condition comparison and handling in the pipeline logic.

* Refactor logging and clean up inject_pipeline.cc

- Remove excessive logging statements to streamline the code and improve performance.
- Simplify condition handling by eliminating unnecessary log outputs related to condition changes and statement processing.
- Maintain the core functionality while enhancing code readability and maintainability.

* Update Dockerfiles to specify exact version of libstdcxx-ng

- Change installation command in multiple Dockerfiles to use `libstdcxx-ng=12` instead of `libstdcxx-ng-12` for consistency and to avoid potential issues with package resolution.
- Ensure all Dockerfiles from cu118 to cu126 reflect this change for uniformity across builds.

* Refactor and enhance examples and kernel handling

- Adjusted the pipeline stages in `example_blocksparse_gemm.py` from 2 to 1 for improved performance.
- Added kernel source printing in `benchmark_nsa_fwd.py` for better debugging and profiling insights.
- Updated tensor allocation and parameter handling in `CtypesKernelAdapter` and `CythonKernelWrapper` to cache parameter dtypes and shapes, improving efficiency and clarity.
- Enhanced the handling of dynamic shapes in the Cython JIT kernel compilation process.
- Modified the benchmark script to accommodate new tensor output parameters and improved batch size defaults for testing.

* Update copyright header in Cython wrapper to reflect Tile-AI Corporation

* revert change
parent 8678aac0
......@@ -116,6 +116,10 @@ def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o_slc: torch.Tensor,
o_swa: Optional[torch.Tensor],
lse_slc: torch.Tensor,
lse_swa: Optional[torch.Tensor],
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
......@@ -140,10 +144,6 @@ def parallel_nsa_fwd(
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
......@@ -497,6 +497,12 @@ def tilelang_sparse_attention(batch,
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
# T.use_swizzle(10)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.annotate_layout({K_shared: tilelang.layout.make_swizzled_layout(K_shared)})
T.annotate_layout({V_shared: tilelang.layout.make_swizzled_layout(V_shared)})
i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
......@@ -603,20 +609,27 @@ def benchmark_nsa(batch_size,
scale=scale,
)
print(program)
kernel = tilelang.compile(program, out_idx=-1)
kernel = tilelang.compile(program, out_idx=None, execution_backend="cython")
print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
profiler_latency = profiler.do_bench(profiler.mod)
print(f"Profiler latency: {profiler_latency} ms")
# Create input tensors
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
# Generate block indices
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size)
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks,
block_size).to(torch.int32)
# Warmup
for _ in range(warmup):
out = kernel(Q, K, V, block_indices.to(torch.int32))
kernel(Q, K, V, block_indices, out)
# Synchronize before timing
torch.cuda.synchronize()
......@@ -624,7 +637,7 @@ def benchmark_nsa(batch_size,
# Benchmark
start_time = time.time()
for _ in range(iterations):
out = kernel(Q, K, V, block_indices.to(torch.int32))
kernel(Q, K, V, block_indices, out)
torch.cuda.synchronize()
end_time = time.time()
......@@ -710,6 +723,8 @@ def benchmark_triton_nsa(batch_size,
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size)
block_counts = torch.randint(
1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda')
o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device='cuda')
# Warmup
for _ in range(warmup):
......@@ -717,6 +732,10 @@ def benchmark_triton_nsa(batch_size,
q=Q,
k=K,
v=V,
o_slc=o_slc,
o_swa=None,
lse_slc=lse_slc,
lse_swa=None,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
......@@ -733,6 +752,10 @@ def benchmark_triton_nsa(batch_size,
q=Q,
k=K,
v=V,
o_slc=o_slc,
o_swa=None,
lse_slc=lse_slc,
lse_swa=None,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
......@@ -881,13 +904,13 @@ def run_benchmark_suite(impl='all'):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark TileLang Sparse Attention")
parser.add_argument("--batch", type=int, default=2, help="Batch size")
parser.add_argument("--batch", type=int, default=32, help="Batch size")
parser.add_argument("--seq_len", type=int, default=1024, help="Sequence length")
parser.add_argument("--heads", type=int, default=1, help="Number of heads")
parser.add_argument("--head_query", type=int, default=16, help="Number of query heads")
parser.add_argument("--dim", type=int, default=64, help="Head dimension")
parser.add_argument("--selected_blocks", type=int, default=8, help="Number of selected blocks")
parser.add_argument("--block_size", type=int, default=64, help="Block size")
parser.add_argument("--dim", type=int, default=128, help="Head dimension")
parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks")
parser.add_argument("--block_size", type=int, default=32, help="Block size")
parser.add_argument(
"--dtype", type=str, default="float16", help="Data type (float16 or float32)")
parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor")
......@@ -898,7 +921,7 @@ if __name__ == "__main__":
parser.add_argument(
"--impl",
type=str,
default="tilelang",
default="all",
choices=["tilelang", "triton", "all"],
help="Implementation to benchmark (tilelang, triton, or all)")
......
......@@ -182,8 +182,9 @@ public:
}
Stmt VisitStmt_(const BufferStoreNode *store) {
if (in_async && (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn")) {
bool is_shared = (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn");
if (in_async && is_shared) {
if (auto *load = store->value.as<BufferLoadNode>()) {
return InjectPTX(load, store);
} else if (auto *call = store->value.as<CallNode>()) {
......
......@@ -34,6 +34,10 @@ class CtypesKernelAdapter(BaseKernelAdapter):
# Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None
# Add new cache attributes
param_dtypes: Optional[List[torch.dtype]] = None # Cache for parameter dtypes
param_shapes: Optional[List[List]] = None # Cache for parameter shapes
def __init__(self,
rt_mod,
params: List[TensorType],
......@@ -61,6 +65,20 @@ class CtypesKernelAdapter(BaseKernelAdapter):
else:
self.ir_module = func_or_mod
# Cache parameter information during initialization
self.param_dtypes = [map_torch_type(param.dtype) for param in params]
self.param_shapes = []
for param in params:
native_shape = []
for dim in param.shape:
if isinstance(dim, tir.IntImm):
native_shape.append(int(dim))
elif isinstance(dim, tir.Var):
native_shape.append(dim) # Keep tir.Var for dynamic dimensions
else:
native_shape.append(dim)
self.param_shapes.append(native_shape)
self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.target = Target.canon_target(determine_target(target))
......@@ -135,9 +153,15 @@ class CtypesKernelAdapter(BaseKernelAdapter):
# tensor pointers
for i in range(len(self.params)):
if i in self.result_idx:
dtype = map_torch_type(self.params[i].dtype)
shape = list(map(int, self.params[i].shape))
# use the device of the first input tensor if available
dtype = self.param_dtypes[i]
shape = []
# Now working with native Python list, no FFI calls needed
for s in self.param_shapes[i]:
if isinstance(s, tir.Var):
ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s]
shape.append(ins[ref_tensor_idx].shape[ref_shape_idx])
else: # Already converted to Python int during initialization
shape.append(s)
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# cython: language_level=3
import torch
......@@ -19,12 +17,29 @@ cdef class CythonKernelWrapper:
list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel
# Add new cache attributes
list param_dtypes # Cache for parameter dtypes
list param_shapes # Cache for parameter shapes as native Python lists
def __cinit__(self, result_idx, params, lib):
# Initialize wrapper with kernel configuration
self.result_idx = result_idx
self.params = params
self.lib = lib
# Convert TVM types to native Python types during initialization
self.param_dtypes = [map_torch_type(param.dtype) for param in params]
# Convert TVM shape arrays to native Python lists
self.param_shapes = []
for param in params:
native_shape = []
for dim in param.shape:
if isinstance(dim, tir.IntImm):
native_shape.append(int(dim))
elif isinstance(dim, tir.Var):
native_shape.append(dim) # Keep tir.Var for dynamic dimensions
else:
native_shape.append(dim)
self.param_shapes.append(native_shape)
def set_dynamic_symbolic_map(self, dynamic_symbolic_map):
self.dynamic_symbolic_map = dynamic_symbolic_map
......@@ -57,29 +72,22 @@ cdef class CythonKernelWrapper:
cdef int ins_idx = 0
cdef list tensor_list = []
cdef list call_args = []
# Prepare input and output tensors
for i in range(len(self.params)):
if i in self.result_idx:
# Create empty output tensor with specified dtype and shape
dtype = map_torch_type(self.params[i].dtype)
dtype = self.param_dtypes[i]
shape = []
for s in self.params[i].shape:
# Now working with native Python list, no FFI calls needed
for s in self.param_shapes[i]:
if isinstance(s, tir.Var):
# find the corresponding input tensor and shape dimension
assert s in self.dynamic_symbolic_map, f"Dynamic symbolic dimension \
{s} not found in dynamic_symbolic_map"
ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s]
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
elif isinstance(s, (tir.IntImm, int)):
shape.append(int(s))
else:
raise ValueError(f"Unsupported shape type: {type(s)}")
else: # Already converted to Python int during initialization
shape.append(s)
device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device()
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
# Use provided input tensor
tensor = inputs[ins_idx]
ins_idx += 1
tensor_list.append(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment