Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)



**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in T.thread_binding(length, "blockIdx.x"):
               for tx in T.thread_binding(length, "threadIdx.x"):
                   C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Device call
   ```
   After compilation, `add` becomes a CUDA `__device__` function.

4. **Cython-based Python/C++ interop**  
   Replaced ctypes with Cython for all Python/C++ interactions:
   - Python → C++ calls
   - C++ → Cython calls  
   This improves performance by around 100x and reduces CPU overhead during compile/runtime.

5. **FP8 data type standardization**  
   Migrated `e5m2_float8` and similar types to Torch-standardized variants`float8_e5m2` and etc.



* Refactor CMakeLists.txt to set default build type and manage dependencies for tvm_cython modules

* Update default value of `check_well_formed` parameter in `prim_func` to False for improved flexibility in TIR function parsing.

* Add StorageRewrite function to transform module

Introduced the StorageRewrite function in the tilelang.transform module, which returns a TVM transform pass. This addition enhances the functionality of the module by providing a new transformation option for users.

* Refactor null option handling in IR and layout inference

- Updated instances of `NullOpt` to `std::nullopt` in `ir.cc` and `parallel.cc` for consistency with modern C++ practices.
- Enhanced layout inference logic in `layout_inference.cc` to improve type safety by replacing `as<Fragment>().get()` with `as<FragmentNode>()`.
- Adjusted error handling in `multi_version_buffer_rewriter.cc` and `persist_threadblock.cc` to use more concise null checks.
- Cleaned up test files by commenting out `tilelang.testing.main()` and replacing it with specific test function calls for better clarity.
- Removed unused test file `test_tilelang_kernel_deepseek_nsa.py` to streamline the testing suite.

* Update TVM subproject and refactor cluster planning and tile operation handling

- Updated the TVM subproject to a dirty commit state.
- Refactored copyright headers in `cluster_planning.cc` to reflect the new licensing.
- Enhanced error handling in `lower_tile_op.cc` to check for missing padding map annotations.
- Modified test files to improve clarity and functionality, including adjustments to kernel compilation and test assertions.
- Updated various test cases to ensure proper handling of annotations and configurations in the TileLang testing framework.

* Update annotation type in warp specialized test for consistency

- Changed the annotation type in the `test_warp_specialized` function from a literal integer to `T.int32(3)` for improved type safety and consistency with the TileLang framework.

* Refactor test execution in warp specialized test

- Replaced the direct call to `test_warp_specialized()` with `tilelang.testing.main()` in the test file to standardize test execution and improve integration with the TileLang testing framework.

* refactor

* [Enhancement] Add strict layout map for improved buffer layout inference (#594)

- Introduced a `strict_layout_map` to enhance layout inference by ensuring that buffers with strict layout requirements are properly accounted for during the inference process.
- Updated the inference logic to check for the presence of buffers in the `strict_layout_map` before applying layout changes, improving the accuracy of layout assignments.
- Refactored the layout inference steps to include the copying of layouts into the new strict map, ensuring a clear separation of layout handling based on inference levels.

* [Example] Update examples to use @tilelang.jit (#597)

* [Example] Update kernel compilation in examples to use @tilelang.jit

- Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead.
- Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability.
- Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed.

* format

* Update example_tilelang_sparse_gqa_decode_varlen_indice.py

* Update example_dequant_gemm_fine_grained.py

* Update example_gemm_autotune.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Enhancement] Refine error messaging in LowerBulkCopy for global and shared range checks (#599)

* [Enhancement] Improve error messaging for global and shared range legality checks in LowerBulkCopy

- Updated error messages in the LowerBulkCopy function to provide clearer context when global and shared ranges are illegal.
- Enhanced the readability of the error output by including tensor names, improving debugging and validation processes during bulk copy operations.

* [Enhancement] Refine error messaging in LowerBulkCopy for global and shared range checks

- Improved the clarity of error messages in the LowerBulkCopy function by enhancing the output format.
- Included additional context in error messages to aid debugging when global and shared ranges are found to be illegal, ensuring better traceability during bulk copy operations.

* [Enhancement] Introduce PassConfig `TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE` to enable aggressive shared memory reuse (#602)

* [Enhancement] Add aggressive shared memory merge option in memory allocation

- Introduced a new configuration option `tl.enable_aggressive_shared_memory_merge` to enable aggressive merging of shared memory allocations.
- Updated the `SharedMemLinearAccessPatternFinder` class to support an aggressive merge strategy, allowing for improved memory reuse.
- Modified the `MergeSharedMemoryAllocations` function to incorporate the new merging strategy based on the configuration.
- Enhanced the `PassConfigKey` enumeration to include the new aggressive merge option, ensuring it can be configured appropriately.

* lint fix

* [Enhancement] Add aggressive shared memory merge configuration option

- Introduced a new configuration option `kEnableAggressiveSharedMemoryMerge` to enable aggressive merging of shared memory allocations, enhancing memory management capabilities.

* [Enhancement] Update MergeSharedMemoryAllocations to support aggressive merge option

- Modified the `MergeSharedMemoryAllocations` function to accept an `enable_aggressive_merge` parameter, allowing for more flexible memory management.
- Introduced a new helper function `should_enable_aggressive_merge` to determine the aggressive merge configuration based on the pass context and target.
- Updated the relevant calls in the `phase.py` and `__init__.py` files to utilize the new aggressive merge functionality, enhancing the overall memory allocation strategy.

* [Refactor] Update accumulation handling in gemm_sm90.h (#603)

- Replaced the use of `tiled_mma.accumulate_ = GMMA::ScaleOut::Zero` with a call to `clear(acc)` for better clarity and maintainability in the accumulation logic.
- This change enhances the readability of the code by standardizing the approach to clearing accumulation values across multiple sections of the file.

* [Enhancement] Add tma bulk copy. (#600)

* [Bugfix] Fixed mha_bwd shape inconsistency error (#604)

* lint fix

* Update requirements-lint.txt to maintain clang-format version consistency

* [Bugfix] Avoid duplicate data access when cross thread buffer meet replicate register (#606)

* [Enhancement] Improve debug output formatting in layout and fragment nodes

- Updated the `DebugOutput` methods in `LayoutNode` and `FragmentNode` to provide more structured and informative output, including transformation details and thread range information.
- Enhanced layout inference logic in `ParallelOp` to add predicates for cross-thread shared memory access, improving layout handling in parallel operations.
- Minor adjustment in `layout_inference.cc` to ensure clarity in parallel loop handling.

* lint fix

* [Enhancement] Support tf32 gemm_rs (#607)

- Added a line break in `quickstart.py` for better readability.
- Simplified the JIT kernel compilation in `quickstart.py` by removing the unused execution backend option.
- Modified `example_elementwise_add.py` to disable cache for `tilelang` and optimized the element-wise addition kernel by utilizing shared memory for input tensors, improving performance.
- Updated default values for matrix dimensions and block sizes in the argument parser to enhance usability.

* [Enhancement] Introduce option `TL_DISABLE_FAST_MATH` and `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` (#609)

* [Enhancement] Introduce new PassConfig options for fast math and PTXAS verbosity

- Added `kDisableFastMath` and `kEnablePTXASVerboseOutput` configuration options to enhance control over compilation settings.
- Updated `LibraryGenerator` to utilize these new pass configurations, allowing for more flexible compilation behavior based on user preferences.
- Enhanced `PassConfigKey` enumeration to include the new options, ensuring they can be configured appropriately in the pass context.

* [Refactor] Update PTXAS verbosity configuration key in LibraryGenerator

- Changed the configuration key for PTXAS verbosity from `TL_VERBOSE_PTXAS_OUTPUT` to `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` to align with the new naming convention introduced in recent enhancements.
- This update ensures consistency in the configuration options used within the `LibraryGenerator` class, improving clarity and maintainability of the code.

* lint fix

* fix build

* [Experimental][Language] add `T.GEMM_SP` for sm90 sparse tensor core (#526)

* [experimental] add a draft gemm_sp

* [3rdparty] bump cutlass to v3.9.3

* [lint] run format.sh

* [chore] rebase

* [chore] use abs path

* [gemm_sp] add metadata layout

* [ci] add more example

* [lint] run format.sh

* [chore] polish

* [chore] move gemm_sp to experimental

* [chore] polish

* [lint] run format.sh

* [Enhancement] Improve bulk copy handling and update GEMM sparse tensor test

* Added a warning log for unsupported non-swizzled global layouts in the bulk copy operation, ensuring fallback to normal copy.
* Refactored the GEMM sparse tensor test by removing unnecessary imports and simplifying the kernel compilation process.
* Updated the test to directly call the `run_gemm_sp` function, enhancing clarity and functionality.

* Implement Test

* [Enhancement] Update GEMM SP and SM89 templates for improved functionality

* Refactored GEMM SP computation to enhance warp partitioning logic, ensuring compatibility with Hopper architecture.
* Updated layout inference to support new WGMMA conditions and improved error messaging for unsupported targets.
* Modified SM89 templates to utilize new MMA atom structures, enhancing performance and compatibility with fp8 types.
* Added conditional inclusion for GEMM SP header based on CUDA architecture version.

* lint fix

* [gemm_sp] support more layout and data types

* Enhancement: sync T.gemm_sp's layout inference with T.gemm

* Enhancement: support more block_k in compress util

* [Enhancement] enable block_k=64

* [Lint] run format.sh

* [Enhancement] compressor support more dtype

* Enhancement: enable block_K=32

* [Lint] format.sh

* [Fixbug] fix shape

* Refactor: sync gemm

* [Enhancement] enable transpose

* [Enhancement] enable fp8_e4m3

* [Enhancement] enable int8

* [Lint] run format.sh

* [Benchmark] add gemm_sp benchmark

* [Example] fix 256 threads hang

* [CI] fix ci

* [Chore] resolve gemini feedback

* [Benchmark] increase search space

* [Lint] format

* [CI] skip sparse tensor core related tests as only sm90 is supported

* [CI] pass local run

* Update gemm_sm89.h

* lint fix

* lint fix

* [Enhancement] Add support for sparse GEMM and initialize CUDA architecture flags

- Introduced a new boolean flag `enable_sparse_gemm_` to control the inclusion of sparse GEMM functionality in CUDA code generation.
- Updated the `Finish` method to conditionally include the sparse GEMM header based on the new flag.
- Implemented logic in `VisitStmt_` to enable sparse GEMM when the corresponding external call is detected.
- Added a function to initialize the `TORCH_CUDA_ARCH_LIST` environment variable based on the target compute version, enhancing compatibility with PyTorch.
- Refactored the initialization function into the appropriate module and ensured it is called in the sparse utilities module.

* Update test_compress_utils.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Doc] Phaseout Legacy documentations (#610)

- Added a new entry in the README for the introduction of `T.gemm_sp` supporting 2:4 sparse tensor core.
- Removed several outdated documentation files related to convolution, flash attention, and other tutorials to streamline the documentation structure.

* [Refactor] Phaseout Pass ParallelLoopTransformer (#611)

* Refactor layout inference by removing the ParallelLoopTransformer class. Updated layout inference logic to streamline buffer access collection and condition handling in parallel loops. This change simplifies the code structure and enhances maintainability.

* Update MHA backward test cases to use reduced dimensions for batch size and context length

* fix build

* [Enhancement] Update ReduceOp initialization values for integer types (#614)

* [Enhancement] Update ReduceOp initialization values for integer types

- Modified the `MakeInitValue` method in `ReduceOp` to handle integer data types correctly by returning appropriate minimum and maximum values based on the bit width.
- Added checks for integer types to ensure correct initialization for `kMax` and `kMin` reduction types, enhancing the robustness of the reduction operations.

* [Enhancement] Update ReduceOp to handle unsigned integer initialization values

- Enhanced the `MakeInitValue` method in `ReduceOp` to include support for unsigned integer data types.
- Added conditions to return appropriate initialization values for `kMax` and `kMin` reduction types based on the data type, improving the robustness of reduction operations.

* Bump transformers from 4.50.0 to 4.51.0 in /examples/bitnet-1.58b (#615)

Bumps [transformers](https://github.com/huggingface/transformers) from 4.50.0 to 4.51.0.
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](https://github.com/huggingface/transformers/compare/v4.50.0...v4.51.0

)

---
updated-dependencies:
- dependency-name: transformers
  dependency-version: 4.51.0
  dependency-type: direct:production
...
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Refactor] refactor autotune examples (#617)

* [Refactor] Update tilelang kernel functions and remove unused imports

- Refactored the `flashattn_fwd`, `flashattn_bwd_preprocess`, and `flashattn_bwd_postprocess` functions to utilize direct kernel calls instead of cached versions, improving clarity and performance.
- Added `@tilelang.jit` decorators with specified output indices to enhance kernel compilation.
- Removed unused import of `cached` from `tilelang`, streamlining the code.
- Commented out the main testing function call in `test_tilelang_kernel_mha_bwd.py` for potential future use.

* [Refactor] Simplify configuration generation in benchmark and example scripts

- Refactored the `get_configs` functions in multiple benchmark and example scripts to utilize a dictionary-based approach for parameter configuration, improving readability and maintainability.
- Updated the `flashattn` and `chunk_scan_fwd` functions to directly accept configuration parameters, enhancing flexibility in kernel tuning.
- Removed redundant code and streamlined the configuration generation process across various files, ensuring consistency in how configurations are defined and utilized.

* [Refactor] Update configuration handling in benchmark scripts

- Refactored the `get_configs` functions in benchmark scripts to accept a variable argument list, improving flexibility in configuration management.
- Enhanced the `matmul` and `flashattn` functions to utilize the updated configuration approach, streamlining parameter handling for kernel tuning.
- Added `@autotune` decorators to relevant functions, ensuring consistent autotuning behavior across benchmarks.
- Cleaned up redundant code and improved overall readability in the affected files.

* [Refactor] Clean up formatting and update subproject commit

- Updated the subproject commit reference in the TVM directory to indicate a dirty state.
- Removed unnecessary blank lines and improved formatting in the `benchmark_matmul` and `benchmark_matmul_fp8` scripts for better readability.
- Streamlined the function definitions in the `flashattn` example script to enhance clarity and maintainability.

* [Refactor] Update AutoTuner configuration handling

- Modified the AutoTuner class to check if kernel parameters are set before processing tunable arguments, improving robustness in configuration handling.
- Enhanced the logic for skipping compilation when tunable parameters are already provided, ensuring efficient use of resources.
- Updated comments for clarity and maintainability.

* lint fix

* Update TVM subproject commit to indicate dirty state and modify MHA backward test cases

- Updated the subproject commit reference in the TVM directory to reflect a dirty state.
- Adjusted the `test_mha_bwd` function to use a new configuration for the MHA backward tests, changing the context size from 128 to 256.
- Uncommented the main testing function call for potential execution.

* lint fix

* Bump transformers from 4.51.0 to 4.52.1 in /examples/bitnet-1.58b (#619)

Bumps [transformers](https://github.com/huggingface/transformers) from 4.51.0 to 4.52.1.
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](https://github.com/huggingface/transformers/compare/v4.51.0...v4.52.1

)

---
updated-dependencies:
- dependency-name: transformers
  dependency-version: 4.52.1
  dependency-type: direct:production
...
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Fix PTXAS options flag in LibraryGenerator for consistency (#620)

* Refactor FP8 type handling across multiple files to standardize usage of "float8_e4m3" and "float8_e5m2" instead of "e4m3_float8" and "e5m2_float8". This includes updates in benchmarks, examples, tests, and internal utilities.

* [Refactor] Add parallel loop transform pass for condition extraction (#618)

* [Refactor] Add parallel loop transform

* done format check

* pull 3rdparty repo

* Refactor loop variable handling in transformation utilities

- Updated the logic in `loop_parallel_transform_utils.h` to simplify the handling of related loop variables.
- Removed the check that enforced a single related loop variable, replacing it with a return statement when multiple variables are detected, enhancing clarity and maintainability of the transformation process.

* Update loop_parallel_transform_utils.h

* Refactor loop variable handling in transformation utilities

- Enhanced the logic in `loop_parallel_transform_utils.h` to improve clarity and maintainability by simplifying the handling of related loop variables.
- Replaced the previous enforcement of a single related loop variable with a return statement for multiple variables detected.

* remove disable cache flag as commit id has been key component

* lint fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Dev] Update linear attention examples to enhance performance on Hopper GPUs (#621)

* Tune linear attention examples on H100

* Add retnet fwd kernel

* fix lint

* [Enhancement] Add ahead of time cython compilation in setup.py (#622)

* [Enhancement] Add Cython support and compiler detection in setup.py

- Introduced a new `CythonExtension` class for building Cython-based extensions, enhancing the build process for Cython projects.
- Implemented functions to detect the Cython compiler and C++ compiler, improving compatibility and user experience.
- Updated the build process to handle Cython extensions alongside CMake extensions, ensuring a seamless integration for users.
- Added caching mechanisms for Cython compilation to optimize build times and reduce unnecessary recompilation.

* [Enhancement] Add Cython dependency and enable CMake extension building

- Added Cython as a required dependency in `pyproject.toml` to support Cython-based extensions.
- Updated `setup.py` to enable building CMake extensions, improving the build process for projects utilizing both Cython and CMake.
- Modified the Cython compiler detection logic to streamline installation instructions for users.

* [Enhancement] Support more flexible layout host pythonic expr (#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix

* [Enhancement] support composable expression for shape with symbolic vars (#624)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix

* minor update

* 🐍

Fix the file name "test_exmaple_tilelang_nsa" (#629)

* [Enhancement] Add CPU utilization and count settings for Auto-Tuning (#630)

* [Enhancement] Add CPU utilization and count settings for Auto-Tuning

- Introduced environment variables for CPU utilization, counts, and maximum CPU count for auto-tuning.
- Updated the AutoTuner class to utilize these new settings, improving flexibility and performance in multi-threaded environments.
- Enhanced logging to provide better insights into the auto-tuning process based on the configured CPU settings.

* typo fix

* [AutoTune] Support `with set_autotune_inputs` to set auto tuning input tensors (#632)

* [Refactor] Simplify and modularize autotuner implementation

- Removed unused imports and extensive code sections from the autotuner module to enhance readability and maintainability.
- Modularized the code by introducing new imports for autotuning and capturing functionalities, streamlining the overall structure.
- Improved logging setup and removed redundant timeout handling functions, focusing on core autotuning logic.
- Updated the AutoTuner class to better utilize the new modular structure, ensuring efficient performance during auto-tuning processes.

* [Refactor] Clean up and enhance capture and tuner modules

- Improved code readability by removing unnecessary blank lines and organizing imports in `capture.py` and `tuner.py`.
- Enhanced logging in the `AutoTuner` class to provide clearer warnings regarding the usage of `supply_prog` in the context of auto-tuning.
- Streamlined the `CaptureStack` class for better thread-local context management.

* lint fix

* [Refactor] Simplify configuration and autotuning logic in blocksparse GEMM example

- Updated `get_configs` function to reduce the number of configurations, enhancing performance and clarity.
- Removed the `get_best_config` function, integrating its logic directly into the `blocksparse_matmul` function with the `@autotune` decorator for streamlined autotuning.
- Adjusted the main function to directly utilize the autotuned kernel, simplifying the overall structure and improving readability.
- Deleted obsolete test file for autotuning decorator, cleaning up the codebase.

* [Refactor] Improve code formatting and readability in autotune test file

- Reformatted the `matmul` function and `get_configs` function for better readability by adjusting line breaks and indentation.
- Fixed a typo in the `enable_rasteration` parameter name to ensure consistency.
- Cleaned up unnecessary blank lines to enhance overall code clarity.

* Update example_blocksparse_gemm.py

* Update capture.py

* [Pass] Introduce flag to diable cp async lowering (#633)

* [Enhancement] Update PipelinePlanner to support async copy configuration

- Modified the `Substitute` method in `PipelinePlanner` to accept a `use_async_copy` parameter, allowing for more flexible pipeline planning based on async copy requirements.
- Updated the constructor of `PipelinePlanner` to initialize the `use_async_copy_` member variable.
- Adjusted the logic in the pipeline planning process to conditionally apply async copy annotations based on the new parameter.
- Commented out the `LoopVectorizeDynamic` call in `LowerAndLegalize` to prevent unintended modifications during the legalizing phase.

* Refactor PipelinePlanning function for improved readability

- Adjusted the formatting of the `use_async_copy` variable assignment in the `PipelinePlanning` function to enhance code clarity and maintainability.

* fix typo (#635)

* [Pass][Simplify] Introduce symbolic level simplify for condition expression (#634)

* [Enhancement] Add argument simplification option to StmtSimplifier

- Introduced a new `simplify_arguments` flag in the `StmtSimplifier::Apply` method to control argument simplification behavior.
- Updated the `Simplify` function to accept the new flag, allowing for enhanced flexibility in the simplification process.
- Adjusted the `LowerAndLegalize` and `_Simplify` functions to utilize the new argument, ensuring consistent behavior across the codebase.
- Added comments to clarify the purpose of the new flag and its impact on simplification logic.

* lint fix

* [Enhancement] Improve layout inference and reduce operation handling

- Updated `ParallelOp::InferLayout` to check for pure buffer stores, enhancing layout inference logic.
- Modified `ReduceOp::Lower` to include all threads in the AllReduce operation, improving performance on specific architectures.
- Added a TODO comment in `AllReduce` to consider merging synchronization barriers for optimization.

* lint fix

* [Enhancement] Add input validation for GEMM parameters

- Introduced checks to ensure that the dimensions M and N are divisible by their respective warp sizes (kMPerWarp and kNPerWarp) in the Gemm::ComputeWarpPartition method.
- Added informative error messages to assist in debugging when the input parameters do not meet the required conditions.

* bug fix

* Enhance test coverage by adding LLVM requirement decorator to multiple function call tests. This ensures that tests for argument count, type code, null data pointer, and dimensionality checks are only executed when LLVM is available, improving test reliability and clarity.

* lint fix

* Fix software pipeline stage annotation and update optional config handling in StmtSimplifier

* Add Python executable detection in CMake configuration and update TVM submodule reference. Remove unused vectorization tests for improved clarity.

* Update TVM submodule reference and refactor FFI registration to use static initialization blocks for improved organization and clarity.

* Refactor attribute handling in layout and IR nodes to use reflection registration. This change replaces the VisitAttrs method with a RegisterReflection method for improved clarity and organization across multiple classes, including KernelLaunchFrameNode, WarpSpecializeFrameNode, LayoutNode, FragmentNode, and SwizzledLayoutNode.

* finish rebase

* tvm update

* Refactor FFI registration across tilelang modules to use the updated `tvm.ffi` namespace. This includes changes in various files to replace `tvm._ffi` with `tvm.ffi`, enhancing consistency and clarity in the codebase.

* lint fix

* Update TVM submodule reference and modify CUDA runtime argument handling to use the new runtime constants for improved clarity and consistency.

* lint fix

* Refactor tensor data type references from "e4m3_float8" and "e5m2_float8" to "float8_e4m3" and "float8_e5m2" across multiple files for consistency and clarity.

* lint fix

* Refactor forward_index initialization in Fragment class to default to an empty array instead of None, ensuring consistent handling of optional outputs.

* test fix

* lint fix

* bugfix

* lint fix

* reduce fix

* lint fix

* carver fix

* cast fix

* Update submodule and enhance kernel launch functionality with optional block size parameter; add device kernel launch transformation.

* lint fix

* bugfix

* Refactor test execution in test_tilelang_cpu_gemm.py and enhance device call checks in lower.py to exclude C packed functions from kernel launch conditions.

* lint fix

* Update runtime.cc

* phase out lisence

* Update subproject commit for TVM to 555cc71

* Update subproject commit for TVM to d39953fa

* Update subproject commit for TVM to 9574805f

* Update subproject commit for TVM to a08b7c3

* fix ci

* ci fix

---------
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarCunxiao Ni <85601223+Cunxiao2002@users.noreply.github.com>
Co-authored-by: default avatarYuxi Chi <cherichy@outlook.com>
Co-authored-by: default avatarNathan Chen <120630832+Nathancgy@users.noreply.github.com>
Co-authored-by: default avatarbotbw <wang1570@e.ntu.edu.sg>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: default avatarxs-keju <93414213+xs-keju@users.noreply.github.com>
Co-authored-by: default avatarTong WU <109033598+Rachmanino@users.noreply.github.com>
Co-authored-by: default avatarKadir Nar <kadir.nar@hotmail.com>
Co-authored-by: default avatarYuqing Xia <35415939+xiayuqing0622@users.noreply.github.com>
Co-authored-by: default avatarxwhzz <wh.xie@outlook.com>

parent 8edd6941
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file clasuter_planning.cc
* \brief Plan the cluster for GPU(sm90+) blocks
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
......@@ -132,8 +115,10 @@ tvm::transform::Pass ClusterPlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ClusterPlanning")
.set_body_typed(ClusterPlanning);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning);
});
} // namespace transform
} // namespace tir
......
......@@ -599,7 +599,7 @@ public:
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = NullOpt;
Optional<Stmt> else_case = std::nullopt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
......@@ -681,10 +681,6 @@ public:
stmt = Substitute(stmt, {{var_, idx}});
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode *op) final {
LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
}
private:
// analyzer
......
#include "../op/builtin.h"
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/op.h>
......@@ -85,8 +86,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ConfigIndexBitwidth")
.set_body_typed(ConfigIndexBitwidth);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth",
ConfigIndexBitwidth);
});
} // namespace tl
} // namespace tvm
......@@ -5,7 +5,8 @@
#include "./storage_access.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -115,8 +116,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
{});
}
TVM_REGISTER_GLOBAL("tl.transform.EliminateStorageSyncForMBarrier")
.set_body_typed(EliminateStorageSyncForMBarrier);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier",
EliminateStorageSyncForMBarrier);
});
} // namespace transform
} // namespace tl
......
......@@ -24,6 +24,7 @@
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/transforms/ir_utils.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/stmt_functor.h>
......@@ -352,12 +353,7 @@ private:
};
PrimFunc FlattenBufferRewriter(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
return BufferFlattener::Flatten(f);
} else {
return f;
}
}
using namespace tir::transform;
......@@ -368,7 +364,10 @@ tvm::transform::Pass FlattenBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {});
}
TVM_REGISTER_GLOBAL("tl.transform.FlattenBuffer").set_body_typed(FlattenBuffer);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer);
});
} // namespace tl
} // namespace tvm
......@@ -22,6 +22,7 @@
* \brief Legalize the program from frontend
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
......@@ -88,8 +89,10 @@ Pass FrontendLegalize() {
return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {});
}
TVM_REGISTER_GLOBAL("tl.transform.FrontendLegalize")
.set_body_typed(FrontendLegalize);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FrontendLegalize", FrontendLegalize);
});
} // namespace tl
} // namespace tvm
......@@ -3,6 +3,7 @@
* \brief Bind the If Stmt to each Stmt in SeqStmt
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -80,7 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {});
}
TVM_REGISTER_GLOBAL("tl.transform.IfStmtBinding").set_body_typed(IfStmtBinding);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding);
});
} // namespace tl
} // namespace tvm
......@@ -22,6 +22,7 @@
* \brief Inject fence between generic and async proxies (sm90+)
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -193,8 +194,10 @@ tvm::transform::Pass InjectFenceProxy() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectFenceProxy")
.set_body_typed(InjectFenceProxy);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy);
});
} // namespace tl
} // namespace tvm
......@@ -22,6 +22,7 @@
* \brief Transform annotated loops into pipelined one that parallelize
* producers and consumers
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>
......@@ -737,7 +738,7 @@ private:
}
if (!is_unit_loop) {
Map<String, ObjectRef> preserved_annotations;
Map<String, Any> preserved_annotations;
for (const auto &kv : pipeline_loop_->annotations) {
const String &key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage &&
......@@ -748,7 +749,7 @@ private:
}
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
std::move(new_loop), NullOpt, preserved_annotations);
std::move(new_loop), std::nullopt, preserved_annotations);
}
// Update producer heads in the global async states.
for (const auto &[stage_id, state] : async_states_local) {
......@@ -955,7 +956,7 @@ private:
std::unordered_set<int> pipeline_async_stages;
if (auto annot =
op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
for (auto s : Downcast<Array<Integer>>(annot)) {
for (auto s : Downcast<Array<Integer>>(annot.value())) {
pipeline_async_stages.insert(s->value);
}
}
......@@ -1038,8 +1039,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline")
.set_body_typed(InjectSoftwarePipeline);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
InjectSoftwarePipeline);
});
} // namespace tl
} // namespace tvm
......@@ -21,6 +21,7 @@
* \brief Replace copy from global to shared with async copy
* \file inject_ptx_async_copy.cc
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -231,8 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectPTXAsyncCopy")
.set_body_typed(InjectPTXAsyncCopy);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy);
});
} // namespace tl
} // namespace tvm
......@@ -23,6 +23,7 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -306,8 +307,10 @@ tvm::transform::Pass InjectTmaBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectTmaBarrier")
.set_body_typed(InjectTmaBarrier);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier);
});
} // namespace tl
} // namespace tvm
......@@ -3,6 +3,7 @@
* \brief infer the fragment/shared memory layout
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
......@@ -138,11 +139,10 @@ public:
if (layout_map.count(buffer)) {
// If replicate size of this buffer is greater than the old one
if (buffer.scope() == "local.fragment" &&
level != InferLevel::kStrict &&
!strict_layout_map.count(buffer)) {
const FragmentNode *dst_layout = layout.as<Fragment>().get();
level != InferLevel::kStrict) {
const FragmentNode *dst_layout = layout.as<FragmentNode>();
const FragmentNode *src_layout =
layout_map[buffer].as<Fragment>().get();
layout_map[buffer].as<FragmentNode>();
if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) >
......@@ -313,7 +313,7 @@ private:
auto var = call->args[1].as<Var>().value();
return buffer_data_to_buffer_[var];
}
return NullOpt;
return std::nullopt;
}
void addToUseList(const Buffer &buffer) {
......@@ -354,11 +354,9 @@ private:
}
if (op->annotations.count(attr::kLayoutMap)) {
// Check if the layout map is Map<Var, Layout>
auto map = op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>();
ICHECK(map.defined()) << "layout map is not defined";
ICHECK(map.value().defined()) << "layout map is not defined";
for (const auto &[var, layout] : map.value()) {
auto map =
op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var];
......@@ -519,8 +517,10 @@ tvm::transform::Pass LayoutInference() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
.set_body_typed(LayoutInference);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
});
} // namespace tl
} // namespace tvm
......@@ -3,6 +3,7 @@
* \brief legalize safe memory access
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -313,7 +314,7 @@ private:
}
if (op->annotations.count(attr::kPaddingMap)) {
auto map = op->annotations.Get(attr::kPaddingMap)
.as<Map<Var, PrimExpr>>()
->as<Map<Var, PrimExpr>>()
.value();
for (const auto &[var, padding] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
......@@ -353,8 +354,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LegalizeSafeMemoryAccess")
.set_body_typed(LegalizeSafeMemoryAccess);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess",
LegalizeSafeMemoryAccess);
});
} // namespace tl
} // namespace tvm
......@@ -22,6 +22,7 @@
* \brief infer the fragment/shared memory layout
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -88,8 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LegalizeVectorizedLoop")
.set_body_typed(LegalizeVectorizedLoop);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop",
LegalizeVectorizedLoop);
});
} // namespace tl
} // namespace tvm
......@@ -6,6 +6,7 @@
#include <cstdint>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -145,9 +146,7 @@ private:
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits();
if (access_type.is_e4m3_float8() or access_type.is_e5m2_float8()) {
max_vector_size = 1; // [temporarily] do not vectorize float8
}
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
......@@ -532,8 +531,11 @@ tvm::transform::Pass LoopVectorizeDynamic() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LoopVectorizeDynamic")
.set_body_typed(LoopVectorizeDynamic);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LoopVectorizeDynamic",
LoopVectorizeDynamic);
});
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_device_kernel_launch.cc
* \brief Split device function from host.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
namespace {
struct KernelInfo {
// The device on which the PrimFunc runs
Target target;
// The externally visible symbol which may refer to the PrimFunc
// when launching a device kernel.
String global_symbol;
// The parameters accepted by the PrimFunc. Used to rewrite
// `launch_args` to be in terms of the calling scope.
Array<Var> params;
// The launch parameters that should annotate the PrimFunc, if the
// kernel is ever called from the host.
Array<String> launch_params;
// Additional arguments which must be provided to the host-side
// PackedFunc. These may be in terms of the function's parameters
// (e.g. a function that computes the average of `N` elements, and
// which must be launched with `N` CUDA threads).
Array<PrimExpr> launch_args;
// The extent of each thread
Map<String, PrimExpr> thread_extent;
// The amount of dynamic shared memory used
Optional<PrimExpr> dyn_shmem_size{std::nullopt};
};
/*!
* \brief Visitor class to collect device-side program information.
*/
class DeviceInfoCollector : public StmtVisitor {
public:
static KernelInfo Collect(const GlobalVar &gvar, const PrimFunc &func) {
DeviceInfoCollector collector;
collector.info_.target =
func->GetAttr<Target>(tvm::attr::kTarget).value().WithoutHost();
collector.info_.params = func->params;
collector(func->body);
// The dynamic shared memory is required to be the last of the
// kernel launch parameters
if (collector.dyn_shmem_size) {
collector.info_.launch_params.push_back(
tvm::runtime::launch_param::kUseDynamicSharedMemoryTag);
}
collector.info_.global_symbol =
func->GetAttr<String>(tvm::attr::kGlobalSymbol)
.value_or(gvar->name_hint);
collector.info_.launch_args = collector.info_.launch_params.Map(
[&](const auto &param) { return collector.GetArgument(param); });
collector.info_.dyn_shmem_size = collector.dyn_shmem_size;
collector.info_.thread_extent = collector.thread_extent;
return collector.info_;
}
private:
PrimExpr GetArgument(const String &launch_param) const {
if (launch_param ==
tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) {
CHECK(dyn_shmem_size.defined())
<< "Compute kernel requires launch parameter \"" << launch_param
<< "\", but PrimFunc did not contain Allocate node with shared "
"dynamic scope.";
return dyn_shmem_size.value();
}
auto extent = thread_extent.Get(launch_param);
CHECK(extent) << "Compute kernel requires launch parameter \""
<< launch_param
<< "\", but PrimFunc does not contain AttrStmt \""
<< tir::attr::thread_extent
<< "\" defining this thread extent";
return extent.value();
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times
// use the first appearance as def.
if (!defined_thread.count(iv.get())) {
defined_thread.insert(iv.get());
info_.launch_params.push_back(iv->thread_tag);
thread_extent.Set(iv->thread_tag, op->value);
}
}
StmtVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateNode *op) final {
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
ICHECK(!dyn_shmem_size.defined())
<< "Only one dynamic shared memory allocation is allowed.";
ICHECK_GT(op->extents.size(), 0);
PrimExpr dyn_size = Integer(1);
for (const auto &extent : op->extents) {
dyn_size *= extent;
}
dyn_size *= op->dtype.bytes() * op->dtype.lanes();
dyn_shmem_size = dyn_size;
}
StmtVisitor::VisitStmt_(op);
}
// The collected results
KernelInfo info_;
// recording what thread axis have been visited.
std::unordered_set<const IterVarNode *> defined_thread;
// The extent of each thread
Map<String, PrimExpr> thread_extent;
// The amount of dynamic shared memory used
Optional<PrimExpr> dyn_shmem_size{std::nullopt};
};
class ReturnRemover : public StmtExprMutator {
public:
static Stmt Apply(const Stmt &stmt) {
ReturnRemover mutator;
return mutator(stmt);
}
private:
using Parent = StmtExprMutator;
Stmt VisitStmt_(const EvaluateNode *op) override {
if (auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::ret())) {
ICHECK_EQ(call->args.size(), 1);
auto as_int = call->args[0].as<IntImmNode>();
ICHECK(as_int && as_int->value == 0)
<< "Device kernel may only contain successful return, T.ret(0)";
return Evaluate(0);
}
}
return Parent::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode *op) override {
if (op->op.same_as(builtin::ret())) {
LOG(FATAL) << "Call to builtin::ret() should only appear within an "
"Evaluate node";
}
return Parent::VisitExpr_(op);
}
};
} // namespace
class DeviceKernelMutator : public StmtExprMutator {
public:
using Parent = StmtExprMutator;
explicit DeviceKernelMutator(
std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map)
: device_info_map_(std::move(device_info_map)) {}
PrimFunc RewriteKernelLaunchSite(const GlobalVar &gvar, PrimFunc func) {
ICHECK(!current_target_.defined());
auto it = device_info_map_.find(gvar.get());
ICHECK(it != device_info_map_.end());
current_target_ = it->second.target;
auto body = VisitStmt(func->body);
if (!body.same_as(func->body)) {
func.CopyOnWrite()->body = body;
}
current_target_ = std::nullopt;
return func;
}
PrimFunc UpdateKernelAttributes(const GlobalVar &gvar, PrimFunc func) const {
bool is_kernel_launch = device_kernel_launch_.count(gvar.get());
bool is_call_extern = extern_function_call_.count(gvar.get());
CHECK(!is_kernel_launch || !is_call_extern)
<< "Function " << gvar << " has multiple callees, "
<< "and would need to be lowered into a call_extern at some call "
"sites, "
<< "and a device kernel launch at others. "
<< "This case is not yet supported.";
if (is_kernel_launch || is_call_extern) {
func =
WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, Bool(true));
}
if (is_kernel_launch) {
const auto &info = device_info_map_.at(gvar.get());
// Kernel launches provide an int32 error code to the caller,
// but do not accept any return type from the callee.
{
auto write_ptr = func.CopyOnWrite();
write_ptr->ret_type = VoidType();
write_ptr->body = ReturnRemover::Apply(write_ptr->body);
}
func =
WithAttrs(std::move(func),
{{tvm::attr::kCallingConv,
Integer(tvm::CallingConv::kDeviceKernelLaunch)},
{tvm::tir::attr::kKernelLaunchParams, info.launch_params},
{tvm::attr::kGlobalSymbol, info.global_symbol}});
}
// @lei: workaround as we may require c host codegen, so we need to set the
// global symbol for cpu backend.
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
const auto &info = device_info_map_.at(gvar.get());
const auto &thread_extent = info.thread_extent;
func = WithAttr(std::move(func), "thread_extent", thread_extent);
if (info.dyn_shmem_size.defined()) {
func = WithAttr(std::move(func), "dyn_shared_memory_buf",
info.dyn_shmem_size.value());
}
return func;
}
private:
PrimExpr VisitExpr_(const CallNode *op) override {
auto node = Downcast<Call>(Parent::VisitExpr_(op));
auto *gvar = op->op.as<GlobalVarNode>();
if (!gvar)
return std::move(node);
auto it = device_info_map_.find(gvar);
ICHECK(it != device_info_map_.end())
<< "CallNode attempted subroutine call to " << gvar->name_hint
<< ", but " << gvar->name_hint << " did not appear within the IRModule";
const KernelInfo &dev_info = it->second;
auto caller_target = current_target_.value();
auto callee_target = dev_info.target;
bool same_target = caller_target->str() == callee_target->str();
if (same_target) {
// Calls within the same target may be handled at codegen time
// as internal subroutine calls.
return std::move(node);
}
bool same_device_type = caller_target->GetTargetDeviceType() ==
callee_target->GetTargetDeviceType();
if (same_device_type) {
// Calls to another target using the same device (e.g. LLVM
// calling a custom TIRToRuntime target) do not require a kernel
// launch, but need to be replaced with call_extern.
extern_function_call_.insert(gvar);
Array<PrimExpr> args;
args.push_back(StringImm(gvar->name_hint));
for (const auto &arg : node->args) {
args.push_back(arg);
}
return Call(node->dtype, builtin::call_extern(), args);
}
ICHECK(dev_info.launch_params.defined())
<< "CallNode attempted kernel launch to " << gvar->name_hint
<< " on target " << dev_info.target << ", but subroutine "
<< gvar->name_hint
<< " did not have the tir::attr::kKernelLaunchParams attribute "
<< "required for cross-target kernel launch";
// Collected kernel information may be in terms of the callee's
// arguments, but we need expressions for them in terms of the
// caller's parameters. The param_map allows substitution of
// parameter values into the thread extents, to generate
// expressions that are valid within the caller.
Map<Var, PrimExpr> param_map = [&]() {
Map<Var, PrimExpr> param_map;
CHECK_EQ(node->args.size(), dev_info.params.size())
<< "Function " << gvar->name_hint << " accepts "
<< dev_info.params.size()
<< " arguments as input, but is called using " << node->args.size()
<< " arguments";
for (size_t i = 0; i < node->args.size(); i++) {
param_map.Set(dev_info.params[i], node->args[i]);
}
return param_map;
}();
device_kernel_launch_.insert(gvar);
Array<PrimExpr> call_args;
call_args.push_back(StringImm(dev_info.global_symbol));
for (PrimExpr arg : node->args) {
call_args.push_back(arg);
}
for (const auto &launch_arg : dev_info.launch_args) {
call_args.push_back(Substitute(launch_arg, param_map));
}
auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype;
return Call(dtype, builtin::tvm_call_packed(), call_args);
}
Optional<Target> current_target_;
std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map_;
std::unordered_set<const GlobalVarNode *> device_kernel_launch_;
std::unordered_set<const GlobalVarNode *> extern_function_call_;
};
namespace transform {
tvm::transform::Pass LowerDeviceKernelLaunch() {
auto pass_func = [](IRModule mod,
tir::transform::PassContext ctx) -> IRModule {
auto mutator = [&mod]() {
std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto prim_func = base_func.as<PrimFunc>()) {
device_info_map[gvar.get()] =
DeviceInfoCollector::Collect(gvar, prim_func.value());
}
}
return DeviceKernelMutator(std::move(device_info_map));
}();
{
IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func =
mutator.RewriteKernelLaunchSite(gvar, GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func);
}
}
}
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
}
{
IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func =
mutator.UpdateKernelAttributes(gvar, GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func);
}
}
}
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0,
"tl.LowerDeviceKernelLaunch", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch",
LowerDeviceKernelLaunch);
});
} // namespace transform
} // namespace tl
} // namespace tvm
......@@ -22,7 +22,8 @@
* \brief Lower the special device storage access.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
......@@ -141,8 +142,11 @@ Pass LowerDeviceStorageAccessInfo() {
{});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerDeviceStorageAccessInfo")
.set_body_typed(LowerDeviceStorageAccessInfo);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceStorageAccessInfo",
LowerDeviceStorageAccessInfo);
});
} // namespace transform
} // namespace tl
......
......@@ -3,6 +3,7 @@
* \brief Lower Hopper intrinsics cuda GPU(sm90+)
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
......@@ -149,8 +150,10 @@ tvm::transform::Pass LowerHopperIntrin() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerHopperIntrin")
.set_body_typed(LowerHopperIntrin);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin);
});
#endif // (CUDA_MAJOR_VERSION >= 12)
} // namespace tl
......
......@@ -3,6 +3,7 @@
* \brief Lower L2 persistent annotation
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
......@@ -98,8 +99,10 @@ tvm::transform::Pass LowerL2Persistent() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerL2Persistent")
.set_body_typed(LowerL2Persistent);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerL2Persistent", LowerL2Persistent);
});
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_opaque_block.cc
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using namespace tir::attr;
/*!
* \brief Remove Block to ensure that the TIR can not be scheduled again.
*/
class OpaqueBlockLower : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt body) {
OpaqueBlockLower lower;
lower.storage_align_ = CollectStorageAlignAnnotation(body);
return lower(std::move(body));
}
private:
Stmt VisitStmt_(const BlockRealizeNode *op) final {
// We have convert blocks into opaque blocks in previous passes.
ICHECK(op->iter_values.empty())
<< "Non-opaque blocks are not allowed in FlattenBuffer. Please "
"call pass ConvertBlocksToOpaque before.";
// Step 1. Visit the body
Block new_block = Downcast<Block>(this->VisitStmt(op->block));
PrimExpr predicate = this->VisitExpr(op->predicate);
// Step 2. Transform the `predicate` to if-then-else
Stmt body = new_block->body;
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer &buffer = new_block->alloc_buffers[i - 1];
Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer);
body = DeclBuffer(buffer, std::move(body));
Map<String, ffi::Any> allocate_annotations;
auto it = storage_align_.find(buffer->data);
if (it != storage_align_.end()) {
StorageAlignAnnotation allocate_aligns;
for (auto tuple : it->second) {
tuple.Set<0>(-1);
allocate_aligns.push_back(tuple);
}
allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns);
}
body = Allocate(buffer->data, buffer->dtype, allocation_shape,
const_true(), std::move(body), allocate_annotations);
}
// Step 4. Handle annotations, block annotations are not preserved by
// default.
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true);
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
body = AttrStmt(Integer(0), it->first, it->second, std::move(body));
}
return body;
}
Stmt VisitStmt_(const BlockNode *op) final {
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
if (block->annotations.count("stmt_group")) {
return block->body;
}
return block;
}
Stmt VisitStmt_(const ForNode *op) final {
// Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
if (is_one(extent) && op->annotations.empty()) {
// handling unit loop
unit_loop_vars_[op->loop_var] = min;
}
// Step 2. Visit recursively
Stmt body = this->VisitStmt(op->body);
// Step 3. Handle annotations
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
Map<String, ffi::Any> new_annotations =
HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false);
// Step 4. Create new For loop accordingly
if (op->kind == ForKind::kThreadBinding) {
// Case 1. Thread binding
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty()) {
// Case 2. Unit loop
return body;
} else {
// Case 3. An ordinary loop
body = For(op->loop_var, std::move(min), std::move(extent), op->kind,
std::move(body), std::nullopt, new_annotations);
}
// Step 5. Insert nested attrs
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
body = AttrStmt(op->loop_var, it->first, it->second, std::move(body));
}
return body;
}
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) {
return var;
} else {
PrimExpr expr = it->second;
if (expr.dtype() != var.dtype()) {
expr = tvm::cast(var.dtype(), std::move(expr));
}
return expr;
}
}
static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var,
String thread_tag, Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
/*var=*/std::move(var),
/*iter_type=*/IterVarType::kThreadIndex,
/*thread_tag=*/thread_tag);
String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" ||
thread_tag == "vthread.y" || thread_tag == "vthread.z")
? tir::attr::virtual_thread
: tir::attr::thread_extent;
return AttrStmt(/*node=*/std::move(iter_var),
/*attr_key=*/std::move(attr_key),
/*value=*/std::move(extent),
/*body=*/std::move(body));
}
/*! \brief Convert attr value from annotation map into PrimExpr. */
PrimExpr ConvertAttrValue(const String &key, const Any &obj) {
if (obj == nullptr) {
return PrimExpr();
} else if (auto expr = obj.try_cast<PrimExpr>()) {
return expr.value();
} else if (auto str = obj.try_cast<String>()) {
return std::move(StringImm(str.value()));
} else {
LOG(FATAL) << "Illegal attribute of key " << key << ", value type "
<< obj.GetTypeKey() << " not supported";
return PrimExpr();
}
}
/*!
* \brief Helper to handle annotation dict.
* (1) if the attr key is prefixed by `pragma_`, move to ordered kv list. They
* are lowered to `AttrStmt` by legacy TE schedule convention.
* (2) the non-pragma loop annotations are preserved
* (3) the non-pragma block annotations are dropped
* \return New annotation dict with preserved keys. Also update pragma attr
* pairs ordered by key.
*/
Map<String, ffi::Any>
HandleAnnotations(const Map<String, ffi::Any> &annotations,
std::vector<std::pair<std::string, PrimExpr>> *pragma_attrs,
bool is_block) {
Map<String, ffi::Any> preserved_annotations;
pragma_attrs->clear();
for (const auto &kv : annotations) {
const String &key = kv.first;
if (tir::attr::IsPragmaKey(key)) {
pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second));
} else if (!is_block) {
// the loop annotation is preserved
preserved_annotations.Set(key, kv.second);
}
}
std::sort(
pragma_attrs->begin(), pragma_attrs->end(),
[](const auto &p1, const auto &p2) { return p1.first < p2.first; });
return preserved_annotations;
}
/*! \brief Record the loop_var and loop start value of unit loops, whose
* extent is one. */
std::unordered_map<Var, PrimExpr> unit_loop_vars_;
/*! \brief Attr keys to preserve into loop annotations. */
std::unordered_set<std::string> preserved_annotations_;
/*! \brief The map from buffer var to its storage alignment information. */
std::unordered_map<Var, StorageAlignAnnotation> storage_align_;
};
PrimFunc TLLowerOpaqueBlock(PrimFunc f) {
auto fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
return f;
}
tir::transform::Pass LowerOpaqueBlock() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return TLLowerOpaqueBlock(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock);
});
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment