Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)



**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in T.thread_binding(length, "blockIdx.x"):
               for tx in T.thread_binding(length, "threadIdx.x"):
                   C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Device call
   ```
   After compilation, `add` becomes a CUDA `__device__` function.

4. **Cython-based Python/C++ interop**  
   Replaced ctypes with Cython for all Python/C++ interactions:
   - Python → C++ calls
   - C++ → Cython calls  
   This improves performance by around 100x and reduces CPU overhead during compile/runtime.

5. **FP8 data type standardization**  
   Migrated `e5m2_float8` and similar types to Torch-standardized variants`float8_e5m2` and etc.



* Refactor CMakeLists.txt to set default build type and manage dependencies for tvm_cython modules

* Update default value of `check_well_formed` parameter in `prim_func` to False for improved flexibility in TIR function parsing.

* Add StorageRewrite function to transform module

Introduced the StorageRewrite function in the tilelang.transform module, which returns a TVM transform pass. This addition enhances the functionality of the module by providing a new transformation option for users.

* Refactor null option handling in IR and layout inference

- Updated instances of `NullOpt` to `std::nullopt` in `ir.cc` and `parallel.cc` for consistency with modern C++ practices.
- Enhanced layout inference logic in `layout_inference.cc` to improve type safety by replacing `as<Fragment>().get()` with `as<FragmentNode>()`.
- Adjusted error handling in `multi_version_buffer_rewriter.cc` and `persist_threadblock.cc` to use more concise null checks.
- Cleaned up test files by commenting out `tilelang.testing.main()` and replacing it with specific test function calls for better clarity.
- Removed unused test file `test_tilelang_kernel_deepseek_nsa.py` to streamline the testing suite.

* Update TVM subproject and refactor cluster planning and tile operation handling

- Updated the TVM subproject to a dirty commit state.
- Refactored copyright headers in `cluster_planning.cc` to reflect the new licensing.
- Enhanced error handling in `lower_tile_op.cc` to check for missing padding map annotations.
- Modified test files to improve clarity and functionality, including adjustments to kernel compilation and test assertions.
- Updated various test cases to ensure proper handling of annotations and configurations in the TileLang testing framework.

* Update annotation type in warp specialized test for consistency

- Changed the annotation type in the `test_warp_specialized` function from a literal integer to `T.int32(3)` for improved type safety and consistency with the TileLang framework.

* Refactor test execution in warp specialized test

- Replaced the direct call to `test_warp_specialized()` with `tilelang.testing.main()` in the test file to standardize test execution and improve integration with the TileLang testing framework.

* refactor

* [Enhancement] Add strict layout map for improved buffer layout inference (#594)

- Introduced a `strict_layout_map` to enhance layout inference by ensuring that buffers with strict layout requirements are properly accounted for during the inference process.
- Updated the inference logic to check for the presence of buffers in the `strict_layout_map` before applying layout changes, improving the accuracy of layout assignments.
- Refactored the layout inference steps to include the copying of layouts into the new strict map, ensuring a clear separation of layout handling based on inference levels.

* [Example] Update examples to use @tilelang.jit (#597)

* [Example] Update kernel compilation in examples to use @tilelang.jit

- Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead.
- Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability.
- Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed.

* format

* Update example_tilelang_sparse_gqa_decode_varlen_indice.py

* Update example_dequant_gemm_fine_grained.py

* Update example_gemm_autotune.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Enhancement] Refine error messaging in LowerBulkCopy for global and shared range checks (#599)

* [Enhancement] Improve error messaging for global and shared range legality checks in LowerBulkCopy

- Updated error messages in the LowerBulkCopy function to provide clearer context when global and shared ranges are illegal.
- Enhanced the readability of the error output by including tensor names, improving debugging and validation processes during bulk copy operations.

* [Enhancement] Refine error messaging in LowerBulkCopy for global and shared range checks

- Improved the clarity of error messages in the LowerBulkCopy function by enhancing the output format.
- Included additional context in error messages to aid debugging when global and shared ranges are found to be illegal, ensuring better traceability during bulk copy operations.

* [Enhancement] Introduce PassConfig `TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE` to enable aggressive shared memory reuse (#602)

* [Enhancement] Add aggressive shared memory merge option in memory allocation

- Introduced a new configuration option `tl.enable_aggressive_shared_memory_merge` to enable aggressive merging of shared memory allocations.
- Updated the `SharedMemLinearAccessPatternFinder` class to support an aggressive merge strategy, allowing for improved memory reuse.
- Modified the `MergeSharedMemoryAllocations` function to incorporate the new merging strategy based on the configuration.
- Enhanced the `PassConfigKey` enumeration to include the new aggressive merge option, ensuring it can be configured appropriately.

* lint fix

* [Enhancement] Add aggressive shared memory merge configuration option

- Introduced a new configuration option `kEnableAggressiveSharedMemoryMerge` to enable aggressive merging of shared memory allocations, enhancing memory management capabilities.

* [Enhancement] Update MergeSharedMemoryAllocations to support aggressive merge option

- Modified the `MergeSharedMemoryAllocations` function to accept an `enable_aggressive_merge` parameter, allowing for more flexible memory management.
- Introduced a new helper function `should_enable_aggressive_merge` to determine the aggressive merge configuration based on the pass context and target.
- Updated the relevant calls in the `phase.py` and `__init__.py` files to utilize the new aggressive merge functionality, enhancing the overall memory allocation strategy.

* [Refactor] Update accumulation handling in gemm_sm90.h (#603)

- Replaced the use of `tiled_mma.accumulate_ = GMMA::ScaleOut::Zero` with a call to `clear(acc)` for better clarity and maintainability in the accumulation logic.
- This change enhances the readability of the code by standardizing the approach to clearing accumulation values across multiple sections of the file.

* [Enhancement] Add tma bulk copy. (#600)

* [Bugfix] Fixed mha_bwd shape inconsistency error (#604)

* lint fix

* Update requirements-lint.txt to maintain clang-format version consistency

* [Bugfix] Avoid duplicate data access when cross thread buffer meet replicate register (#606)

* [Enhancement] Improve debug output formatting in layout and fragment nodes

- Updated the `DebugOutput` methods in `LayoutNode` and `FragmentNode` to provide more structured and informative output, including transformation details and thread range information.
- Enhanced layout inference logic in `ParallelOp` to add predicates for cross-thread shared memory access, improving layout handling in parallel operations.
- Minor adjustment in `layout_inference.cc` to ensure clarity in parallel loop handling.

* lint fix

* [Enhancement] Support tf32 gemm_rs (#607)

- Added a line break in `quickstart.py` for better readability.
- Simplified the JIT kernel compilation in `quickstart.py` by removing the unused execution backend option.
- Modified `example_elementwise_add.py` to disable cache for `tilelang` and optimized the element-wise addition kernel by utilizing shared memory for input tensors, improving performance.
- Updated default values for matrix dimensions and block sizes in the argument parser to enhance usability.

* [Enhancement] Introduce option `TL_DISABLE_FAST_MATH` and `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` (#609)

* [Enhancement] Introduce new PassConfig options for fast math and PTXAS verbosity

- Added `kDisableFastMath` and `kEnablePTXASVerboseOutput` configuration options to enhance control over compilation settings.
- Updated `LibraryGenerator` to utilize these new pass configurations, allowing for more flexible compilation behavior based on user preferences.
- Enhanced `PassConfigKey` enumeration to include the new options, ensuring they can be configured appropriately in the pass context.

* [Refactor] Update PTXAS verbosity configuration key in LibraryGenerator

- Changed the configuration key for PTXAS verbosity from `TL_VERBOSE_PTXAS_OUTPUT` to `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` to align with the new naming convention introduced in recent enhancements.
- This update ensures consistency in the configuration options used within the `LibraryGenerator` class, improving clarity and maintainability of the code.

* lint fix

* fix build

* [Experimental][Language] add `T.GEMM_SP` for sm90 sparse tensor core (#526)

* [experimental] add a draft gemm_sp

* [3rdparty] bump cutlass to v3.9.3

* [lint] run format.sh

* [chore] rebase

* [chore] use abs path

* [gemm_sp] add metadata layout

* [ci] add more example

* [lint] run format.sh

* [chore] polish

* [chore] move gemm_sp to experimental

* [chore] polish

* [lint] run format.sh

* [Enhancement] Improve bulk copy handling and update GEMM sparse tensor test

* Added a warning log for unsupported non-swizzled global layouts in the bulk copy operation, ensuring fallback to normal copy.
* Refactored the GEMM sparse tensor test by removing unnecessary imports and simplifying the kernel compilation process.
* Updated the test to directly call the `run_gemm_sp` function, enhancing clarity and functionality.

* Implement Test

* [Enhancement] Update GEMM SP and SM89 templates for improved functionality

* Refactored GEMM SP computation to enhance warp partitioning logic, ensuring compatibility with Hopper architecture.
* Updated layout inference to support new WGMMA conditions and improved error messaging for unsupported targets.
* Modified SM89 templates to utilize new MMA atom structures, enhancing performance and compatibility with fp8 types.
* Added conditional inclusion for GEMM SP header based on CUDA architecture version.

* lint fix

* [gemm_sp] support more layout and data types

* Enhancement: sync T.gemm_sp's layout inference with T.gemm

* Enhancement: support more block_k in compress util

* [Enhancement] enable block_k=64

* [Lint] run format.sh

* [Enhancement] compressor support more dtype

* Enhancement: enable block_K=32

* [Lint] format.sh

* [Fixbug] fix shape

* Refactor: sync gemm

* [Enhancement] enable transpose

* [Enhancement] enable fp8_e4m3

* [Enhancement] enable int8

* [Lint] run format.sh

* [Benchmark] add gemm_sp benchmark

* [Example] fix 256 threads hang

* [CI] fix ci

* [Chore] resolve gemini feedback

* [Benchmark] increase search space

* [Lint] format

* [CI] skip sparse tensor core related tests as only sm90 is supported

* [CI] pass local run

* Update gemm_sm89.h

* lint fix

* lint fix

* [Enhancement] Add support for sparse GEMM and initialize CUDA architecture flags

- Introduced a new boolean flag `enable_sparse_gemm_` to control the inclusion of sparse GEMM functionality in CUDA code generation.
- Updated the `Finish` method to conditionally include the sparse GEMM header based on the new flag.
- Implemented logic in `VisitStmt_` to enable sparse GEMM when the corresponding external call is detected.
- Added a function to initialize the `TORCH_CUDA_ARCH_LIST` environment variable based on the target compute version, enhancing compatibility with PyTorch.
- Refactored the initialization function into the appropriate module and ensured it is called in the sparse utilities module.

* Update test_compress_utils.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Doc] Phaseout Legacy documentations (#610)

- Added a new entry in the README for the introduction of `T.gemm_sp` supporting 2:4 sparse tensor core.
- Removed several outdated documentation files related to convolution, flash attention, and other tutorials to streamline the documentation structure.

* [Refactor] Phaseout Pass ParallelLoopTransformer (#611)

* Refactor layout inference by removing the ParallelLoopTransformer class. Updated layout inference logic to streamline buffer access collection and condition handling in parallel loops. This change simplifies the code structure and enhances maintainability.

* Update MHA backward test cases to use reduced dimensions for batch size and context length

* fix build

* [Enhancement] Update ReduceOp initialization values for integer types (#614)

* [Enhancement] Update ReduceOp initialization values for integer types

- Modified the `MakeInitValue` method in `ReduceOp` to handle integer data types correctly by returning appropriate minimum and maximum values based on the bit width.
- Added checks for integer types to ensure correct initialization for `kMax` and `kMin` reduction types, enhancing the robustness of the reduction operations.

* [Enhancement] Update ReduceOp to handle unsigned integer initialization values

- Enhanced the `MakeInitValue` method in `ReduceOp` to include support for unsigned integer data types.
- Added conditions to return appropriate initialization values for `kMax` and `kMin` reduction types based on the data type, improving the robustness of reduction operations.

* Bump transformers from 4.50.0 to 4.51.0 in /examples/bitnet-1.58b (#615)

Bumps [transformers](https://github.com/huggingface/transformers) from 4.50.0 to 4.51.0.
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](https://github.com/huggingface/transformers/compare/v4.50.0...v4.51.0

)

---
updated-dependencies:
- dependency-name: transformers
  dependency-version: 4.51.0
  dependency-type: direct:production
...
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Refactor] refactor autotune examples (#617)

* [Refactor] Update tilelang kernel functions and remove unused imports

- Refactored the `flashattn_fwd`, `flashattn_bwd_preprocess`, and `flashattn_bwd_postprocess` functions to utilize direct kernel calls instead of cached versions, improving clarity and performance.
- Added `@tilelang.jit` decorators with specified output indices to enhance kernel compilation.
- Removed unused import of `cached` from `tilelang`, streamlining the code.
- Commented out the main testing function call in `test_tilelang_kernel_mha_bwd.py` for potential future use.

* [Refactor] Simplify configuration generation in benchmark and example scripts

- Refactored the `get_configs` functions in multiple benchmark and example scripts to utilize a dictionary-based approach for parameter configuration, improving readability and maintainability.
- Updated the `flashattn` and `chunk_scan_fwd` functions to directly accept configuration parameters, enhancing flexibility in kernel tuning.
- Removed redundant code and streamlined the configuration generation process across various files, ensuring consistency in how configurations are defined and utilized.

* [Refactor] Update configuration handling in benchmark scripts

- Refactored the `get_configs` functions in benchmark scripts to accept a variable argument list, improving flexibility in configuration management.
- Enhanced the `matmul` and `flashattn` functions to utilize the updated configuration approach, streamlining parameter handling for kernel tuning.
- Added `@autotune` decorators to relevant functions, ensuring consistent autotuning behavior across benchmarks.
- Cleaned up redundant code and improved overall readability in the affected files.

* [Refactor] Clean up formatting and update subproject commit

- Updated the subproject commit reference in the TVM directory to indicate a dirty state.
- Removed unnecessary blank lines and improved formatting in the `benchmark_matmul` and `benchmark_matmul_fp8` scripts for better readability.
- Streamlined the function definitions in the `flashattn` example script to enhance clarity and maintainability.

* [Refactor] Update AutoTuner configuration handling

- Modified the AutoTuner class to check if kernel parameters are set before processing tunable arguments, improving robustness in configuration handling.
- Enhanced the logic for skipping compilation when tunable parameters are already provided, ensuring efficient use of resources.
- Updated comments for clarity and maintainability.

* lint fix

* Update TVM subproject commit to indicate dirty state and modify MHA backward test cases

- Updated the subproject commit reference in the TVM directory to reflect a dirty state.
- Adjusted the `test_mha_bwd` function to use a new configuration for the MHA backward tests, changing the context size from 128 to 256.
- Uncommented the main testing function call for potential execution.

* lint fix

* Bump transformers from 4.51.0 to 4.52.1 in /examples/bitnet-1.58b (#619)

Bumps [transformers](https://github.com/huggingface/transformers) from 4.51.0 to 4.52.1.
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](https://github.com/huggingface/transformers/compare/v4.51.0...v4.52.1

)

---
updated-dependencies:
- dependency-name: transformers
  dependency-version: 4.52.1
  dependency-type: direct:production
...
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Fix PTXAS options flag in LibraryGenerator for consistency (#620)

* Refactor FP8 type handling across multiple files to standardize usage of "float8_e4m3" and "float8_e5m2" instead of "e4m3_float8" and "e5m2_float8". This includes updates in benchmarks, examples, tests, and internal utilities.

* [Refactor] Add parallel loop transform pass for condition extraction (#618)

* [Refactor] Add parallel loop transform

* done format check

* pull 3rdparty repo

* Refactor loop variable handling in transformation utilities

- Updated the logic in `loop_parallel_transform_utils.h` to simplify the handling of related loop variables.
- Removed the check that enforced a single related loop variable, replacing it with a return statement when multiple variables are detected, enhancing clarity and maintainability of the transformation process.

* Update loop_parallel_transform_utils.h

* Refactor loop variable handling in transformation utilities

- Enhanced the logic in `loop_parallel_transform_utils.h` to improve clarity and maintainability by simplifying the handling of related loop variables.
- Replaced the previous enforcement of a single related loop variable with a return statement for multiple variables detected.

* remove disable cache flag as commit id has been key component

* lint fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Dev] Update linear attention examples to enhance performance on Hopper GPUs (#621)

* Tune linear attention examples on H100

* Add retnet fwd kernel

* fix lint

* [Enhancement] Add ahead of time cython compilation in setup.py (#622)

* [Enhancement] Add Cython support and compiler detection in setup.py

- Introduced a new `CythonExtension` class for building Cython-based extensions, enhancing the build process for Cython projects.
- Implemented functions to detect the Cython compiler and C++ compiler, improving compatibility and user experience.
- Updated the build process to handle Cython extensions alongside CMake extensions, ensuring a seamless integration for users.
- Added caching mechanisms for Cython compilation to optimize build times and reduce unnecessary recompilation.

* [Enhancement] Add Cython dependency and enable CMake extension building

- Added Cython as a required dependency in `pyproject.toml` to support Cython-based extensions.
- Updated `setup.py` to enable building CMake extensions, improving the build process for projects utilizing both Cython and CMake.
- Modified the Cython compiler detection logic to streamline installation instructions for users.

* [Enhancement] Support more flexible layout host pythonic expr (#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix

* [Enhancement] support composable expression for shape with symbolic vars (#624)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix

* minor update

* 🐍

Fix the file name "test_exmaple_tilelang_nsa" (#629)

* [Enhancement] Add CPU utilization and count settings for Auto-Tuning (#630)

* [Enhancement] Add CPU utilization and count settings for Auto-Tuning

- Introduced environment variables for CPU utilization, counts, and maximum CPU count for auto-tuning.
- Updated the AutoTuner class to utilize these new settings, improving flexibility and performance in multi-threaded environments.
- Enhanced logging to provide better insights into the auto-tuning process based on the configured CPU settings.

* typo fix

* [AutoTune] Support `with set_autotune_inputs` to set auto tuning input tensors (#632)

* [Refactor] Simplify and modularize autotuner implementation

- Removed unused imports and extensive code sections from the autotuner module to enhance readability and maintainability.
- Modularized the code by introducing new imports for autotuning and capturing functionalities, streamlining the overall structure.
- Improved logging setup and removed redundant timeout handling functions, focusing on core autotuning logic.
- Updated the AutoTuner class to better utilize the new modular structure, ensuring efficient performance during auto-tuning processes.

* [Refactor] Clean up and enhance capture and tuner modules

- Improved code readability by removing unnecessary blank lines and organizing imports in `capture.py` and `tuner.py`.
- Enhanced logging in the `AutoTuner` class to provide clearer warnings regarding the usage of `supply_prog` in the context of auto-tuning.
- Streamlined the `CaptureStack` class for better thread-local context management.

* lint fix

* [Refactor] Simplify configuration and autotuning logic in blocksparse GEMM example

- Updated `get_configs` function to reduce the number of configurations, enhancing performance and clarity.
- Removed the `get_best_config` function, integrating its logic directly into the `blocksparse_matmul` function with the `@autotune` decorator for streamlined autotuning.
- Adjusted the main function to directly utilize the autotuned kernel, simplifying the overall structure and improving readability.
- Deleted obsolete test file for autotuning decorator, cleaning up the codebase.

* [Refactor] Improve code formatting and readability in autotune test file

- Reformatted the `matmul` function and `get_configs` function for better readability by adjusting line breaks and indentation.
- Fixed a typo in the `enable_rasteration` parameter name to ensure consistency.
- Cleaned up unnecessary blank lines to enhance overall code clarity.

* Update example_blocksparse_gemm.py

* Update capture.py

* [Pass] Introduce flag to diable cp async lowering (#633)

* [Enhancement] Update PipelinePlanner to support async copy configuration

- Modified the `Substitute` method in `PipelinePlanner` to accept a `use_async_copy` parameter, allowing for more flexible pipeline planning based on async copy requirements.
- Updated the constructor of `PipelinePlanner` to initialize the `use_async_copy_` member variable.
- Adjusted the logic in the pipeline planning process to conditionally apply async copy annotations based on the new parameter.
- Commented out the `LoopVectorizeDynamic` call in `LowerAndLegalize` to prevent unintended modifications during the legalizing phase.

* Refactor PipelinePlanning function for improved readability

- Adjusted the formatting of the `use_async_copy` variable assignment in the `PipelinePlanning` function to enhance code clarity and maintainability.

* fix typo (#635)

* [Pass][Simplify] Introduce symbolic level simplify for condition expression (#634)

* [Enhancement] Add argument simplification option to StmtSimplifier

- Introduced a new `simplify_arguments` flag in the `StmtSimplifier::Apply` method to control argument simplification behavior.
- Updated the `Simplify` function to accept the new flag, allowing for enhanced flexibility in the simplification process.
- Adjusted the `LowerAndLegalize` and `_Simplify` functions to utilize the new argument, ensuring consistent behavior across the codebase.
- Added comments to clarify the purpose of the new flag and its impact on simplification logic.

* lint fix

* [Enhancement] Improve layout inference and reduce operation handling

- Updated `ParallelOp::InferLayout` to check for pure buffer stores, enhancing layout inference logic.
- Modified `ReduceOp::Lower` to include all threads in the AllReduce operation, improving performance on specific architectures.
- Added a TODO comment in `AllReduce` to consider merging synchronization barriers for optimization.

* lint fix

* [Enhancement] Add input validation for GEMM parameters

- Introduced checks to ensure that the dimensions M and N are divisible by their respective warp sizes (kMPerWarp and kNPerWarp) in the Gemm::ComputeWarpPartition method.
- Added informative error messages to assist in debugging when the input parameters do not meet the required conditions.

* bug fix

* Enhance test coverage by adding LLVM requirement decorator to multiple function call tests. This ensures that tests for argument count, type code, null data pointer, and dimensionality checks are only executed when LLVM is available, improving test reliability and clarity.

* lint fix

* Fix software pipeline stage annotation and update optional config handling in StmtSimplifier

* Add Python executable detection in CMake configuration and update TVM submodule reference. Remove unused vectorization tests for improved clarity.

* Update TVM submodule reference and refactor FFI registration to use static initialization blocks for improved organization and clarity.

* Refactor attribute handling in layout and IR nodes to use reflection registration. This change replaces the VisitAttrs method with a RegisterReflection method for improved clarity and organization across multiple classes, including KernelLaunchFrameNode, WarpSpecializeFrameNode, LayoutNode, FragmentNode, and SwizzledLayoutNode.

* finish rebase

* tvm update

* Refactor FFI registration across tilelang modules to use the updated `tvm.ffi` namespace. This includes changes in various files to replace `tvm._ffi` with `tvm.ffi`, enhancing consistency and clarity in the codebase.

* lint fix

* Update TVM submodule reference and modify CUDA runtime argument handling to use the new runtime constants for improved clarity and consistency.

* lint fix

* Refactor tensor data type references from "e4m3_float8" and "e5m2_float8" to "float8_e4m3" and "float8_e5m2" across multiple files for consistency and clarity.

* lint fix

* Refactor forward_index initialization in Fragment class to default to an empty array instead of None, ensuring consistent handling of optional outputs.

* test fix

* lint fix

* bugfix

* lint fix

* reduce fix

* lint fix

* carver fix

* cast fix

* Update submodule and enhance kernel launch functionality with optional block size parameter; add device kernel launch transformation.

* lint fix

* bugfix

* Refactor test execution in test_tilelang_cpu_gemm.py and enhance device call checks in lower.py to exclude C packed functions from kernel launch conditions.

* lint fix

* Update runtime.cc

* phase out lisence

* Update subproject commit for TVM to 555cc71

* Update subproject commit for TVM to d39953fa

* Update subproject commit for TVM to 9574805f

* Update subproject commit for TVM to a08b7c3

* fix ci

* ci fix

---------
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarCunxiao Ni <85601223+Cunxiao2002@users.noreply.github.com>
Co-authored-by: default avatarYuxi Chi <cherichy@outlook.com>
Co-authored-by: default avatarNathan Chen <120630832+Nathancgy@users.noreply.github.com>
Co-authored-by: default avatarbotbw <wang1570@e.ntu.edu.sg>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: default avatarxs-keju <93414213+xs-keju@users.noreply.github.com>
Co-authored-by: default avatarTong WU <109033598+Rachmanino@users.noreply.github.com>
Co-authored-by: default avatarKadir Nar <kadir.nar@hotmail.com>
Co-authored-by: default avatarYuqing Xia <35415939+xiayuqing0622@users.noreply.github.com>
Co-authored-by: default avatarxwhzz <wh.xie@outlook.com>

parent 8edd6941
......@@ -6,7 +6,7 @@
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -209,8 +209,10 @@ tvm::transform::Pass LowerSharedBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerSharedBarrier")
.set_body_typed(LowerSharedBarrier);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedBarrier", LowerSharedBarrier);
});
} // namespace transform
} // namespace tl
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Lower allreduce to device implementable ir.
* \file lower_thread_allreduce.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
#include "tir/transforms/update_pointer_storage_scope.h"
namespace tvm {
namespace tl {
using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
/*!
* \brief collect the mapping from the buffer var to its allocate
*/
class AllocateCollector : public StmtExprVisitor {
private:
bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
}
public:
void VisitStmt_(const AllocateNode *op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
// The dynamic mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const AllocateNode *> dyn_shmem_allocs_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const AllocateNode *>
static_shmem_allocs_;
};
class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
explicit ThreadAllreduceBuilder(const TargetNode *target,
bool is_dynamic = false)
: target_(target),
warp_size_(
target->GetAttr<Integer>("thread_warp_size", 1).value().IntValue()),
max_num_threads_(target->GetAttr<Integer>("max_num_threads", -1)
.value()
.IntValue()) {
if (is_dynamic) {
shared_scope = "shared.dyn";
}
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
return ret;
} else if (op->attr_key == tir::attr::reduce_scope) {
const CommReducerNode *combiner = op->node.as<CommReducerNode>();
ICHECK(combiner);
reduce_combiner_.push_back(combiner);
Stmt ret = StmtExprMutator::VisitStmt_(op);
reduce_combiner_.pop_back();
return ret;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const EvaluateNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<EvaluateNode>();
const CallNode *call = op->value.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_thread_allreduce())) {
return MakeAllreduce(call);
} else {
return stmt;
}
}
Stmt VisitStmt_(const AllocateNode *op) final {
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
if (auto it = alloc_remap_.find(node->buffer_var.get());
it != alloc_remap_.end()) {
Buffer buf = Downcast<Buffer>(it->second);
auto write_ptr = node.CopyOnWrite();
write_ptr->buffer_var = buf->data;
write_ptr->dtype = buf->dtype;
write_ptr->extents = buf->shape;
write_ptr->condition = const_true(buf->dtype.lanes());
if (buf.scope() == shared_scope) {
// Use volatile access to shared buffer.
write_ptr->body =
AttrStmt(buf->data, tir::attr::volatile_scope, 1, write_ptr->body);
}
}
return std::move(node);
}
Optional<Buffer> GetRemappedBuffer(const Buffer &buf) {
if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) {
return it->second;
}
if (auto it = var_remap_.find(buf->data.get()); it != var_remap_.end()) {
Buffer new_buf = buf;
new_buf.CopyOnWrite()->data = it->second;
buf_remap_[buf.get()] = new_buf;
return new_buf;
}
return std::nullopt;
}
Stmt VisitStmt_(const DeclBufferNode *op) final {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto buf = GetRemappedBuffer(node->buffer)) {
node.CopyOnWrite()->buffer = buf.value();
}
return std::move(node);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
if (auto it = load_remap_.find(op->buffer->data.get());
it != load_remap_.end()) {
for (const auto &index : op->indices) {
ICHECK(is_zero(index))
<< "The index of buffer " << op->buffer << " is " << index;
}
return it->second;
}
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
op = load.get();
if (auto opt = GetRemappedBuffer(load->buffer)) {
load.CopyOnWrite()->buffer = opt.value();
}
return std::move(load);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (auto opt = GetRemappedBuffer(store->buffer)) {
store.CopyOnWrite()->buffer = opt.value();
}
return std::move(store);
}
private:
// Thread entry
struct ThreadEntry {
runtime::ThreadScope scope;
IterVar iv;
int extent;
// comparator
bool operator<(const ThreadEntry &other) const {
return scope.dim_index < other.scope.dim_index;
}
};
// make allreduce.
Stmt MakeAllreduce(const CallNode *call) {
ICHECK(!reduce_combiner_.empty());
const CommReducerNode *combiner = reduce_combiner_.back();
size_t size = combiner->result.size();
const IntImmNode *size_of_args = call->args[0].as<IntImmNode>();
ICHECK(size_of_args) << call->args[0]->GetTypeKey();
ICHECK_EQ(size, size_of_args->value);
Array<PrimExpr> inits = combiner->identity_element;
std::vector<PrimExpr> values(size);
std::vector<DataType> types(size);
PrimExpr cond = call->args[size + 1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1 + idx];
if (!is_one(cond)) {
values[idx] = Select(cond, values[idx], inits[idx]);
}
types[idx] = values[idx].dtype();
}
std::vector<Buffer> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
PrimExpr arg = call->args[2 + size + idx];
// Loads from boolean buffers may have cast nodes inserted by
// earlier passes.
if (auto cast = arg.as<CastNode>()) {
arg = cast->value;
}
buffers[idx] = Downcast<BufferLoad>(arg)->buffer;
}
std::unordered_set<const VarNode *> reduce_set;
for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
const VarNode *v = call->args[i].as<VarNode>();
// The simply optimization replace a iteration variable with a constant
// when extent of the iteration is 1. As threaded IterVar always started
// from 0, we can just ignore this variable in this case.
if (v) {
reduce_set.insert(v);
} else {
ICHECK(call->args[i].as<IntImmNode>() &&
call->args[i].as<IntImmNode>()->value == 0)
<< "arg" << i << "should be a VarNode or IntImmNode "
<< "while it is " << call->args[i];
}
}
size_t nmatch = 0;
std::vector<ThreadEntry> vred, vpar;
int reduce_dim_index = -1;
for (const AttrStmtNode *attr : thread_extents_) {
ThreadEntry e;
IterVar iv = Downcast<IterVar>(attr->node);
e.scope = runtime::ThreadScope::Create(iv->thread_tag);
e.iv = iv;
ICHECK_LE(e.scope.rank, 1);
ICHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
const auto *ptr = attr->value.as<IntImmNode>();
ICHECK(ptr) << "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
// ignore variables equal to 0
if (e.extent == 1) {
continue;
}
if (reduce_set.count(iv->var.get())) {
bool already_exists = false;
for (const auto &entry : vred) {
if (entry.scope.dim_index == e.scope.dim_index) {
already_exists = true;
break;
}
}
if (!already_exists) {
vred.push_back(e);
++nmatch;
reduce_dim_index = e.scope.dim_index;
}
} else {
bool already_exists = false;
for (const auto &entry : vpar) {
if (entry.scope.dim_index == e.scope.dim_index) {
already_exists = true;
break;
}
}
if (!already_exists) {
vpar.push_back(e);
}
}
}
}
// remove reduce thread from parallel thread
if (reduce_dim_index != -1) {
for (size_t i = 0; i < vpar.size(); ++i) {
if (vpar[i].scope.dim_index == reduce_dim_index) {
vpar.erase(vpar.begin() + i);
break;
}
}
}
ICHECK_EQ(nmatch, reduce_set.size())
<< "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end());
std::sort(vpar.begin(), vpar.end());
// the size of each index.
int reduce_extent, group_extent;
PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
PrimExpr group_index = FlattenThread(vpar, &group_extent);
// the longest contiguous reduce extent after flattening
int contiguous_reduce_extent = 1;
std::vector<std::tuple<int, int, bool>>
block_threads; // tuple(dim_index, extent, is_reduce)
for (const ThreadEntry &thr : vred) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
}
}
for (const ThreadEntry &thr : vpar) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
}
}
// sort according to dim_index
std::sort(block_threads.begin(), block_threads.end());
for (auto &&thr_attr : block_threads) {
auto [dim_index, extent, is_reduce] = thr_attr;
(void)dim_index; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
if (is_reduce) {
contiguous_reduce_extent *= extent;
} else {
break;
}
}
std::vector<Stmt> seq;
std::vector<Buffer> new_alloc_bufs;
//
// This is an optimization. For small reduction sizes, it may be beneficial
// for a single warp to performance the entire reduction. No trips to shared
// memory and no cross warp synchronizations are required.
// The following code emits the reduction as follows:
//
// Allocate reduction vars v[i], i = 0..size-1
//
// for offset from WARP_SIZE to 1 by 2
//
// a <- load(v[i])
// b <- shuffle_down(load(v[i], offset))
// v[i] <- reduction(a, b)
//
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
// When the thread extent is multiple of warp size, we can use a two-stage
// warp-level reduction to optimize. This is implemented by applying the
// algorithm above twice.
//
// For example, suppose we want to use 512 threads to reduce 512 elements
// and the warp size is 32. In this case there are (512 / 32) = 16 warps.
// In the first stage, each of the 16 warps reduces 32 elements. So after
// the stage, we have 16 remaining elements to be reduced, one for each
// warp. We store the 16 elements in shared memory, and start the second
// stage. In the second stage we use the first 16 lanes of the first warp to
// reduce the remaining elements, and this reduction can also be optimized
// by shuffle_down warp-level primitives.
PrimExpr zero_index = make_const(reduce_index->dtype, 0);
if (IsWarpReduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
std::vector<PrimExpr> reduce_results;
DataType mask_dtype = DataType::UInt(32);
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
if (reduce_extent <= warp_size_) {
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, reduce_extent, group_index,
mask, std::nullopt, &seq);
// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
// uniformly writing the same result.
for (size_t i = 0; i < size; ++i) {
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
PrimExpr val = BufferLoad(buf, {zero_index});
ICHECK_EQ(val->dtype, types[i]);
PrimExpr splat =
WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(),
val, reduce_extent * group_index);
seq.push_back(BufferStore(buf, splat, {zero_index}));
}
} else {
int n_warps = reduce_extent / warp_size_;
std::vector<Buffer> local_bufs;
// 1. Create the staging buffer in shared memory.
std::vector<Buffer> staging_shared_bufs;
staging_shared_bufs.reserve(size);
for (size_t i = 0; i < size; ++i) {
Buffer staging_shared_buf = decl_buffer(
/*shape=*/{make_const(reduce_index->dtype,
n_warps * group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging",
/*storage_scope=*/shared_scope);
staging_shared_bufs.push_back(staging_shared_buf);
new_alloc_bufs.push_back(staging_shared_buf);
}
// 2. First round of allreduce.
std::tie(reduce_results, local_bufs) =
MakeWarpAllreduce(values, types, combiner, reduce_index, warp_size_,
group_index, mask, std::nullopt, &seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(),
local_bufs.end());
// 3. Write allreduce results to staging buffer.
std::vector<Stmt> write_staging_buf;
write_staging_buf.reserve(size);
for (size_t i = 0; i < size; ++i) {
new_alloc_bufs.push_back(
Downcast<BufferLoad>(reduce_results[i])->buffer);
write_staging_buf.push_back(BufferStore(
/*buffer=*/staging_shared_bufs[i],
/*value=*/reduce_results[i],
/*indices=*/
{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
}
PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
seq.push_back(SyncThread(shared_scope));
// 4. Load staging buffer.
// Second round of allreduce.
for (size_t i = 0; i < size; ++i) {
values[i] =
BufferLoad(/*buffer=*/staging_shared_bufs[i],
/*indices=*/{group_index * n_warps + reduce_index});
}
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, n_warps, group_index, mask,
/*predicate=*/reduce_index <
make_const(reduce_index->dtype, n_warps),
&seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(),
local_bufs.end());
// 5. Create shared memory buffer(s) of `group_extent` elements, storing
// the allreduce results so each thread can access.
std::vector<Stmt> write_result;
write_result.reserve(size);
for (size_t i = 0; i < size; ++i) {
new_alloc_bufs.push_back(
Downcast<BufferLoad>(reduce_results[i])->buffer);
Buffer broadcast_shared_buf = decl_buffer(
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result",
/*storage_scope=*/shared_scope);
write_result.push_back(BufferStore(broadcast_shared_buf,
reduce_results[i], {group_index}));
// Update `reduce_results`, pointing to the value loaded from the
// shared memory buffer.
reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
}
seq.push_back(IfThenElse(reduce_index == zero_index,
SeqStmt::Flatten(write_result)));
seq.push_back(SyncThread(shared_scope));
}
// Write back allreduce results and update existing allocations.
for (size_t i = 0; i < size; ++i) {
ICHECK(!load_remap_.count(buffers[i]->data.get()));
PrimExpr pred = const_true(types[i].lanes());
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
ICHECK_EQ(reduce_results[i]->dtype, types[i]);
load_remap_[buffers[i]->data.get()] = reduce_results[i];
auto node =
Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = buf;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
}
} else {
std::vector<Buffer> shared_bufs(size);
if (reduce_extent == 1) {
// special case, no reduction is needed.
std::vector<Stmt> stores;
for (size_t i = 0; i < size; ++i) {
stores.push_back(BufferStore(buffers[i], values[i], {0}));
}
return SeqStmt::Flatten(stores);
}
// This sync is necessary because there might be incomplete read of
// previous iteration on the same buffer.
seq.emplace_back(SyncThread(shared_scope));
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = decl_buffer(
{IntImm(group_index->dtype, group_extent * reduce_extent)},
types[idx], "red_buf" + std::to_string(idx), shared_scope);
seq.emplace_back(
BufferStore(shared_bufs[idx], values[idx],
{BufIndex(reduce_index, group_index, reduce_extent)}));
}
seq.emplace_back(SyncThread(shared_scope));
seq.emplace_back(MakeBufAllreduce(
combiner, types, shared_bufs, reduce_index, group_index,
reduce_extent, group_extent, contiguous_reduce_extent));
for (size_t idx = 0; idx < size; ++idx) {
ICHECK(!load_remap_.count(buffers[idx]->data.get()));
PrimExpr pred = const_true(types[idx].lanes());
BufferLoad load(shared_bufs[idx],
{BufIndex(make_zero(reduce_index.dtype()), group_index,
reduce_extent)});
ICHECK_EQ(load->dtype, types[idx]);
load_remap_[buffers[idx]->data.get()] = load;
alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
}
// Fix all local allocations as all statements are built.
Stmt body = SeqStmt::Flatten(seq);
for (Buffer buf : new_alloc_bufs) {
body = DeclBuffer(buf, body);
body = Allocate(buf->data, buf->dtype, buf->shape,
const_true(buf->dtype.lanes()), body);
}
return body;
}
std::pair<std::vector<PrimExpr>, std::vector<Buffer>>
MakeWarpAllreduce(std::vector<PrimExpr> src_values, //
std::vector<DataType> dtypes, //
const CommReducerNode *combiner, //
PrimExpr reduce_index, int reduce_extent, //
PrimExpr group_index, //
PrimExpr mask, Optional<PrimExpr> predicate, //
std::vector<Stmt> *seq) {
int n_buffers = src_values.size();
std::vector<Buffer> shared_bufs;
std::vector<Buffer> local_bufs;
shared_bufs.reserve(n_buffers);
// This is the index to the reduction variable, one reduction
// variable per warp. Local scope seems easier to reason without
// relying on a pattern match pass to fix it later.
Array<PrimExpr> zero_indices = {0};
Array<PrimExpr> shape = {1};
std::vector<Stmt> load_values;
load_values.reserve(n_buffers);
for (int idx = 0; idx < n_buffers; ++idx) {
shared_bufs.push_back(decl_buffer(
shape, dtypes[idx], "red_buf" + std::to_string(idx), "local"));
load_values.push_back(
BufferStore(shared_bufs[idx], src_values[idx], zero_indices));
// Uses a local variable to store the shuffled data. Later
// on, an allocation will be built for this local variable.
local_bufs.push_back(
decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx), "local"));
}
if (predicate.defined()) {
seq->push_back(
IfThenElse(predicate.value(), SeqStmt::Flatten(load_values)));
} else {
seq->insert(seq->end(), load_values.begin(), load_values.end());
}
// The mask for this reducer, as this reducer may sit inside
// a divergent control flow. Here it uses a variable to cache the current
// active channels.
Optional<Buffer> mask_buffer;
if (need_warp_shuffle_mask_) {
mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices));
// Push the buffer description. Later this will have an
// allocation built for it.
local_bufs.push_back(mask_buffer.value());
}
// Emit reductions within a warp.
int start_offset = 1;
while (start_offset * 2 < reduce_extent) {
start_offset *= 2;
}
for (int offset = start_offset; offset > 0; offset /= 2) {
// Load reduction values, no synchronization needed.
Array<PrimExpr> a, b;
for (int i = 0; i < n_buffers; ++i) {
Buffer shared_buf = shared_bufs[i];
BufferLoad val(shared_buf, zero_indices);
ICHECK_EQ(val->dtype, dtypes[i]);
a.push_back(val);
// __shfl_*sync calls shall not appear in if_then_else expressions
// as this is causing extra divergency. E.g.
//
// v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
//
// behaves differently from
//
// int t = __shfl_sync(mask, v1, 0);
// v1 = (v2 < v3) ? v3 : t;
//
// The former may cause dead lock as there is a divergent
// branch with a warp sync call inside.
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(),
mask_buffer, val, offset);
Buffer local_buf = local_bufs[i];
Stmt s = BufferStore(local_buf, other, zero_indices);
seq->push_back(s);
BufferLoad load = BufferLoad(local_buf, zero_indices);
ICHECK_EQ(load->dtype, dtypes[i]);
b.push_back(load);
}
// Do reductions.
Array<PrimExpr> ret = (*combiner)(a, b);
// Store the reduction result to itself.
std::vector<Stmt> stores;
stores.reserve(n_buffers);
for (int i = 0; i < n_buffers; ++i) {
Buffer buf = shared_bufs[i];
stores.push_back(BufferStore(buf, ret[i], zero_indices));
}
// During the sub-warp reduction, values from inactive threads could be
// read, which is an undefined behavior according to the cuda document.
//
// In practice, the return value are usually 0, which does no harm to sum
// reduction. However, the result can be incorrect in max or prod
// reduction. Therefore an additional range check has to be performed to
// ensure the correctness.
if (offset * 2 > reduce_extent) {
PrimExpr cond = reduce_index + offset < reduce_extent;
seq->push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
} else {
seq->push_back(SeqStmt::Flatten(stores));
}
}
std::vector<PrimExpr> reduce_results;
reduce_results.reserve(n_buffers);
for (int i = 0; i < n_buffers; ++i) {
reduce_results.push_back(BufferLoad(shared_bufs[i], zero_indices));
}
return {reduce_results, local_bufs};
}
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode *combiner,
const std::vector<DataType> &types,
const Array<Buffer> &shared_bufs, PrimExpr reduce_index,
PrimExpr group_index, int reduce_extent,
int group_extent, int contiguous_reduce_extent) {
// Get next power of two
int reduce_align = 1;
while (reduce_extent > reduce_align) {
reduce_align = reduce_align << 1;
}
ICHECK_GT(reduce_align, 1);
std::vector<Stmt> seq;
size_t size = shared_bufs.size();
PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction
auto fload = [&](int offset) {
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
BufferLoad b_load(
shared_bufs[i],
{BufIndex(reduce_index + offset, group_index, reduce_extent)});
ICHECK_EQ(b_load->dtype, types[i]);
b.push_back(b_load);
BufferLoad a_load(shared_bufs[i], {buf_index});
ICHECK_EQ(a_load->dtype, types[i]);
a.push_back(a_load);
}
Array<PrimExpr> ret = (*combiner)(a, b);
return ret;
};
auto fstore = [&](const Array<PrimExpr> &ret) {
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index});
}
return SeqStmt::Flatten(stores);
};
auto freduce = [&](int offset) {
auto ret = fload(offset);
return fstore(ret);
};
// Step one, check for
if (reduce_align > reduce_extent) {
// reduction with the boundary condition
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread(shared_scope));
}
// normal synchronization
bool warp_align =
group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0;
while (reduce_align > contiguous_reduce_extent ||
reduce_align > warp_size_ || !warp_align) {
if (reduce_align == 1) {
break;
}
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread(shared_scope));
}
// in warp synchronization.
if (reduce_align > 1) {
PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
std::vector<Stmt> in_warp_seq;
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
// freduce can read/write to the same memory location. For
// example, with reduce_align of 4, threadIdx 3 reads from
// memory location 7 as threadIdx 7 is writing to it.
// Therefore, we need to separate out the load from the store
// with a memory barrier in-between. This isn't necessary for
// the earlier normal synchronization, because those are each
// protected by an if-statement. The if-statement is avoided
// here to reduce thread divergence.
auto loads = fload(reduce_align);
Array<Var> in_warp_local_vars;
for (auto expr : loads) {
Var var("w_" + std::to_string(reduce_align) + "_" +
std::to_string(in_warp_local_vars.size()),
expr->dtype);
in_warp_local_vars.push_back(var);
}
std::vector<Stmt> in_let_statement;
in_let_statement.emplace_back(SyncThread("warp"));
in_let_statement.emplace_back(
fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()}));
in_let_statement.emplace_back(SyncThread("warp"));
Stmt body = SeqStmt::Flatten(in_let_statement);
for (size_t i = 0; i < size; i++) {
body = LetStmt(in_warp_local_vars[i], loads[i], body);
}
in_warp_seq.push_back(body);
}
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
seq.emplace_back(SyncThread(shared_scope));
}
return SeqStmt::Flatten(seq);
}
// Flatten the thread index.
// Also return a warp number,
PrimExpr FlattenThread(const std::vector<ThreadEntry> &tvec,
int *out_total_extent) {
int &total_extent = *out_total_extent;
total_extent = 1;
if (tvec.size() == 0) {
return make_zero(DataType::Int(32));
}
PrimExpr ret;
for (const ThreadEntry &e : tvec) {
if (ret.defined()) {
ret = ret + e.iv->var * total_extent;
} else {
ICHECK_EQ(total_extent, 1);
ret = e.iv->var;
}
total_extent *= e.extent;
}
return ret;
}
// The local buffer index.
PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index,
int reduce_extent) {
if (!is_zero(group_index)) {
return analyzer_.Simplify(group_index * reduce_extent + reduce_index);
} else {
return reduce_index;
}
}
// sync thread op.
static Stmt SyncThread(const std::string &sync) {
return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync)}));
}
// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op &op, Optional<Buffer> mask_buffer, PrimExpr val,
PrimExpr delta_or_lane) {
Array<PrimExpr> indices = {0};
PrimExpr mask;
if (mask_buffer.defined()) {
mask = BufferLoad(mask_buffer.value(), indices);
} else {
mask = IntImm(DataType::Int(32), 0);
}
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
return Call(val.dtype(), op, args);
}
// Check if we can use warp level reduction.
//
// Note: The ROCm backend will only have warp reductions for now.
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal).
bool IsWarpReduction(const std::vector<DataType> &types, int group_extent,
int reduce_extent, int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
return false;
}
need_warp_shuffle_mask_ = target_->kind->name != "metal";
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_fixed_length_vector())
return ty.bits() * ty.lanes() != 32;
return ty.bits() != 32;
}))) {
return false;
}
// Supported types:
// {u}int, {u}long, {u}long long, float, double, half/half2
if (std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_float16())
return ty.lanes() > 2;
if (ty.is_fixed_length_vector())
return true;
return ty.bytes() < 4 || ty.bytes() > 8;
})) {
return false;
}
if (thread_extents_.empty()) {
return false;
}
// reduce region must be contiguous.
if (contiguous_reduce_extent != reduce_extent) {
return false;
}
// whether reduce_extent and group_extent are valid for warp reduction.
if (target_->kind->name == "rocm") {
return reduce_extent == warp_size_;
} else {
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
bool is_subwarp_reduction = warp_size_ % reduce_extent == 0;
bool is_multiwarp_reduction =
max_num_threads_ != -1 &&
max_num_threads_ <= warp_size_ * warp_size_ &&
reduce_extent % warp_size_ == 0;
if (is_subwarp_reduction || is_multiwarp_reduction) {
return true;
} else {
return group_extent == 1 && reduce_extent <= warp_size_;
}
}
}
}
// The target.
const TargetNode *target_ = nullptr;
// The shared scope.
String shared_scope = "shared";
// The warp size of the device.
int warp_size_{1};
// The maximum number of threads of the device. "-1" denotes unknown.
int max_num_threads_{-1};
// A boolean indicating if the target supports warp-level masking.
bool need_warp_shuffle_mask_;
// surrounding scope of thread extent.
std::vector<const AttrStmtNode *> thread_extents_;
std::vector<const CommReducerNode *> reduce_combiner_;
// The load remap
std::unordered_map<const VarNode *, PrimExpr> load_remap_;
// Allocate remap
std::unordered_map<const VarNode *, Buffer> alloc_remap_;
// BufferVar remap
std::unordered_map<const VarNode *, Var> var_remap_;
// Buffer remap
std::unordered_map<const BufferNode *, Buffer> buf_remap_;
// Internal analyzer
arith::Analyzer analyzer_;
};
namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerThreadAllreduce() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
AllocateCollector collector;
collector(f->body);
bool is_dynamic = collector.dyn_shmem_allocs_.size() > 1;
auto *n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined())
<< "LowerThreadAllreduce: Require the target attribute";
const TargetNode *target_node = target.as<TargetNode>();
ThreadAllreduceBuilder thread_all_reduce(target_node, is_dynamic);
n->body = thread_all_reduce(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerThreadAllreduce",
LowerThreadAllreduce);
});
} // namespace transform
} // namespace tl
} // namespace tvm
......@@ -3,6 +3,7 @@
* \brief Lower the tile op for further codegen.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
......@@ -108,12 +109,14 @@ private:
* \return The rewritten block.
*/
Stmt RewritePaddingMap(const BlockNode *op) {
auto padding_map =
op->annotations.Get(attr::kPaddingMap).as<Map<Var, PrimExpr>>().value();
auto padding_map = op->annotations.Get(attr::kPaddingMap);
if (!padding_map) {
LOG(FATAL) << "Padding map annotation is missing";
}
Map<Var, Var> var_remap = CreateVarRemap();
Map<Var, PrimExpr> new_padding_map =
RemapPaddingMap(padding_map, var_remap);
Map<Var, PrimExpr> new_padding_map = RemapPaddingMap(
Downcast<Map<Var, PrimExpr>>(padding_map.value()), var_remap);
auto block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
......@@ -235,7 +238,7 @@ private:
}
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
Optional<PrimExpr> offset = NullOpt,
Optional<PrimExpr> offset = std::nullopt,
DataType dtype = DataType::Int(32)) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
// accumulate it to smem_offset
......@@ -318,7 +321,7 @@ private:
op->op.same_as(tl::tma_store()))) {
has_tma_ = true;
}
Array<RelayExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
Array<RelaxExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
builtin::mma_store()};
if (std::find(ptx_instructions.begin(), ptx_instructions.end(), op->op) ==
......@@ -354,7 +357,7 @@ private:
// mma_store now
auto access_ptr = call->args[2];
auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
HandleAccessPtrAndOffset(access_ptr, std::nullopt, call->dtype);
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
} else {
......@@ -496,7 +499,10 @@ tvm::transform::Pass LowerTileOp() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerTileOp", LowerTileOp);
});
} // namespace transform
} // namespace tl
......
......@@ -20,8 +20,10 @@
/*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
......@@ -30,7 +32,6 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -75,7 +76,7 @@ public:
private:
struct ConvertedInfo {
int tcode{-1};
int type_index{-1};
PrimExpr expr;
Buffer dummy_val_buffer;
Buffer dummy_tcode_buffer;
......@@ -87,13 +88,13 @@ private:
// convert val's data type to FFI data type, return type code
DataType dtype = val.dtype();
if (dtype.is_int() || dtype.is_uint()) {
info.tcode = kTVMArgInt;
info.type_index = ffi::TypeIndex::kTVMFFIInt;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_float()) {
info.tcode = kTVMArgFloat;
info.type_index = ffi::TypeIndex::kTVMFFIFloat;
info.expr = Cast(DataType::Float(64), val);
} else if (dtype.is_void()) {
info.tcode = kTVMNullptr;
info.type_index = ffi::TypeIndex::kTVMFFINone;
info.expr = val;
} else {
LOG(FATAL) << "data type " << dtype << " not supported yet";
......@@ -101,18 +102,18 @@ private:
// If multiple return locations have the same data type, use the
// same dummy buffer declaration.
auto it = dummy_val_buffer_map_.find(info.tcode);
auto it = dummy_val_buffer_map_.find(info.type_index);
if (it != dummy_val_buffer_map_.end()) {
info.dummy_val_buffer = it->second;
} else {
info.dummy_val_buffer =
Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0),
ret_var_->name_hint, 0, 0, kDefault);
dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer;
dummy_val_buffer_map_[info.type_index] = info.dummy_val_buffer;
}
// The tcode is always a 32-bit int, so we don't need to have a separate
// map.
// The type_index is always a 32-bit int, so we don't need to have a
// separate map.
if (!dummy_tcode_buffer_.defined()) {
dummy_tcode_buffer_ =
Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0),
......@@ -126,7 +127,8 @@ private:
Stmt WriteToOut(PrimExpr val) {
auto info = ConvertForFFI(val);
Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0});
Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0});
Stmt store_tcode =
BufferStore(info.dummy_tcode_buffer, info.type_index, {0});
Stmt ret_zero = Evaluate(tvm::ret(0));
return SeqStmt({store_val, store_tcode, ret_zero});
}
......@@ -153,7 +155,7 @@ public:
if (rewriter.made_change_) {
return stmt;
} else {
return NullOpt;
return std::nullopt;
}
}
......@@ -204,21 +206,21 @@ inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
* \param func The function to be inspected
*
* \returns The global_symbol to be used for the function at call
* sites, or NullOpt if the function is to remain unchanged.
* sites, or std::nullopt if the function is to remain unchanged.
*/
Optional<String> RequiresPackedAPI(const PrimFunc &func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return NullOpt;
return std::nullopt;
}
}
// Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (!global_symbol.defined()) {
return NullOpt;
return std::nullopt;
}
return global_symbol;
......@@ -344,9 +346,9 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}
// type code checks
Var tcode(param->name_hint + ".code", DataType::Int(32));
Var type_index(param->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(LetStmt(
tcode,
type_index,
BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}),
nop));
DataType t = param.dtype();
......@@ -354,20 +356,22 @@ PrimFunc MakePackedAPI(PrimFunc func) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_init.emplace_back(
AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr ||
type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(
AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(AssertStmt(type_index == kDLInt,
tvm::tir::StringImm(msg.str()), nop));
} else {
ICHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_init.emplace_back(
AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(AssertStmt(type_index == kDLFloat,
tvm::tir::StringImm(msg.str()), nop));
}
}
......@@ -406,13 +410,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
seq_check.push_back(
AttrStmt(node, tir::attr::device_type, device_type, nop));
bool need_set_device =
(target_device_type != kDLMicroDev &&
(
// or is c source target
target_device_type != kDLCPU || target->kind->name != "llvm"));
if (need_set_device) {
if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device),
......@@ -468,7 +466,6 @@ PrimFunc MakePackedAPI(PrimFunc func) {
<< " are used, but are not passed in as API arguments";
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function.
return func;
}
......@@ -516,8 +513,10 @@ tvm::transform::Pass MakePackedAPI() {
return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MakePackedAPI").set_body_typed([]() {
return MakePackedAPI();
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MakePackedAPI",
[]() { return MakePackedAPI(); });
});
} // namespace tl
......
......@@ -3,6 +3,7 @@
* \brief Merge the If Stmt in SeqStmt
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -91,7 +92,10 @@ tvm::transform::Pass MergeIfStmt() {
return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MergeIfStmt").set_body_typed(MergeIfStmt);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt);
});
} // namespace tl
} // namespace tvm
......@@ -23,8 +23,9 @@
* memory allocation. This pass merges multiple TIR-level dynamic or static
* shared memory allocations into one allocation.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -1048,8 +1049,11 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
{});
}
TVM_REGISTER_GLOBAL("tl.transform.MergeSharedMemoryAllocations")
.set_body_typed(MergeSharedMemoryAllocations);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeSharedMemoryAllocations",
MergeSharedMemoryAllocations);
});
} // namespace transform
} // namespace tl
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file warp_specialized_pipeline.cc
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -220,14 +202,14 @@ private:
Stmt VisitStmt_(const ForNode *op) final {
loop_stack_.emplace_back(op->loop_var, op->extent);
auto num_stages_anno = op->annotations.Get("num_stages");
if (!num_stages_anno.defined()) {
if (!num_stages_anno) {
auto for_node = StmtExprMutator::VisitStmt_(op);
loop_stack_.pop_back();
return for_node;
}
ICHECK(num_stages_anno.as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
ICHECK(num_stages_anno->as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
const SeqStmtNode *pipeline_body_seq = op->body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
......@@ -340,8 +322,10 @@ tvm::transform::Pass MultiVersionBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MultiVersionBuffer")
.set_body_typed(MultiVersionBuffer);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer);
});
} // namespace tl
} // namespace tvm
......@@ -3,6 +3,7 @@
* \brief Lower L2 persistent annotation
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
......@@ -59,8 +60,10 @@ tvm::transform::Pass PersistThreadblock() {
return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {});
}
TVM_REGISTER_GLOBAL("tl.transform.PersistThreadblock")
.set_body_typed(PersistThreadblock);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PersistThreadblock", PersistThreadblock);
});
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file pipeline_planning.cc
* \brief Plan the software pipeline
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
......@@ -224,12 +201,12 @@ private:
auto order_anno = loop->annotations.Get("tl_pipeline_order");
auto stage_anno = loop->annotations.Get("tl_pipeline_stage");
auto num_stages_anno = loop->annotations.Get("num_stages");
if (order_anno.defined() && stage_anno.defined()) {
if (order_anno && stage_anno) {
// Check if order_anno or stage_anno contains -1, which means TMA+WS is
// enabled
bool ws_tma_enabled = false;
auto order_array = Downcast<Array<Integer>>(order_anno);
auto stage_array = Downcast<Array<Integer>>(stage_anno);
auto order_array = Downcast<Array<Integer>>(order_anno.value());
auto stage_array = Downcast<Array<Integer>>(stage_anno.value());
for (const auto &val : order_array) {
if (val->value == -1) {
ws_tma_enabled = true;
......@@ -249,20 +226,20 @@ private:
return StmtExprMutator::VisitStmt_(loop);
}
Map<String, ObjectRef> annotations;
Map<String, Any> annotations;
for (const auto &[key, value] : loop->annotations) {
if (key != "tl_pipeline_order") {
annotations.Set(key, value);
}
}
annotations.Set(tir::attr::software_pipeline_order, order_anno);
annotations.Set(tir::attr::software_pipeline_order, order_anno.value());
for (const auto &[key, value] : loop->annotations) {
if (key != "tl_pipeline_stage") {
annotations.Set(key, value);
}
}
annotations.Set(tir::attr::software_pipeline_stage, stage_anno);
annotations.Set(tir::attr::software_pipeline_stage, stage_anno.value());
if (TargetHasAsyncCopy(target_) && use_async_copy_)
annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0});
......@@ -271,9 +248,9 @@ private:
return for_node;
}
if (!num_stages_anno.defined())
if (!num_stages_anno)
return StmtExprMutator::VisitStmt_(loop);
int num_stages = num_stages_anno.as<IntImmNode>()->value;
int num_stages = num_stages_anno->as<IntImmNode>()->value;
Stmt pipeline_body{nullptr};
if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
......@@ -443,7 +420,7 @@ private:
}
// Finally, make the pipeline annotation
Map<String, ObjectRef> annotations;
Map<String, Any> annotations;
for (const auto &[key, value] : loop->annotations) {
if (key != "num_stages") {
annotations.Set(key, value);
......@@ -496,8 +473,10 @@ tvm::transform::Pass PipelinePlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning")
.set_body_typed(PipelinePlanning);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning);
});
} // namespace tl
} // namespace tvm
/*!
* \file simplify.cc
* \brief Remove useless parameters of TL PrimFunc.
* \brief Statement simplifier based on analyzer and remove useless parameters
* of TL PrimFunc.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
......@@ -19,39 +21,45 @@ namespace tl {
using namespace tir;
using namespace arith;
struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
bool transitively_prove_inequalities;
bool propagate_knowns_to_prove_conditional;
bool propagate_knowns_to_simplify_expressions;
bool convert_boolean_to_and_of_ors;
bool apply_constraints_to_boolean_branches;
TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") {
TVM_ATTR_FIELD(transitively_prove_inequalities)
.describe("If true, simplify conditionals with transitive combinations "
"of scoped constraints")
.set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
.describe("If true, known buffer values are propagated and used to "
"statically prove conditionals")
.set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
.describe("If true, known buffer values are propagated and used to "
"replace BufferLoad wherever "
"possible")
.set_default(false);
TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
.describe("If true, simplify conditionals into an AND of ORs")
.set_default(false);
TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
.describe("If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch")
.set_default(false);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<SimplifyConfigNode>()
.def_ro("transitively_prove_inequalities",
&SimplifyConfigNode::transitively_prove_inequalities,
"If true, simplify conditionals with transitive combinations "
"of scoped constraints",
refl::DefaultValue(false))
.def_ro("propagate_knowns_to_prove_conditional",
&SimplifyConfigNode::propagate_knowns_to_prove_conditional,
"If true, known buffer values are propagated and used to "
"statically prove conditionals",
refl::DefaultValue(false))
.def_ro("propagate_knowns_to_simplify_expressions",
&SimplifyConfigNode::propagate_knowns_to_simplify_expressions,
"If true, known buffer values are propagated and used to "
"replace BufferLoad wherever "
"possible",
refl::DefaultValue(false))
.def_ro("convert_boolean_to_and_of_ors",
&SimplifyConfigNode::convert_boolean_to_and_of_ors,
"If true, simplify conditionals into an AND of ORs",
refl::DefaultValue(false))
.def_ro("apply_constraints_to_boolean_branches",
&SimplifyConfigNode::apply_constraints_to_boolean_branches,
"If true, simplify each branch of AND/OR under a constraints "
"provided by the other "
"branch",
refl::DefaultValue(false));
}
static constexpr const char *_type_key = "tl.transform.SimplifyConfig";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode);
RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
......@@ -200,6 +208,7 @@ public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
SimplifyConfigNode);
};
TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); });
TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
......@@ -207,7 +216,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
Optional<SimplifyConfig> config_opt = NullOpt,
Optional<SimplifyConfig> config_opt = std::nullopt,
bool simplify_arguments = false) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(
......@@ -229,6 +238,7 @@ public:
// Begin to remove useless var and buffer
// First get used buffers
simplifier.used_buffers_ = CollectUsedBuffers(func);
bool param_updated = false;
Array<Var> new_params;
Map<Var, Buffer> new_buffer_map;
......@@ -239,13 +249,18 @@ public:
simplifier.used_buffers_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else if (simplifier.used_in_buffer_def_.find(
func->buffer_map[var]->data.get()) !=
simplifier.used_in_buffer_def_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else {
param_updated = true;
}
}
}
if (simplify_arguments && param_updated) {
if (param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span);
} else {
......@@ -444,7 +459,7 @@ private:
arith::ProofStrength::kSymbolicBound)) {
return Bool(true);
}
return NullOpt;
return std::nullopt;
}
}
......@@ -452,7 +467,7 @@ private:
std::optional<ControlFlowGraph> touch_pattern_;
Map<Var, PrimExpr> non_inlined_bindings_;
Optional<Stmt> current_stmt_{NullOpt};
Optional<Stmt> current_stmt_{std::nullopt};
std::unordered_set<const VarNode *> used_in_buffer_def_;
std::unordered_set<const VarNode *> used_vars_;
std::unordered_set<const BufferNode *> used_buffers_;
......@@ -469,7 +484,10 @@ tvm::transform::Pass Simplify(bool simplify_arguments = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
}
TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.Simplify", Simplify);
});
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file storage_rewrite.cc
* \brief Memory access pattern analysis and optimization.
* Re-write data access to enable memory sharing when possible.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/type.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include "arith/int_operator.h"
#include "runtime/thread_storage_scope.h"
#include "tir/ir/buffer_common.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using runtime::StorageRank;
using runtime::StorageScope;
using namespace tir;
/*!
* \brief Perform data type legalization on the given BufferLoadNode pointer.
* Equal to BufferLoadNode::LegalizeDType, but operates on a pointer.
* \param n A pointer to a writable BufferLoadNode.
*/
static void LegalizeBufferLoadDType(BufferLoadNode *n) {
// Check that all indices except the last one have a scalar dtype
for (int i = 0; i < static_cast<int>(n->indices.size()) - 1; i++) {
ICHECK(n->indices[i].dtype().is_scalar())
<< "Only the last index of a buffer access may be a vector type.";
}
// If there are no indices, set the dtype to the buffer's dtype
if (n->indices.empty()) {
n->dtype = n->buffer->dtype;
} else {
auto index_dtype = n->indices.back().dtype();
bool is_buffer_dtype_scalable = n->buffer->dtype.is_scalable_vector();
bool is_index_scalable = index_dtype.is_scalable_vector();
// Do not allow both index dtype and buffer dtype to be scalable vectors
ICHECK(!(is_index_scalable && is_buffer_dtype_scalable))
<< "Index dtype and buffer dtype cannot both be scalable.";
if (is_index_scalable) {
// Index is a scalable vector, while the buffer is not
n->dtype = n->buffer->dtype.with_scalable_vscale_factor(
index_dtype.vscale_factor() * n->buffer->dtype.lanes());
} else if (is_buffer_dtype_scalable) {
// The buffer is a scalable vector, while the index is not
n->dtype = n->buffer->dtype.with_scalable_vscale_factor(
n->buffer->dtype.vscale_factor() * index_dtype.lanes());
} else {
// Neither side is a scalable vector, multiply lanes
n->dtype = n->buffer->dtype.with_lanes(index_dtype.lanes() *
n->buffer->dtype.lanes());
}
}
}
/*!
* \brief collect the mapping from the buffer var to its allocate
*/
class AllocateCollector : public StmtExprVisitor {
private:
bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
}
public:
void VisitStmt_(const AllocateNode *op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
// The dynamic mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const AllocateNode *> dyn_shmem_allocs_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const AllocateNode *>
static_shmem_allocs_;
};
// Find a linear pattern of storage access
// Used for liveness analysis.
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
//
// The linear_seq_ stores before_scope and after_scope.
// The access to the arrays are stored at the after_scope point.
//
// Define "scope" as the body of For/thread_launch/IfThenElse
// This pass tries to detect last point that we need to keep memory
// alive under the same scope as allocate.
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class LinearAccessPatternFinder final : public StmtExprVisitor {
public:
/*! \brief record the touch hist of statment. */
struct StmtEntry {
// The statment
const Object *stmt;
// The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope.
// if offset > 0, means this is the begin, the end entry is current_index +
// offset if offset < 0, means this is the end, the begin entry is
// current_index + offset
int64_t scope_pair_offset{0};
// The buffer variables this statment touched.
std::vector<const VarNode *> touched;
};
// The scope of each allocation
struct AllocEntry {
// The physical dimension of the allocation.
size_t num_physical_dimensions{0};
// scope level
size_t level{0};
// allocation stmt
const AllocateNode *alloc{nullptr};
};
void VisitStmt_(const AllocateNode *op) final {
size_t level = scope_.size();
const VarNode *buf = op->buffer_var.get();
AllocEntry entry;
entry.alloc = op;
entry.level = level;
// Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
// all allocations specify the extent of physical dimensions, and
// is 1 for flat memory spaces.
entry.num_physical_dimensions = op->extents.size();
alloc_info_[buf] = entry;
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
scope_.push_back(StmtEntry());
// visit subexpr
StmtExprVisitor::VisitStmt_(op);
all_buffers_accessed_.insert(op->buffer.get());
// Add write access.
const VarNode *buffer_var = op->buffer->data.get();
auto it = alloc_info_.find(buffer_var);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
scope_[it->second.level].touched.push_back(buffer_var);
ICHECK_EQ(op->buffer->axis_separators.size() + 1,
it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl;
}
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void VisitExpr_(const BufferLoadNode *op) final {
// Add write access.
StmtExprVisitor::VisitExpr_(op);
all_buffers_accessed_.insert(op->buffer.get());
const VarNode *buffer_var = op->buffer->data.get();
auto it = alloc_info_.find(buffer_var);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store.";
scope_[it->second.level].touched.push_back(buffer_var);
ICHECK_EQ(op->buffer->axis_separators.size() + 1,
it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl;
}
}
void VisitStmt_(const EvaluateNode *op) final {
scope_.push_back(StmtEntry());
// visit subexpr
StmtExprVisitor::VisitStmt_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void VisitExpr_(const VarNode *buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
scope_[it->second.level].touched.push_back(buf);
}
}
template <typename T> void VisitNewScope(const T *op) {
scope_.push_back(StmtEntry());
StmtEntry e;
e.stmt = op;
int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
// before scope.
linear_seq_.push_back(e);
StmtExprVisitor::VisitStmt_(op);
// after scope.
e.touched = std::move(scope_.back().touched);
scope_.pop_back();
int64_t end_index = static_cast<int64_t>(linear_seq_.size());
ICHECK_GT(end_index, begin_index);
e.scope_pair_offset = begin_index - end_index;
linear_seq_.push_back(e);
// record the pointer to end index.
ICHECK_NE(end_index, 0U);
linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
}
void VisitStmt_(const AttrStmtNode *op) final {
// Only record the outer most thread extent.
if (op->attr_key == tir::attr::thread_extent && !in_thread_env_) {
in_thread_env_ = true;
VisitNewScope(op);
in_thread_env_ = false;
} else if (op->attr_key == tir::attr::extern_scope) {
VisitNewScope(op);
} else if (op->attr_key == tir::attr::virtual_thread) {
VisitNewScope(op);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); }
void VisitStmt_(const ForNode *op) final { VisitNewScope(op); }
void VisitStmt_(const WhileNode *op) final { VisitNewScope(op); }
void VisitStmt_(const AssertStmtNode *op) final { VisitNewScope(op); }
void VisitStmt_(const LetStmtNode *op) final { VisitNewScope(op); }
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer
std::unordered_map<const VarNode *, AllocEntry> alloc_info_;
// A record of which Buffer objects have been accessed, to prune
// unused DeclBuffer instances.
std::unordered_set<const BufferNode *> all_buffers_accessed_;
private:
// Whether already in thread env.
bool in_thread_env_{false};
// The scope stack.
std::vector<StmtEntry> scope_;
};
// Verify if the statement can be run safely via inplace fashion
//
// Detect pattern: dst[index] = f(src[index])
//
// WARNING: the current detection algorithm cannot handle the case
// when a location in an array is written multiple times
//
// For example, the following program will pass the check,
// but we cannot make A and B to be the same array.
//
// A[0] = B[0] + 1
// A[0] = B[0] + 1
//
// The high level code generator needs to ensure that the generated
// code only write each location of the target array once.
//
// This is the case with IR generated by the current compute schedule.
// We explicitly return false if we find there is an extern block
// which can be arbitrary IR.
//
// Neve-the-less, inplace detector should be used with care in mind.
// We may also consider introduce a condition checker that checks
// if every index only visited once for an absolute sufficient condition.
//
// The code after inplace transformation is no longer idempotent.
//
class InplaceOpVerifier : public StmtExprVisitor {
public:
bool Check(const Object *stmt, const VarNode *dst, const VarNode *src) {
dst_ = dst;
src_ = src;
result_ = true;
if (stmt->IsInstance<AttrStmtNode>()) {
VisitStmt_(static_cast<const AttrStmtNode *>(stmt));
} else if (stmt->IsInstance<ForNode>()) {
VisitStmt_(static_cast<const ForNode *>(stmt));
} else if (stmt->IsInstance<IfThenElseNode>()) {
VisitStmt_(static_cast<const IfThenElseNode *>(stmt));
} else if (stmt->IsInstance<WhileNode>()) {
VisitStmt_(static_cast<const WhileNode *>(stmt));
} else if (stmt->IsInstance<BufferStoreNode>()) {
VisitStmt_(static_cast<const BufferStoreNode *>(stmt));
} else {
return false;
}
return result_;
}
using StmtExprVisitor::VisitStmt_;
void VisitStmt(const Stmt &n) final {
if (!result_)
return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr(const PrimExpr &n) final {
if (!result_)
return;
StmtExprVisitor::VisitExpr(n);
}
void VisitExpr_(const VarNode *op) final {
// assume all opaque access is unsafe
if (op == dst_ || op == src_) {
result_ = false;
return;
}
}
void VisitStmt_(const BufferStoreNode *op) final {
++mem_nest_;
for (const auto &index : op->indices) {
this->VisitExpr(index);
}
--mem_nest_;
if (op->buffer->data.get() == dst_) {
store_ = op;
this->VisitExpr(op->value);
store_ = nullptr;
} else {
this->VisitExpr(op->value);
}
}
void VisitStmt_(const AttrStmtNode *op) final {
// always reject extern code
if (op->attr_key == tir::attr::extern_scope ||
op->attr_key == tir::attr::volatile_scope) {
result_ = false;
return;
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode *op) final {
const VarNode *buf = op->buffer->data.get();
// cannot read from dst_ (no reduction)
if (buf == dst_) {
result_ = false;
return;
}
// do not allow indirect memory load
if (mem_nest_ != 0) {
result_ = false;
return;
}
if (src_ == buf) {
if (store_ == nullptr || store_->value.dtype() != op->dtype) {
result_ = false;
return;
}
ICHECK_EQ(store_->indices.size(), op->indices.size())
<< "Store/Load occur to the same buffer " << buf->name_hint
<< " with differing number of indices";
for (size_t i = 0; i < store_->indices.size(); i++) {
if (!tir::ExprDeepEqual()(store_->indices[i], op->indices[i])) {
result_ = false;
return;
}
}
}
++mem_nest_;
StmtExprVisitor::VisitExpr_(op);
--mem_nest_;
}
private:
// result of the check
bool result_{true};
// destination memory
const VarNode *dst_;
// source variable
const VarNode *src_;
// counter of load,
// it is not safe to inplace when there is nested load like A[B[i]]
int mem_nest_{0};
// The current store to be inspected
const BufferStoreNode *store_{nullptr};
};
/* \brief Rewrite and merge memory allocation.
*
* Using LinearAccessPatternFinder, determines which buffers could share an
* allocation. This includes both sequential usage of the same buffer and
* merging small allocations at the same scope into a single larger allocation.
* The merging of small allocations requires the codegen to cast the resulting
* value from the storage type to the output type after access.
*/
class StoragePlanRewriter : public StmtExprMutator {
public:
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
using AllocEntry = LinearAccessPatternFinder::AllocEntry;
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
bool reuse_require_exact_matched_dtype) {
detect_inplace_ = detect_inplace;
// plan the rewrite
LinearAccessPatternFinder finder;
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse,
reuse_require_exact_matched_dtype);
all_buffers_accessed_ = finder.all_buffers_accessed_;
this->PrepareNewAlloc();
// start rewrite
stmt = operator()(std::move(stmt));
if (attach_map_.count(nullptr)) {
return MakeAttach(attach_map_.at(nullptr), stmt);
}
return stmt;
}
template <typename Node> Node VisitBufferAccess(Node node) {
auto it = alloc_map_.find(node->buffer->data.get());
if (it != alloc_map_.end()) {
Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var);
Array<PrimExpr> indices = node->indices;
indices.Set(indices.size() - 1,
RemapIndex(node->buffer->dtype, indices[indices.size() - 1],
it->second));
auto writer = node.CopyOnWrite();
writer->buffer = buf;
writer->indices = indices;
}
return node;
}
Buffer RemapBuffer(Buffer buf, Var new_backing_array) {
auto key = buf.get();
auto it = buffer_remap_.find(key);
if (it != buffer_remap_.end()) {
ICHECK_EQ(it->second->data.get(), new_backing_array.get())
<< "Cannot remap buffer " << buf->name << " to use backing array "
<< new_backing_array->name_hint << ", previously used backing array "
<< it->second->data->name_hint;
return it->second;
}
Buffer remapped = Buffer(
new_backing_array, buf->dtype, buf->shape, buf->strides,
buf->elem_offset, new_backing_array->name_hint, buf->data_alignment,
buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span);
buffer_remap_[key] = remapped;
return remapped;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return VisitBufferAccess(std::move(node));
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(std::move(node));
}
PrimExpr VisitExpr_(const VarNode *op) final {
auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) {
if (it->second->bits_offset != 0) {
LOG(WARNING)
<< "Use a merged buffer variable address, could cause error";
}
return it->second->alloc_var;
} else {
return GetRef<PrimExpr>(op);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode *buffer = op->args[1].as<VarNode>();
auto it = alloc_map_.find(buffer);
if (it == alloc_map_.end()) {
return StmtExprMutator::VisitExpr_(op);
}
const StorageEntry *se = it->second;
PrimExpr offset = this->VisitExpr(op->args[2]);
PrimExpr extent = this->VisitExpr(op->args[3]);
uint64_t elem_bits = dtype.bits() * dtype.lanes();
ICHECK_EQ(se->bits_offset % elem_bits, 0U);
if (se->bits_offset != 0) {
offset =
make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
}
return Call(op->dtype, op->op,
{op->args[0], se->alloc_var, offset, extent, op->args[4]});
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::virtual_thread ||
tir::attr::IsPragmaKey(op->attr_key)) {
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto &svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
return AttrStmt(op->node, op->attr_key, op->value,
MakeAttach(svec, op->body));
} else {
return StmtExprMutator::VisitStmt_(op);
}
} else if (op->attr_key == tir::attr::volatile_scope) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
auto it = alloc_map_.find(op->node.as<VarNode>());
if (it == alloc_map_.end())
return stmt;
return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const ForNode *op) final {
ICHECK(op->kind != ForKind::kVectorized)
<< "VectorizeLoop before LiftStorageAlloc";
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto &svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
return For(op->loop_var, op->min, op->extent, op->kind,
MakeAttach(svec, op->body), op->thread_binding,
op->annotations);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const AllocateNode *op) final {
return this->VisitStmt(op->body);
}
Stmt VisitStmt_(const DeclBufferNode *op) final {
if (hoisted_buffer_decls_.count(op->buffer.get()) ||
!all_buffers_accessed_.count(op->buffer.get())) {
return this->VisitStmt(op->body);
}
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto it = alloc_map_.find(op->buffer->data.get());
it != alloc_map_.end()) {
Buffer buf = RemapBuffer(op->buffer, it->second->alloc_var);
node.CopyOnWrite()->buffer = buf;
}
return std::move(node);
}
private:
struct StorageEntry {
// The scope that this alloc attaches after
// For shared/local memory it is beginning of the thread extent.
// for global memory it is nullptr, means beginning of everything.
const Object *attach_scope_{nullptr};
// The constant size of the buffer in bits, only used if it is constant
uint64_t const_nbits{0};
// The storage scope.
StorageScope scope;
// The physical dimensionality of the allocations. Since
// StorageRewrite is applied after StorageFlatten/FlattenBuffer,
// this is size of `AllocateNode::extents`. If moved
size_t ndim;
// Allocs that shares this entry.
std::vector<const AllocateNode *> allocs;
// The children of this entry, not including itself.
std::vector<StorageEntry *> merged_children;
// The replacement Allocate, if any. May also include associated
// DeclBuffer statement.
std::vector<Stmt> alloc_nest;
// The var expr of new allocation.
Var alloc_var;
// The allocation element type.
DataType elem_type;
// This is non-zero if this allocate is folded into another one
// the address(in bits) becomes alloc_var + bits_offset;
// can be effectively converted to the element type.
// We need to convert bit_offset to offset of specific element type later.
//
// We use bits(instead of bytes) to support non-conventional indexing in
// hardware. When we are merging buffer together, the bits_offset are set to
// be aligned to certain value given by the max_simd_bits property of the
// special memory.
//
// This allows effective sharing among different types as long as their
// alignment requirement fits into the max_simd_bits.
uint64_t bits_offset{0};
};
// Checks whether the storage_scope is especially tagged for a specific
// memory. Special memory is all combined into a single allocation.
bool IsSpecialTaggedMemory(const StorageScope &scope) {
return scope.tag.length() != 0 && scope.tag != ".dyn" &&
scope.tag != ".workspace" && scope.tag != ".vtcm";
}
// Alllocate entry of node.
// Event entry in liveness analysis
struct EventEntry {
// variables we generate
std::vector<const VarNode *> gen;
// variables we kill
std::vector<const VarNode *> kill;
};
Stmt MakeAttach(const std::vector<StorageEntry *> &svec, Stmt body) {
for (auto it = svec.rbegin(); it != svec.rend(); it++) {
body = MergeNest((*it)->alloc_nest, body);
}
return body;
}
// Remap the index
PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry *e) {
if (e->bits_offset == 0)
return index;
uint64_t elem_bits = dtype.bits();
ICHECK_EQ(e->bits_offset % elem_bits, 0U);
return make_const(index.dtype(), e->bits_offset / elem_bits) + index;
}
// Prepare the new allocations
void PrepareNewAlloc() {
for (size_t i = 0; i < alloc_vec_.size(); ++i) {
StorageEntry *e = alloc_vec_[i].get();
attach_map_[e->attach_scope_].push_back(e);
}
// find allocation via attach map.
for (auto &kv : attach_map_) {
// find the element with the most amount of bytes.
std::vector<StorageEntry *> &vec = kv.second;
// try to find merge, for tagged memory
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry *e = vec[i];
if (IsSpecialTaggedMemory(e->scope)) {
ICHECK_NE(e->const_nbits, 0U)
<< "Special tagged memory must be const size";
for (size_t j = 0; j < i; ++j) {
if (e->scope == vec[j]->scope) {
vec[j]->merged_children.push_back(e);
break;
}
}
}
}
// Start allocation
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry *e = vec[i];
// already merged
if (e->bits_offset != 0)
continue;
if (e->merged_children.size() != 0) {
NewAllocTagMerged(e);
continue;
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
DataType alloc_type = e->allocs[0]->dtype;
for (const AllocateNode *op : e->allocs) {
if (op->dtype.lanes() > alloc_type.lanes()) {
alloc_type = op->dtype;
}
}
bool all_allocs_identical = std::all_of(
e->allocs.begin() + 1, e->allocs.end(),
[&](const AllocateNode *op) -> bool {
const AllocateNode *first = *e->allocs.begin();
if (op->dtype != first->dtype) {
return false;
}
if (op->extents.size() != first->extents.size()) {
return false;
}
ExprDeepEqual expr_equal;
for (size_t i = 0; i < op->extents.size(); i++) {
if (!expr_equal(op->extents[i], first->extents[i])) {
return false;
}
}
return true;
});
if (all_allocs_identical) {
// simply use the original allocation.
e->alloc_nest.push_back(
Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0)));
if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) {
e->alloc_nest.push_back(DeclBuffer(
RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0)));
hoisted_buffer_decls_.insert(ptr->buffer.get());
}
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag "
<< e->scope.to_string();
}
}
} else {
// Build a merged allocation
PrimExpr combo_size;
for (const AllocateNode *op : e->allocs) {
ICHECK_EQ(op->extents.size(), 1)
<< "Buffer var " << op->buffer_var->name_hint
<< " was identified as a re-usable allocation, but has "
<< op->extents.size() << " physical dimensions. "
<< "Currently, only flat 1-d memory spaces should be "
"identified as re-usable "
"allocations.";
PrimExpr sz = op->extents[0];
auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto *imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
LOG(WARNING) << "The allocation requires : " << imm->value
<< " * " << nbits
<< " bits, which is greater than the maximum of"
" int32. The size is cast to int64."
<< "\n";
sz = make_const(DataType::Int(64), imm->value);
}
}
// transform to bits
auto sz_nbits = sz * nbits;
if (combo_size.defined()) {
combo_size = max(combo_size, sz_nbits);
} else {
combo_size = sz_nbits;
}
}
// transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided =
analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
combo_size = indexdiv(combo_size, type_bits);
// round up for can not divided
if (!divided) {
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = analyzer_.Simplify(combo_size);
e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type,
{combo_size}, const_true(),
Evaluate(0)));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag "
<< e->scope.to_string();
}
}
}
}
}
}
// New allocation for merged data
void NewAllocTagMerged(StorageEntry *e) {
ICHECK_NE(e->scope.tag.length(), 0U);
// allocate with element type.
ICHECK_NE(e->const_nbits, 0U);
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_bits = e->const_nbits;
// By default, align to 32 bits.
size_t align = 32;
if (info.defined()) {
align = info->max_simd_bits;
}
// Always align to max_simd_bits
// so we can remap types by keeping this property
if (total_bits % align != 0) {
total_bits += align - (total_bits % align);
}
e->alloc_var = e->allocs[0]->buffer_var;
for (StorageEntry *child : e->merged_children) {
ICHECK_NE(child->const_nbits, 0U);
ICHECK_NE(total_bits, 0U);
child->bits_offset = total_bits;
child->alloc_var = e->alloc_var;
total_bits += child->const_nbits;
if (total_bits % align != 0) {
total_bits += align - (total_bits % align);
}
}
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
(total_bits + type_bits - 1) / type_bits);
e->alloc_nest.push_back(Allocate(e->alloc_var, e->elem_type, {alloc_size},
const_true(), Evaluate(0)));
if (info.defined()) {
ICHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
}
// Liveness analysis to find gen and kill point of each variable.
void LivenessAnalysis(const std::vector<StmtEntry> &seq) {
// find kill point, do a reverse linear scan.
std::unordered_set<const VarNode *> touched;
for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry &s = seq[i - 1];
for (const VarNode *buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
event_map_[s.stmt].kill.push_back(buffer);
}
}
}
// find gen point, do forward scan
touched.clear();
for (size_t i = 0; i < seq.size(); ++i) {
int64_t offset = seq[i].scope_pair_offset;
if (offset < 0)
continue;
const StmtEntry &s = seq[i + offset];
for (const VarNode *buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
event_map_[s.stmt].gen.push_back(buffer);
}
}
}
}
void PlanNewScope(const Object *op) {
if (thread_scope_ != nullptr) {
ICHECK(thread_scope_ == op);
// erase all memory atatched to this scope.
for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
if (it->second->attach_scope_ == op) {
it = const_free_map_.erase(it);
} else {
++it;
}
}
for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) {
if ((*it)->attach_scope_ == op) {
it = sym_free_list_.erase(it);
} else {
++it;
}
}
thread_scope_ = nullptr;
} else {
thread_scope_ = op;
}
}
// Memory plan algorithm
void
PlanMemory(const std::vector<StmtEntry> &seq,
const std::unordered_map<const VarNode *, AllocEntry> &alloc_info,
bool enable_reuse, bool reuse_require_exact_matched_dtype) {
std::unordered_set<const VarNode *> inplace_flag;
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i];
auto it = event_map_.find(seq[i].stmt);
// scope_pair_offset >= 0 means it is either
// - leaf stmt(offset = 0)
// - beginning of scope(offset < 0)
// In both cases, we need to handle the gen event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
// Inplace operation detection
// specially handle this
bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);
for (const VarNode *var : it->second.gen) {
ICHECK(alloc_info.count(var));
const AllocEntry &entry = alloc_info.at(var);
const AllocateNode *alloc = entry.alloc;
auto storage_scope =
StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
StorageEntry *dst_entry = nullptr;
// inplace detection
if (detect_inplace) {
// only one inplace var for s.stmt
bool inplace_found = false;
for (const VarNode *src : it->second.kill) {
if (!inplace_flag.count(src) && alloc_map_.count(src)) {
InplaceOpVerifier visitor;
StorageEntry *src_entry = alloc_map_.at(src);
if (src_entry->scope == storage_scope &&
src_entry->attach_scope_ == thread_scope_ &&
src_entry->elem_type == alloc->dtype.element_of() &&
visitor.Check(s.stmt, var, src)) {
uint64_t const_nbits =
static_cast<uint64_t>(alloc->ConstantAllocationSize()) *
alloc->dtype.bits() * alloc->dtype.lanes();
if (src_entry->const_nbits == const_nbits && !inplace_found) {
// successfully inplace
dst_entry = src_entry;
inplace_flag.insert(src);
inplace_found = true;
}
}
}
}
}
if (dst_entry == nullptr) {
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope,
entry.num_physical_dimensions, enable_reuse,
reuse_require_exact_matched_dtype);
}
dst_entry->allocs.emplace_back(alloc);
alloc_map_[var] = dst_entry;
}
}
// enter/exit new scope
if (s.stmt->IsInstance<AttrStmtNode>()) {
const auto *op = static_cast<const AttrStmtNode *>(s.stmt);
if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::virtual_thread ||
tir::attr::IsPragmaKey(op->attr_key)) {
PlanNewScope(op);
} else {
ICHECK(op->attr_key == tir::attr::extern_scope);
}
} else if (s.stmt->IsInstance<ForNode>()) {
const auto *op = static_cast<const ForNode *>(s.stmt);
if (op->kind == ForKind::kParallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) {
PlanNewScope(op);
}
}
}
// scope_pair_offset <= 0 means it is either
// - leaf stmt(offset = 0)
// - end of scope(offset < 0)
// In both cases, we need to handle the kill event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode *var : it->second.kill) {
// skip space which are already replaced by inplace
if (!inplace_flag.count(var)) {
this->Free(var);
}
}
}
}
}
// Allocate new storage entry.
StorageEntry *NewAlloc(const AllocateNode *op, const Object *attach_scope,
const StorageScope &scope, size_t const_nbits) {
ICHECK(op != nullptr);
// Re-use not successful, allocate a new buffer.
auto entry = std::make_unique<StorageEntry>();
entry->attach_scope_ = attach_scope;
entry->scope = scope;
entry->elem_type = op->dtype.element_of();
entry->const_nbits = const_nbits;
StorageEntry *e = entry.get();
alloc_vec_.emplace_back(std::move(entry));
return e;
}
StorageEntry *FindAlloc(const AllocateNode *op, const Object *attach_scope,
const StorageScope &scope,
size_t num_physical_dimensions, bool enable_reuse,
bool reuse_require_exact_matched_dtype) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits =
static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
// If the size of the array isn't known at compile-time, it must
// have its own allocation with size determined at runtime.
bool is_known_size = (const_nbits != 0);
// Currently, only flat memory spaces can be re-used. Packing
// into N-d space (e.g. 2-d texture memory on GPUs) will require
// more in-depth algorithms.
bool is_flat_memory_space = (num_physical_dimensions == 1);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
bool is_small_array =
(scope.tag.length() == 0) &&
(scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
(is_known_size && const_nbits <= 32));
if (!enable_reuse || is_small_array || !is_flat_memory_space) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
if (is_known_size) {
// constant allocation.
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits);
auto end = const_free_map_.upper_bound(const_nbits * match_range);
// start looking at the buffer that is bigger than the required size first
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope)
continue;
if (e->scope != scope)
continue;
// when not divided, no reuse, eg, float4 vs float3
if (e->bits_offset % op_elem_bits != 0)
continue;
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
continue;
}
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
}
// then start looking at smaller buffers.
for (auto it = mid; it != begin;) {
--it;
StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope)
continue;
if (e->scope != scope)
continue;
if (e->elem_type != op->dtype.element_of())
continue;
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
continue;
}
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
}
} else {
// Simple strategy: round roubin.
for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) {
StorageEntry *e = *it;
if (e->attach_scope_ != attach_scope)
continue;
if (e->scope != scope)
continue;
if (e->elem_type != op->dtype.element_of())
continue;
sym_free_list_.erase(it);
return e;
}
}
return NewAlloc(op, attach_scope, scope, const_nbits);
}
// simulated free.
void Free(const VarNode *var) {
auto it = alloc_map_.find(var);
ICHECK(it != alloc_map_.end());
StorageEntry *e = it->second;
ICHECK_NE(e->allocs.size(), 0U);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (e->scope.tag.length() == 0) {
// Disable sharing of local memory.
if (e->scope.rank >= StorageRank::kWarp ||
e->allocs[0]->dtype.is_handle())
return;
// disable reuse of small arrays
if (e->const_nbits > 0 && e->const_nbits <= 32)
return;
}
// normal free.
if (e->const_nbits != 0) {
const_free_map_.insert({e->const_nbits, e});
} else {
sym_free_list_.push_back(e);
}
}
// thread scope.
const Object *thread_scope_{nullptr};
// whether enable inplace detection.
bool detect_inplace_{false};
// Locations of free ops.
std::unordered_map<const Object *, EventEntry> event_map_;
// constant size free map.
std::multimap<uint64_t, StorageEntry *> const_free_map_;
// symbolic free list, for non constant items.
std::list<StorageEntry *> sym_free_list_;
// The allocation attach map
std::unordered_map<const Object *, std::vector<StorageEntry *>> attach_map_;
// The allocation assign map
std::unordered_map<const VarNode *, StorageEntry *> alloc_map_;
// The allocations
std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
// The buffer objects being remapped
std::unordered_map<const BufferNode *, Buffer> buffer_remap_;
// Buffers whose DeclBuffer has been hoisted to be adjacent to the new
// Allocate location
std::unordered_set<const BufferNode *> hoisted_buffer_decls_;
// Any buffers that is accessed at some point. DeclBuffer instances
// that do not appear in this list may be removed.
std::unordered_set<const BufferNode *> all_buffers_accessed_;
// analyzer
arith::Analyzer analyzer_;
};
/* Helper struct containing information on how a buffer is declared and used
*
*/
struct BufferVarInfo {
enum DeclarationLocation {
kPrimFuncParam = (1 << 0),
kPrimFuncBufferMap = (1 << 1),
kAllocateNode = (1 << 2),
kAllocateConstNode = (1 << 3),
kLetNode = (1 << 4),
};
// The tir::Var that represents this buffer.
Var var;
// The data type of an element of the buffer.
DataType element_dtype;
/* The extent of the buffer.
*
* If multidimensional, the extent of the last dimension of the buffer. If
* the size is unknown (e.g. pointer arguments to PrimFunc with no
* corresponding entry in buffer_map), then extent is zero.
*/
PrimExpr extent;
// Where the buffer was declared
DeclarationLocation declaration_location;
// When accessed, which element type is it accessed as. This may
// differ both in base type (e.g. int32* cast to float32* after
// packing in StorageRewrite) or in number of lanes (e.g. float16*
// cast to float16x4*).
std::unordered_set<DataType> access_dtype;
// Data types used for scalar reads. This is used to record vectorized read
// dtypes that can be shuffled for scalar reads when
// rewrite_scalar_read_to_vector_shuffle is enabled.
std::unordered_set<DataType> scalar_read_dtype;
DataType get_preferred_dtype() const {
std::unordered_set<DataType> base_access_dtype;
for (auto dtype : access_dtype) {
base_access_dtype.insert(dtype.element_of());
}
for (auto dtype : scalar_read_dtype) {
base_access_dtype.insert(dtype.element_of());
}
// If the array is accessed as multiple base types within a
// function, no point in changing the declared type. CodeGenC can
// handle this with a type-cast prior to indexing. Vulkan will
// raise an error at code-gen time, if a later pass doesn't split
// it out.
if (base_access_dtype.size() != 1) {
return element_dtype;
}
DataType preferred_base_type = *base_access_dtype.begin();
// If there is only one vectorizable size used to access the
// buffer, and if that access size is compatible with the array
// size, then the buffer is vectorizable. In the future, this
// could be improved to allow vectorized buffer access of size
// GCD(*lanes_used), if necessary.
// When there are scalar reads and no writes, access_dtype can be empty and
// we should avoid rewriting.
int preferred_lanes = element_dtype.lanes();
if (element_dtype.lanes() == 1 && (access_dtype.size() == 1)) {
int lanes = access_dtype.begin()->lanes();
// Check the scalar read dtypes are compatible with the vectorized access
// dtype.
for (auto dtype : scalar_read_dtype) {
if (dtype.lanes() % lanes != 0) {
return element_dtype;
}
}
arith::Analyzer analyzer_;
arith::ModularSet me = analyzer_.modular_set(extent);
if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
preferred_lanes = lanes;
}
}
return preferred_base_type.with_lanes(preferred_lanes);
}
};
/* Checks whether buffers are accessed as scalar or vector parameters in a
* function.
*
*/
class VectorTypeAccessChecker : public StmtExprVisitor {
public:
/* Constructor
*
* @param params The parameters passed to a PrimFunc
*
* @param buffer_map The buffer_map associated with a PrimFunc
*
* @param allow_untyped_handles If a buffer or pointer variable is
* missing a type annotation, assume that it has the same underlying
* type as it is later accessed, with scalar element types.
*/
VectorTypeAccessChecker(const Array<tir::Var> &params,
const Map<Var, Buffer> &buffer_map,
bool allow_untyped_pointers = false,
bool detect_scalar_read_patterns = true)
: allow_untyped_pointers_(allow_untyped_pointers),
detect_scalar_read_patterns_(detect_scalar_read_patterns) {
// If a parameter is in the buffer map, we want to track the
// version in the map.
for (auto it : buffer_map) {
Buffer &buffer = it.second;
Var buffer_var = buffer->data;
DataType dtype = buffer->dtype;
PrimExpr extent =
buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0;
OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncParam);
}
// If a pointer parameter isn't in the buffer map, then we want to
// track the parameter itself.
for (Var buffer_var : params) {
auto pointer_type = GetPointerType(buffer_var->type_annotation);
if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) {
DataType dtype = pointer_type.value();
PrimExpr extent = 0;
OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncBufferMap);
}
}
}
void VisitExpr_(const BufferLoadNode *op) final {
OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices,
/*is_buffer_load=*/true);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices,
/*is_buffer_load=*/false);
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
DataType dtype = op->args[0].dtype();
const VarNode *buffer = op->args[1].as<VarNode>();
PrimExpr index = op->args[2];
OnArrayAccess(dtype, buffer, {index}, false);
} else if (op->op.same_as(builtin::address_of())) {
if (auto load = op->args[0].as<BufferLoadNode>()) {
OnArrayAccess(load->dtype, load->buffer->data.get(), load->indices,
/*is_buffer_load=*/false);
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const AllocateNode *op) final {
const Array<PrimExpr> &extents = op->extents;
PrimExpr extent = extents[extents.size() - 1];
OnArrayDeclaration(op->buffer_var, op->dtype, extent,
BufferVarInfo::kAllocateNode);
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateConstNode *op) final {
const Array<PrimExpr> &extents = op->extents;
PrimExpr extent =
extents.size() ? extents[extents.size() - 1] : NullValue<PrimExpr>();
OnArrayDeclaration(op->buffer_var, op->dtype, extent,
BufferVarInfo::kAllocateConstNode);
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const LetNode *op) final {
HandleLetNode(op->var);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const LetStmtNode *op) final {
HandleLetNode(op->var);
StmtExprVisitor::VisitStmt_(op);
}
void HandleLetNode(Var let_var) {
if (let_var->dtype.is_handle()) {
auto pointer_type = GetPointerType(let_var->type_annotation);
if (pointer_type.has_value()) {
OnArrayDeclaration(let_var, pointer_type.value(), 0,
BufferVarInfo::kLetNode);
} else if (allow_untyped_pointers_) {
OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode);
} else {
LOG(FATAL) << "Let statement of variable " << let_var->name_hint
<< " is missing a type annotation, "
<< "or type annotation is not a pointer to primitive";
}
}
}
/* Update the type map for a buffer based on its declaration
*
* @param buffer The VarNode representing the buffer.
*
* @param element_dtype The dtype of a single element of the buffer.
* If unknown, when used with the allow_untyped_handles option,
* should be a handle dtype.
*
* @param extent The extent of the buffer. Zero if size is unknown.
*
* @param declaration_location How the buffer was allocated, so that
* some locations can be rewritten without others.
*/
void
OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent,
BufferVarInfo::DeclarationLocation declaration_location) {
ICHECK(info_map_.find(buffer.get()) == info_map_.end())
<< "Array declaration of " << buffer->name_hint
<< " occurred multiple times.";
if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
}
info_map_[buffer.get()] =
BufferVarInfo{buffer, element_dtype, extent, declaration_location};
}
/* Update the type map for a buffer based on its usage
*
* @param value_dtype The dtype of the value being stored to or
* loaded from the buffer.
*
* @param buffer The VarNode representing the buffer.
*
* @param indices The index at which the value is being stored/loaded.
*
* @param is_buffer_load Whether the access is BufferLoad
*/
void OnArrayAccess(DataType value_dtype, const VarNode *buffer,
const Array<PrimExpr> &indices, bool is_buffer_load) {
auto it = info_map_.find(buffer);
ICHECK(it != info_map_.end())
<< "Load/Store of buffer " << buffer->name_hint << " (" << buffer
<< ") occurred before its declaration.";
if (value_dtype.is_scalable_vector()) {
// Scalable types are not currently supported in storage_rewrite. Scalable
// buffer accesses are not currently checked and therefore are not
// rewritten.
return;
}
BufferVarInfo &var_info = it->second;
if (value_dtype.element_of() == DataType::Bool()) {
value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes());
}
if (var_info.element_dtype.is_handle()) {
ICHECK(allow_untyped_pointers_)
<< "Variable " << buffer->name_hint
<< " was missing a type annotation in its declaration";
var_info.element_dtype = value_dtype.element_of();
}
for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
ICHECK(indices[i].dtype().is_scalar())
<< "Only the last index of a buffer access may be a vector type.";
}
int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
DataType access_dtype = value_dtype;
int lanes_used = var_info.element_dtype.lanes();
// This can happen due to a previous pass that had rewrite_store_load =
// false. This occurs from the StorageRewrite in tvm::lower, followed by
// the PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load =
// false is necessary because the C-based codegens do not yet support
// vectorized pointer types (e.g. float16x4*). Once they do, this if
// statement should instead be replaced by the below ICHECK_EQ.
if (index_lanes * var_info.element_dtype.lanes() != value_dtype.lanes()) {
ICHECK_EQ(index_lanes, value_dtype.lanes());
lanes_used = 1;
var_info.element_dtype = var_info.element_dtype.with_lanes(1);
}
// TODO(Lunderberg): Uncomment this check once it can be applied.
// See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.
// ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(),
// value_dtype.lanes())
// << "Attempting to retrieve " << value_dtype.lanes() << " lanes of
// data with "
// << index_lanes << " indices into an array whose elements have "
// << var_info.element_dtype.lanes() << " lanes. "
// << "Expected output with " << index_lanes *
// var_info.element_dtype.lanes()
// << " lanes.";
// If the index is a RampNode with stride of 1 and offset
// divisible by the number of number of lanes, and the predicate
// does not apply any masking, then this array access could be
// vectorized.
if (indices.size()) {
const RampNode *ramp_index = indices[indices.size() - 1].as<RampNode>();
if (ramp_index && is_one(ramp_index->stride)) {
if (ramp_index->lanes->IsInstance<IntImmNode>()) {
int lanes =
static_cast<int>(Downcast<IntImm>(ramp_index->lanes)->value);
arith::ModularSet me = analyzer_.modular_set(ramp_index->base);
if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
lanes_used = lanes;
}
}
}
}
if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) {
const PrimExpr last_dim_index = indices[indices.size() - 1];
if (last_dim_index.dtype().lanes() == 1) {
arith::ModularSet me = analyzer_.modular_set(last_dim_index);
var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff));
return;
}
}
var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used));
}
// Map of buffer variable information determined
std::unordered_map<const VarNode *, BufferVarInfo> info_map_;
//
bool allow_untyped_pointers_{false};
// Whether to detect scalar read patterns for rewriting to vector shuffle
bool detect_scalar_read_patterns_{true};
// internal analyzer
arith::Analyzer analyzer_;
};
/* \brief Rewrites buffer/pointer variables from scalar types to vectorized
* types.
*
* Some runtimes do not allow casting between composite types and the underlying
* base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*).
* In these cases, in order to have vectorized load/store on an array, the
* element type of that array must be vectorized. This is in contrast to
* C-style runtimes, in which `float16x4* vec = *(float16x4*)(float_arr +
* offset)` is valid.
*
* By default, VectorTypeRewriter will attempt to rewrite all buffer variables
* to vectorized access, if the load/store occurring in the PrimFunc are all
* vectorized. This includes adjusting the indices being used to access the
* array. (e.g. If `float16* scalar_arr` is being converted to `float16x4*
* vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to
* `vec_arr[offset/4]`.)
*
* Currently, several of the C-style runtimes do not support buffers whose
* elements are vectorized types, or rely on the presence of the Ramp nodes to
* identify vectorized loads. The boolean parameters in the constructor are to
* mimic the previous behavior of VectorTypeRewriter, to avoid breaking these
* runtimes. Once all runtimes support vectorized buffer elements, these
* parameters can be removed.
*/
class VectorTypeRewriter : public StmtExprMutator {
public:
/* Constructor
*
* @param checker The VectorTypeAccessChecker that has previously read out
* information from the PrimFunc
*
* @param rewrite_params Whether pointer-type parameters passed into the
* function should be rewritten from scalar types to vectorized types.
*
* @param rewrite_buffer_map Whether buffers present in the buffer_map should
* have their data variable be rewritten from scalar types to vectorized
* types.
*
* @param rewrite_allocate_node Whether the buffer variable associated with
* AllocateNodes should be rewritten from scalar types to vectorized types.
*
* @param rewrite_indices Whether the indices to the Load and Store nodes
* should be rewritten to correspond to the new buffer_var type.
*
* @param rewrite_let_node Whether pointer declarations in let nodes
* should be re-written.
*/
VectorTypeRewriter(
const std::unordered_map<const VarNode *, BufferVarInfo> &info_map,
bool rewrite_params = true, bool rewrite_buffer_map = true,
bool rewrite_allocate_node = true, bool rewrite_indices = true,
bool rewrite_let_node = true, bool rewrite_allocate_const_node = true,
bool rewrite_scalar_read_to_vector_shuffle = true)
: rewrite_indices_(rewrite_indices) {
int rewrite_mask = 0;
if (rewrite_params) {
rewrite_mask |= BufferVarInfo::kPrimFuncParam;
}
if (rewrite_buffer_map) {
rewrite_mask |= BufferVarInfo::kPrimFuncBufferMap;
}
if (rewrite_allocate_node) {
rewrite_mask |= BufferVarInfo::kAllocateNode;
}
if (rewrite_let_node) {
rewrite_mask |= BufferVarInfo::kLetNode;
}
if (rewrite_allocate_const_node) {
rewrite_mask |= BufferVarInfo::kAllocateConstNode;
}
// Rewrite any buffer variables whose preferred type isn't their current
// type.
for (const auto &pair : info_map) {
const auto &var_info = pair.second;
DataType preferred = var_info.get_preferred_dtype();
if (preferred != var_info.element_dtype &&
(rewrite_mask & var_info.declaration_location)) {
Var old_buffer_var = var_info.var;
Var new_buffer_var(old_buffer_var->name_hint,
PointerType(PrimType(preferred),
GetPtrStorageScope(old_buffer_var)),
old_buffer_var->span);
rewrite_map_[var_info.var.get()] = {var_info.var, new_buffer_var,
var_info.element_dtype, preferred};
}
}
}
/*!
* \brief Mutator for BufferLoad or BufferStore.
* \return The rewritten node and the shuffle index. (Only for BufferLoad)
* When the shuffle index is non-negative, the caller should generate Shuffle
* to extract the element from the vector.
*/
template <typename Node> std::pair<Node, int> VisitBufferAccess(Node node) {
int shuffle_index = -1;
if (!rewrite_indices_) {
return {node, shuffle_index};
}
auto it = rewrite_map_.find(node->buffer->data.get());
if (it == rewrite_map_.end()) {
return {node, shuffle_index};
}
const auto &info = it->second;
Array<PrimExpr> indices = node->indices;
const PrimExpr &last_dim_index = indices[indices.size() - 1];
const RampNode *ramp_index = indices[indices.size() - 1].as<RampNode>();
if (node->buffer->dtype.is_scalable_vector() ||
last_dim_index.dtype().is_scalable_vector()) {
// Scalable types are not currently supported in storage_rewrite. Scalable
// buffer accesses are not currently checked and therefore are not
// rewritten.
return {node, shuffle_index};
}
if (ramp_index && is_one(ramp_index->stride) &&
ramp_index->lanes->IsInstance<IntImmNode>()) {
int lanes = static_cast<int>(Downcast<IntImm>(ramp_index->lanes)->value);
PrimExpr new_index =
ramp_index->base / make_const(ramp_index->base.dtype(), lanes);
if (lanes != info.factor()) {
ICHECK(info.factor() && lanes % info.factor() == 0);
int new_lanes = lanes / info.factor();
new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes,
ramp_index->span);
}
indices.Set(indices.size() - 1, new_index);
} else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) {
arith::ModularSet me = analyzer_.modular_set(last_dim_index);
ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0);
PrimExpr new_index =
last_dim_index / make_const(last_dim_index.dtype(), info.factor());
shuffle_index = me->base % info.factor();
;
indices.Set(indices.size() - 1, new_index);
}
auto writer = node.CopyOnWrite();
writer->buffer = RemapBuffer(node->buffer);
writer->indices = indices;
return {node, shuffle_index};
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto [modified, shuffle_index] = VisitBufferAccess(node);
// Not needed for BufferStoreNode, so we can't just call
// LegalizeDtype() in VisitBufferAccess.
if (node.same_as(modified)) {
return std::move(node);
} else {
auto writer = modified.CopyOnWrite();
// writer->LegalizeDType();
LegalizeBufferLoadDType(writer);
if (shuffle_index >= 0) {
return Shuffle::ExtractElement(std::move(modified), shuffle_index);
}
return std::move(modified);
}
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto [modified, shuffle_index] = VisitBufferAccess(std::move(node));
ICHECK(shuffle_index < 0);
return std::move(modified);
}
Stmt VisitStmt_(const LetStmtNode *op) final {
auto it = rewrite_map_.find(op->var.get());
PrimExpr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
if (var.same_as(op->var) && value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
}
return LetStmt(var, value, body);
}
Buffer RemapBuffer(Buffer buf) {
auto cache_key = buf.get();
auto cache_it = buffer_map_.find(cache_key);
if (cache_it != buffer_map_.end()) {
return cache_it->second;
}
auto info_it = rewrite_map_.find(buf->data.get());
if (info_it != rewrite_map_.end()) {
auto &info = info_it->second;
Array<PrimExpr> shape = buf->shape;
PrimExpr last_dim = shape[shape.size() - 1];
shape.Set(shape.size() - 1,
last_dim / make_const(last_dim.dtype(), info.factor()));
auto writer = buf.CopyOnWrite();
writer->data = info.new_buffer_var;
writer->dtype = info.new_element_dtype;
writer->shape = shape;
}
buffer_map_[cache_key] = buf;
return buf;
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
if (!rewrite_indices_) {
return expr;
}
const VarNode *buffer_var = op->args[1].as<VarNode>();
auto it = rewrite_map_.find(buffer_var);
if (it == rewrite_map_.end()) {
return expr;
}
const auto &info = it->second;
PrimExpr index = op->args[2];
PrimExpr extent = op->args[3];
PrimExpr flag = op->args[4];
PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype);
int factor = info.factor();
extent = extent / make_const(extent.dtype(), factor);
index = index / make_const(index.dtype(), factor);
Array<PrimExpr> acc_args{e_dtype, info.new_buffer_var, index, extent,
flag};
return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt VisitStmt_(const AllocateNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
auto it = rewrite_map_.find(op->buffer_var.get());
if (it == rewrite_map_.end()) {
return stmt;
}
const auto &info = it->second;
Var new_buffer_var = info.new_buffer_var;
Array<PrimExpr> extents = op->extents;
PrimExpr last_extent = extents[extents.size() - 1];
extents.Set(extents.size() - 1,
last_extent / make_const(last_extent.dtype(), info.factor()));
return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body);
}
Stmt VisitStmt_(const AllocateConstNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateConstNode>();
auto it = rewrite_map_.find(op->buffer_var.get());
if (it == rewrite_map_.end()) {
return stmt;
}
const auto &info = it->second;
Var new_buffer_var = info.new_buffer_var;
int factor = info.new_element_dtype.lanes() / op->dtype.lanes();
Array<PrimExpr> extents = op->extents;
extents.Set(extents.size() - 1, extents[extents.size() - 1] /
make_const(extents[0].dtype(), factor));
return AllocateConst(new_buffer_var, info.new_element_dtype, extents,
op->data, op->body);
}
/* Update the parameters and all remaining variable references
*
* Should be called after calling operator() on the body of the
* function.
*
* @param func A pointer to the PrimFunc being modified.
*/
void Finalize(PrimFunc *func_ptr) {
ICHECK(func_ptr) << "Finalize expects a non-null pointer";
auto &func = *func_ptr;
auto *n = func.CopyOnWrite();
// Remap any remaining references to the old buffer variables
Map<Var, Var> var_remap;
for (const auto &pair : rewrite_map_) {
const auto &info = pair.second;
var_remap.Set(info.old_buffer_var, info.new_buffer_var);
}
n->body = Substitute(n->body, var_remap);
// Remap the argument list to use the new buffer variables.
Array<Var> new_params;
for (const auto &old_param : n->params) {
auto it = rewrite_map_.find(old_param.get());
if (it == rewrite_map_.end()) {
new_params.push_back(old_param);
} else {
const auto &info = it->second;
new_params.push_back(info.new_buffer_var);
}
}
n->params = new_params;
// Remap the Buffer objects in PrimFunc::buffer_map so that the
// buffers use the new buffer variables
Map<Var, Buffer> new_buffer_map;
for (const auto &pair : n->buffer_map) {
Var key = pair.first;
Buffer old_buffer = pair.second;
Var old_var = old_buffer->data;
Buffer new_buffer = RemapBuffer(old_buffer);
new_buffer_map.Set(key, new_buffer);
}
n->buffer_map = new_buffer_map;
}
private:
struct RewriteInfo {
Var old_buffer_var;
Var new_buffer_var;
DataType old_element_dtype;
DataType new_element_dtype;
int factor() const {
int old_lanes = old_element_dtype.lanes();
int new_lanes = new_element_dtype.lanes();
ICHECK_EQ(new_lanes % old_lanes, 0);
return new_lanes / old_lanes;
}
};
bool rewrite_indices_{true};
std::unordered_map<const VarNode *, RewriteInfo> rewrite_map_;
std::unordered_map<const BufferNode *, Buffer> buffer_map_;
arith::Analyzer analyzer_;
};
// Rewrite allocates, pointer parameters, and buffer map into vectorized
// versions if each access into a buffer is the same vector type.
PrimFunc PointerValueTypeRewrite(
PrimFunc f, bool allow_untyped_pointers = false, bool rewrite_params = true,
bool rewrite_buffer_map = true, bool rewrite_allocate_node = true,
bool rewrite_indices = true, bool rewrite_let_node = true,
bool rewrite_allocate_const_node = true,
bool rewrite_scalar_read_to_vector_shuffle = true) {
VectorTypeAccessChecker checker(f->params, f->buffer_map,
allow_untyped_pointers,
rewrite_scalar_read_to_vector_shuffle);
checker(f->body);
VectorTypeRewriter rewriter(
checker.info_map_, rewrite_params, rewrite_buffer_map,
rewrite_allocate_node, rewrite_indices, rewrite_let_node,
rewrite_allocate_const_node, rewrite_scalar_read_to_vector_shuffle);
PrimFuncNode *n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
rewriter.Finalize(&f);
return f;
}
using namespace tir::transform;
namespace transform {
Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem =
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
AllocateCollector collector;
collector(f->body);
bool has_dynamic = collector.dyn_shmem_allocs_.size() > 1;
if (has_dynamic || merge_static_smem) {
// For IRModule utilizing dynamic shared memory, reuse is not enabled
// Because dynamic doesn't require maintaining the readability and
// it benefits from a more optimized allocation strategy through the
// Pass `MergeSharedMemoryAllocations`.
// When `merge_static_smem` is true, we will reuse and merge shared
// memory in a dedicated pass `MergeSharedMemoryAllocations`.
// And so we don't enable reuse in this pass.
enable_reuse = false;
}
Optional<Target> target = f->GetAttr<Target>("target");
if (target.defined() && (target.value()->kind->name == "vulkan" ||
target.value()->kind->name == "webgpu")) {
// Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU
reuse_require_exact_matched_dtype = true;
}
auto *n = f.CopyOnWrite();
n->body =
StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse,
reuse_require_exact_matched_dtype);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
// padded out to 32 bits) would require either rewriting
// AllocateConst::data, or would require the code generators to
// handle vectorized constants.
return PointerValueTypeRewrite(std::move(f), true, false, false, false,
true, true, false, false);
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite);
});
Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
return tl::PointerValueTypeRewrite(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite",
PointerValueTypeRewrite);
});
} // namespace transform
} // namespace tl
} // namespace tvm
/*!
* \file thread_storage_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -269,7 +270,7 @@ private:
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
num_partial_threads_ = NullOpt;
num_partial_threads_ = std::nullopt;
} else {
TileLangStorageAccessVisitor::VisitStmt_(op);
}
......@@ -371,8 +372,11 @@ Pass TileLangThreadPartialSync(String storage_scope) {
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync")
.set_body_typed(TileLangThreadPartialSync);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ThreadPartialSync",
TileLangThreadPartialSync);
});
} // namespace transform
} // namespace tl
......
......@@ -20,7 +20,8 @@
/*!
* \file thread_storage_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -367,7 +368,7 @@ private:
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
num_partial_threads_ = NullOpt;
num_partial_threads_ = std::nullopt;
} else {
TileLangStorageAccessVisitor::VisitStmt_(op);
}
......@@ -786,7 +787,10 @@ tvm::transform::Pass ThreadSync(String storage_scope) {
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ThreadSync").set_body_typed(ThreadSync);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync);
});
} // namespace transform
} // namespace tl
......
......@@ -22,7 +22,8 @@
*/
// Loop vectorizer as in Halide pipeline.
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -631,7 +632,7 @@ public:
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = NullOpt;
Optional<Stmt> else_case = std::nullopt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
......@@ -688,10 +689,6 @@ public:
stmt = Substitute(stmt, {{var_, idx}});
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode *op) final {
LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
}
private:
// analyzer
......@@ -787,6 +784,10 @@ private:
}
};
inline bool TargetHasSVE() {
return Target::Current()->GetFeature<Bool>("has_sve").value_or(false);
}
class LoopVectorizer : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode *op) final {
......@@ -796,7 +797,7 @@ public:
if (!extent_as_int || extent_as_int->value < 1) {
bool is_scalable_expr =
CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
ICHECK(is_scalable_expr && arith::TargetHasSVE())
ICHECK(is_scalable_expr && TargetHasSVE())
<< "Failed to vectorize loop with extent " << op->extent
<< " for target " << Target::Current();
}
......@@ -837,7 +838,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {});
}
TVM_REGISTER_GLOBAL("tl.transform.VectorizeLoop").set_body_typed(VectorizeLoop);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop);
});
} // namespace tl
} // namespace tvm
......@@ -5,6 +5,7 @@
#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -447,7 +448,7 @@ private:
order_anno.push_back(Integer(op_info.order));
stage_anno.push_back(Integer(op_info.stage));
}
Map<String, ObjectRef> for_annotations = op->annotations;
Map<String, Any> for_annotations = op->annotations;
for_annotations.erase("tl_pipeline_group");
for_annotations.Set("software_pipeline_order", order_anno);
for_annotations.Set("software_pipeline_stage", stage_anno);
......@@ -636,9 +637,9 @@ private:
Stmt VisitStmt_(const ForNode *op) final {
int num_stages = 1;
auto num_stages_anno = op->annotations.Get("num_stages");
if (num_stages_anno.defined()) {
ICHECK(num_stages_anno.as<IntImmNode>());
num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
if (num_stages_anno) {
ICHECK(num_stages_anno->as<IntImmNode>());
num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
}
loop_stack_.emplace_back(op->loop_var, op->extent);
......@@ -648,16 +649,16 @@ private:
Array<Integer> stage_info_array;
auto group_anno = op->annotations.Get("tl_pipeline_group");
if (group_anno.defined()) {
group_info_array = Downcast<Array<Array<Integer>>>(group_anno);
if (group_anno) {
group_info_array = Downcast<Array<Array<Integer>>>(group_anno.value());
}
auto order_anno = op->annotations.Get("tl_pipeline_order");
if (order_anno.defined()) {
order_info_array = Downcast<Array<Integer>>(order_anno);
if (order_anno) {
order_info_array = Downcast<Array<Integer>>(order_anno.value());
}
auto stage_anno = op->annotations.Get("tl_pipeline_stage");
if (stage_anno.defined()) {
stage_info_array = Downcast<Array<Integer>>(stage_anno);
if (stage_anno) {
stage_info_array = Downcast<Array<Integer>>(stage_anno.value());
}
PipelineInfo pipeline_info(group_info_array, order_info_array,
......@@ -686,8 +687,8 @@ private:
auto result = FilterByRole(op);
Stmt grouped_for_node;
if (result.as<ForNode>() && group_anno.defined() &&
group_info_array.size() > 0 && !is_emitting_producer_) {
if (result.as<ForNode>() && group_anno && group_info_array.size() > 0 &&
!is_emitting_producer_) {
GroupOpRewriter group_op_rewriter(pipeline_info_);
auto for_node = Downcast<For>(result);
grouped_for_node = group_op_rewriter(for_node);
......@@ -707,7 +708,7 @@ private:
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
}
if (is_emitting_producer_ || !group_anno.defined() ||
if (is_emitting_producer_ || !group_anno ||
group_info_array.size() == 0) {
loop_stack_.pop_back();
return for_node;
......@@ -1230,8 +1231,10 @@ tvm::transform::Pass WarpSpecialized() {
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
TVM_REGISTER_GLOBAL("tl.transform.WarpSpecialized")
.set_body_typed(WarpSpecialized);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
});
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file warp_specialized_pipeline.cc
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -131,7 +113,7 @@ private:
Stmt VisitStmt_(const ForNode *op) final {
auto order_anno = op->annotations.Get("tl_pipeline_order");
if (!order_anno.defined()) {
if (!order_anno) {
return StmtExprMutator::VisitStmt_(op);
}
......@@ -281,8 +263,10 @@ tvm::transform::Pass RewriteWgmmaSync() {
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.RewriteWgmmaSync")
.set_body_typed(RewriteWgmmaSync);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync);
});
} // namespace tl
} // namespace tvm
......@@ -4,6 +4,8 @@ from tilelang import tvm as tvm
import tilelang.language as T
import torch
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 0
......
......@@ -40,8 +40,8 @@ def tl_matmul(
assert in_dtype in [
"float16",
"bfloat16",
"e4m3_float8",
"e5m2_float8",
"float8_e4m3",
"float8_e5m2",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
......@@ -52,7 +52,7 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
......@@ -220,4 +220,5 @@ def test_assert_tl_matmul_bfloat16():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_assert_tl_matmul_bfloat16()
# ruff: noqa
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
import torch
from typing import Optional, Union
from einops import rearrange, repeat
tilelang.testing.set_random_seed(42)
def naive_nsa_ref(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError(
"Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o_slc = torch.zeros_like(v)
o_swa = torch.zeros_like(v) if window_size > 0 else None
varlen = True
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat(
[block_indices.new_tensor(range(0, B * T, T)),
block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[
i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
s_b = block_counts
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, g_slc, g_swa, block_indices))
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]]
else:
s_b = block_counts
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [HQ]
g_slc_i = g_slc_b[i_q]
# [HQ]
g_swa_i = g_swa_b[i_q]
# [S*BS, HQ]
i_i = i_b[i_q]
# [HQ]
if isinstance(block_counts, torch.Tensor):
s_i = s_b[i_q]
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i_slc, v_i_slc = map(
lambda x: x.gather(
0,
i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill(
torch.logical_or(i_i < 0, i_i > i_q) |
(c >= s_i if block_counts is not None else False), float('-inf')).softmax(0)
if not varlen:
o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1],
(k_b, v_b))
attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def native_sparse_attention(batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=16,
selected_blocks=16,
num_stages=0,
threads=32):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV])
return native_sparse_attention
def run_native_sparse_attention(batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=16,
selected_blocks=16,
num_stages=0,
threads=32):
dtype = torch.float16
head_kv = heads // groups
program = native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale, block_size,
groups, selected_blocks, num_stages, threads)
kernel = tilelang.compile(program, out_idx=-1)
Q = torch.randn((batch, seq_len, heads, dim), dtype=dtype).cuda()
K = torch.randn((batch, seq_len, head_kv, dim), dtype=dtype).cuda()
V = torch.randn((batch, seq_len, head_kv, dim), dtype=dtype).cuda()
g_slc = torch.ones((batch, seq_len, heads), dtype=dtype).cuda()
g_swa = torch.ones((batch, seq_len, heads), dtype=dtype).cuda()
block_indices = torch.full((batch, seq_len, head_kv, selected_blocks),
seq_len,
dtype=torch.long,
device='cuda')
for b in range(batch):
for t in range(seq_len):
for h in range(head_kv):
i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, selected_blocks + 1, (batch, seq_len, head_kv), device='cuda')
out = kernel(Q, K, V, block_indices.to(torch.int32))
ref = naive_nsa_ref(
q=Q,
k=K,
v=V,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale,
)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
def test_tilelang_kernel_deepseek_nsa():
# disable pipeline
run_native_sparse_attention(
batch=2,
heads=64,
seq_len=1,
dim=16,
is_causal=True,
scale=None,
block_size=32,
groups=16,
selected_blocks=16,
num_stages=0,
threads=32)
# enable pipeline
run_native_sparse_attention(
batch=2,
heads=64,
seq_len=1,
dim=16,
is_causal=True,
scale=None,
block_size=32,
groups=16,
selected_blocks=16,
num_stages=2,
threads=32)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -97,7 +97,7 @@ def test_fp4_fp16_convert_close():
block_K,
"float16",
)
print(program.script())
kernel = tilelang.compile(program, out_idx=[1])
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
......@@ -642,4 +642,5 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_fp4_fp16_convert_close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment