Commit 8344af52 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Implement boundary check for the buffer shape with dynamic symbolic (#173)

* [Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation

- Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py
- Modify roller hints generation using new TileLang Carver template and utility functions
- Update get_roller_hints_from_func to handle None cases and improve return logic
- Adjust DefaultPolicy to handle different codegen dictionary formats

* [Refactor] Update Thread Binding and Import Statements in TileLang Kernels

- Replace T.thread_binding() with T.get_thread_binding() across multiple kernel test files
- Update import statements for MMA layout and macro generator in dequantize GEMM and FP8 examples
- Move map_torch_type utility function to tilelang.utils.tensor
- Remove unnecessary imports and improve code organization

* Refactor Native Sparse Attention Example with Enhanced Triton Kernel

- Update parallel_nsa_fwd_kernel to support more flexible sparse attention computation
- Add support for block counts and offsets in the Triton kernel
- Modify kernel grid and computation logic for improved performance
- Update example script to use naive_nsa_simple reference implementation
- Improve type hints and kernel configuration

* Add Native Sparse Attention Examples with Tilelang and Triton Implementations

- Introduce new example scripts for native sparse attention:
  * example_tilelang_nsa_fwd.py: Forward pass implementation using TileLang
  * example_tilelang_nsa_decode.py: Decoding-specific sparse attention implementation
  * example_triton_nsa_fwd.py: Triton-based sparse attention forward pass
- Update reference.py with naive implementations for sparse attention
- Support different sparse attention scenarios including forward pass and inference
- Add comprehensive testing and validation against reference implementations

* lint fix

* Add Variable-Length Native Sparse Attention Examples for TileLang and Triton

- Introduce new example scripts for variable-length native sparse attention:
  * example_tilelang_nsa_fwd_varlen.py: TileLang implementation with variable sequence lengths
  * example_triton_nsa_fwd_varlen.py: Triton implementation with variable sequence lengths
- Update reference.py to support variable-length sparse attention scenarios
- Enhance existing sparse attention implementations to handle variable-length inputs
- Add comprehensive testing and validation for variable-length sparse attention

* Refactor Native Sparse Attention Examples: Code Style and Formatting Improvements

- Standardize function and parameter formatting across NSA example files
- Improve code readability by adjusting indentation and line breaks
- Enhance type hints and parameter alignment
- Remove unnecessary whitespaces and optimize imports
- Maintain consistent code style across TileLang and Triton implementations

* Add debug logging and extend execution backend in JIT and loop vectorization

- Add detailed logging in loop vectorization to help diagnose buffer shape handling
- Extend JIT execution backend to include 'cython' option
- Improve boundary condition checks in BufferLoadNode visit method

* Remove debug logging in loop vectorization BufferLoadNode visit method

- Remove unnecessary INFO log statements in VisitExpr_ method
- Simplify code by eliminating redundant logging
- Maintain core logic for handling buffer load node visits
parent 8e1845d2
......@@ -73,12 +73,14 @@ private:
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
if (node->buffer->shape.size() == 1 &&
node->buffer->shape[0].as<IntImmNode>()->value == 1) {
if (node->buffer->shape.size() == 1) {
// TODO(lei): This should be improved as
// constant buffer that tl hack to use as local register.
auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
if (boundary_check && boundary_check->value == 1) {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
}
UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
......
......@@ -24,7 +24,7 @@ def jit(
func: Callable = None,
*, # Enforce keyword-only arguments from here on
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "ctypes"] = "dlpack",
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto",
verbose: bool = False,
) -> BaseKernelAdapter:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment