- 17 Dec, 2025 1 commit
-
-
Lei Wang authored
* [Enhancement] Update examples and tests for improved type handling and functionality - Enhanced various example scripts to support new data types and improve compatibility with PyTorch. - Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling. - Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management. - Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling. * [Refactor] Update accumulation data type to float32 across examples - Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability. - This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board. * [Refactor] Standardize data type usage across benchmark scripts - Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling. - Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard. - Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules. * [Refactor] Standardize data type usage in templates and scripts - Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity. - Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates. - This change aims to streamline type handling and improve compatibility with existing workflows. * [Refactor] Standardize data type usage in examples and benchmarks - Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability. - Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard. - Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples. * [Refactor] Import dtypes from language.v2 module - Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase. - This change aims to streamline data type management and improve overall code clarity. * fix * [Refactor] Standardize data type usage across scripts - Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity. - Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability. - This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase. * [Refactor] Update data type handling for consistency and clarity - Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency. - Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase. - Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling. - This refactor aims to streamline data type management and improve overall code clarity and maintainability. * [Enhancement] Improve data type handling and error messaging - Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation. - Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs. - Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience. * [Fix] Correct boolean flag in GEMM SP test case - Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case. - This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation. * [Refactor] Standardize data type usage across scripts - Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability. - Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase. - This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability. * [Refactor] Standardize data type usage in various modules - Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability. - Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase. - This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations. * [Refactor] Update argument parsing for data types in benchmarks - Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float. - This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability. * [Refactor] Update data type handling in benchmark and example scripts - Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency. - Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase. - This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples. * [Refactor] Fix data type conversion in multiple scripts - Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts. - This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations. * [Refactor] Update float8 data type usage across multiple scripts - Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling. - This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations. * [Refactor] Enhance float8 data type handling in CUDA code generation - Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic. - Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase. - Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation. - This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy. * [Refactor] Streamline float8 data type handling in CUDA and related modules - Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks. - Updated layout inference for float8 types to improve clarity and maintainability across the implementation. - This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy. * [Refactor] Remove unnecessary cache disabling in float8 example script - Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code. - This change enhances clarity and maintainability of the example script without affecting its functionality. * [Refactor] Update data type usage in debug print tests - Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type. - This update enhances consistency in data type handling within the test suite, improving clarity and maintainability. * lint fix * Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples * Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
-
- 12 Dec, 2025 1 commit
-
-
Lei Wang authored
-
- 12 Nov, 2025 1 commit
-
-
pengxin99 authored
* Fix division by zero in RMS normalization * Fix rsqrt calculation to avoid division by zero
-
- 25 Jun, 2025 1 commit
-
-
Cunxiao Ni authored
* [Example] Update kernel compilation in examples to use @tilelang.jit - Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead. - Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability. - Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed. * format * Update example_tilelang_sparse_gqa_decode_varlen_indice.py * Update example_dequant_gemm_fine_grained.py * Update example_gemm_autotune.py --------- Co-authored-by:Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
-
- 04 Jun, 2025 1 commit
-
-
alex_xiao authored
* [CI]Add norm and layout_plot * fix lint * Remove obsolete test files for RMS normalization and plot layout, streamlining the testing suite. * Add make_mma_load_base_layout function to create MMA result layouts - Introduced a new function `make_mma_load_base_layout` for generating layout functions for storing MMA results in fragment buffers. - Added detailed docstring explaining parameters, return values, and potential exceptions. - Implemented logic for handling different data types and matrix configurations, including assertions for input validation. - Defined internal functions for mapping fragment indices to threads and local indices, enhancing the layout functionality. * Enhance MMA load test with additional imports and functionality - Added imports for `tilelang.language`, `Literal`, `Callable`, `DataType`, `IndexMap`, and `get_mma_micro_size` to support extended functionality. - Improved the `make_mma_load_base_layout` function by ensuring it can handle various data types and configurations. - Updated the test function `test_mma_load_base_layout` to validate the layout for float16 matrix A. * Fix formatting in test_fragment_mma_load_a.py by adding a blank line for improved readability. * Add RMS normalization functions to test_rms_norm.py - Introduced `rms_norm` and `rms_norm_splitk` functions for RMS normalization, enhancing the testing capabilities. - Implemented kernel functions with shared memory allocation and parallel processing for improved performance. - Updated the test function to validate the new RMS normalization implementations. * Add reference program for RMS normalization in test_rms_norm.py - Introduced `ref_program` function to provide a reference implementation for RMS normalization. - This addition enhances the testing framework by allowing comparisons against a known reference output. * Enhance RMS normalization tests with additional imports and formatting - Added import for `tilelang.language` to support extended functionality in `test_rms_norm.py`. - Improved code readability by adding blank lines for better separation of code sections. * Update RMS normalization test parameters and enhance layout plotting - Increased matrix dimensions in `test_rms_norm` to 8192 for improved performance testing. - Removed obsolete test functions in `test_fragment_mma_load_a.py` to streamline the test suite. - Enhanced layout plotting functionality by ensuring proper visualization of base, warp, and block layouts in `test_fragment_mma_load_a.py`. * Refactor RMS normalization test parameters and improve layout plotting readability - Simplified the parameters in `test_rms_norm` by removing `blk_k` for clarity. - Enhanced code readability in `test_fragment_mma_load_a.py` by adjusting the formatting of the `block_layout` definition and removing the unused `warp_cols` variable. * Enhance RMS normalization with split-k implementation and additional profiling - Added a new function `test_rms_norm_splitk` to test the split-k variant of RMS normalization. - Updated the main RMS normalization script to include profiling for the split-k implementation. - Ensured all checks pass with appropriate latency measurements for both reference and tile-lang implementations. * Remove obsolete test file `test_fragment_mma_load_a.py` to streamline the test suite. * Refactor `rms_norm.py` to streamline benchmarking output and remove redundant code. Comment out the `plot_layout` call in `fragment_mma_load_a.py` for clarity. * Refactor `test_rms_norm.py` by removing redundant test function `test_rms_norm_splitk` to streamline the test suite and improve clarity. --------- Co-authored-by:Your Name <you@example.com>
-
- 04 Apr, 2025 1 commit
-
-
Yu Cheng authored
- Introduced a new local fragment for squared values to improve performance. - Updated the computation of the RMS normalization to use the new fragment, enhancing memory efficiency. - Refactored the final multiplication step to operate on the local fragment instead of shared memory. - Added a configuration option to the kernel compilation for better control over TMA lowering. These changes enhance the efficiency and clarity of the RMS normalization implementation.
-
- 30 Mar, 2025 1 commit
-
-
Leslin authored
* Update elementwise_add.py [Bugfix] Replace profiler.mod with profiler.adapter to fix AttributeError * Update rms_norm.py [Bugfix] Replace profiler.mod with profiler.adapter to fix AttributeError * Remove adapter argument from do_bench call * Remove adapter argument from do_bench call --------- Co-authored-by:Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
-
- 26 Mar, 2025 1 commit
-
-
Lei Wang authored
* [Refactor] Improve flash attention example and layout comparison logic - Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code. - Updated the handling of `lse_local_split` to utilize parallel processing for better performance. - Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example. - Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons. * lint fix * [Enhancement] Add support for shared memory scope in Fill operation - Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation. - Implemented parallel operation and layout inference for improved performance in shared memory scenarios. - Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling. * [Refactor] Remove deprecated decorator and enhance Cython kernel handling - Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization. - Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution. - Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments. - Enhanced error checking in the tensor utility functions to ensure static shapes are enforced. - Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs. * [Feature] Add matrix multiplication test and kernel implementation - Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives. - The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types. - Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation. - Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs. - Minor formatting improvements in `deprecated.py` for better readability. * lint fix * [Refactor] Update tensor creation in matrix multiplication test - Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency. - Updated imports in `__init__.py` to include `make_tensor`. - Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers. * [Refactor] Update tensor definitions across multiple files - Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity. - Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations. - Improved documentation in README and example files to reflect changes in tensor usage. * lint fix * [Refactor] Update tensor types in attention and matrix multiplication examples - Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity. - Adjusted tensor definitions in benchmark and example files to align with the new tensor types. - Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files. * lint fix * [Refactor] Update tensor types in GEMM example and test files - Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity. - Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions. * [Refactor] Update tensor usage in customize.py - Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions. - Improved code clarity by standardizing buffer usage across the file. * [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py - Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions. - Improved code clarity by standardizing buffer usage across the test file. * [Refactor] Update tensor types to SharedBuffer and FragmentBuffer - Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions. - Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions. * [Refactor] Introduce Tensor alias for Buffer in proxy.py - Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`. - This change enhances clarity and consistency in tensor usage across the codebase.
-
- 11 Mar, 2025 1 commit
-
-
Yu Cheng authored
* [Dev][Bugfix] Add RMS Normalization Kernels and Fix Reduce Bug - Implement two RMS normalization implementations in TileLang: * `rms_norm_splitk`: Split-K reduction approach for large matrices * `rms_norm`: Full reduction kernel with simplified implementation - Add reference implementation using PyTorch for validation - Include performance benchmarking for both kernel variants - Demonstrate flexible block size and matrix size configurations * [Examples] Simplify RMS Normalization Kernel Compilation - Remove commented-out code for split-K RMS normalization - Simplify kernel compilation by removing explicit TMA lowering configuration - Update copyright header to Tile-AI Corporation - Streamline main script for RMS normalization example
-