- 04 Nov, 2025 1 commit
-
-
Lei Wang authored
* [Feature] Enhance fill operation to support various buffer types - Added support for `BufferLoad` in the `fill` function to handle different buffer types. - Updated `Fill` class to process region descriptors and buffer regions, improving flexibility in buffer handling. - Introduced checks for static bounds in region definitions to ensure safety during operations. - Refactored loop induction variable handling in `FillNode` to accommodate sliced regions. * lint fix * [Refactor] Improve Python compatibility for ParamSpec and Self - Added compatibility handling for ParamSpec and Self to support Python versions below 3.10 and 3.11 respectively. - Updated type annotations across multiple files to ensure consistent usage of typing features. * [Update] Require Python 3.9 and enhance type annotations - Updated the minimum required Python version from 3.8 to 3.9 in `pyproject.toml`. - Removed references to Python 3.8 in classifiers. - Changed type annotations from `int | None` to `Optional[int]` in multiple example files for better clarity and compatibility. - Improved import statements to use `collections.abc` for `Iterable` and `contextlib` for `AbstractContextManager` in relevant files. * [Refactor] Update import statements to enhance type annotations - Replaced imports from `typing` with `collections.abc` for `Iterable` and `Mapping` in relevant files to improve compatibility and clarity. - Updated the caching decorator from `functools.lru_cache` to `functools.cache` for better performance in the C++ compiler retrieval function. - Adjusted import statements in the language proxy file to maintain consistency in type annotations. * disable rocm rs nt test. * lint fix
-
- 23 Oct, 2025 1 commit
-
-
Tong WU authored
* [Feature] Add vectorized float16 and float32 conversion support in CUDA codegen * Implemented handling for conversions between float16 and float32 types, specifically for vectorized operations using __half22float2 and __float22half2_rn. * Enhanced the existing code to support both directions of conversion based on the lane count. * Improved overall type handling in the VisitExpr_ method for better compatibility with TileLang. * [Feature] Add float32 to float8 conversion support in CUDA codegen * Implemented handling for conversion from float32 to float8 (E4M3/E5M2) in the VisitExpr_ method. * Added vectorized conversion support using __nv_cvt_float2_to_fp8x2 for float2 to fp8x2 transformations. * Enhanced type handling for better compatibility with TileLang, particularly for float8 types. * lint * fix a bug * [Enhancement] Support lanes=4 cases and add unit test for vectorized cast * lint * [Feature] Refactor bf16 convertion operations and remove legacy compile flags * lint
-
- 10 Oct, 2025 1 commit
-
-
Tong WU authored
* revert split+sum template for MHA backward * lint * Update example_mha_bwd.py * Update example_mha_bwd_wgmma_pipelined.py * Refactor attention sink examples to support bf16 and user-defined softmax scale * fix typos * Adding compile flags for fast math optimizations and enabling BF16 support in both GQA and MHA backward implementations. * Update backward configuration for GQA and MHA examples to align with flash attention * Refactor GQA backward implementation to improve atomic add performance * Allow for slightly larger numerical error for bf16 * upd readme to show bf16 benchmark results * lint * fix ci and lint * fix comments and lint * refactor atomic add --------- Co-authored-by:Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
-
- 26 Sep, 2025 1 commit
-
-
Tong WU authored
* [Example] Add a new example to support attention sink for MHA - Introduced a new example script for multi-head attention (MHA) with sliding window attention and sink tokens. - Added a reference attention function to validate the implementation against PyTorch. - Included argument parsing for command-line execution of the example. * [Example] Replace MHA sink forward example with updated implementation - Removed the old example script for multi-head attention (MHA) with sliding window attention and sink tokens. - Introduced a new example script that modifies the attention mechanism to enhance performance and maintainability. - Updated argument parsing and reference functions to align with the new implementation. * Enhance MHA sink example with sliding window support - Added a `window_size` parameter to the `flashattn` function to enable sliding window attention. - Implemented assertions to ensure `window_size` is compatible with `block_N`. - Updated the main function to include a `tune` option for performance tuning. - Introduced a new test file to validate both full attention and sliding window scenarios. - Adjusted FLOPS calculation to account for the sliding window configuration. * lint * [Fix] Add checkinf process to fix the bug of swa * Migrate to BSHD layout to align with triton baselines * lint * fix typo * Refactor MHA sink example to use seq_q and seq_kv parameters to accommodate the new sequence length parameters. * Add GQA sink example for optimized attention mechanism & lint fix * fix several typos and bugs * lint * fix speed issues of swa * Add flash attention example with backward pass for BHSD layout and corresponding test cases * Add backward pass implementation for flash attention with sinks and corresponding test case * fix lint and typo * Optimze the calculation of `dsinks` * Add support for swa backward and update examples * fix previous typos * Add example for GQA sink backward pass and update tests for both MHA and GQA sinks * fix lint * fix previous typos * typo
-