• Lei Wang's avatar
    [Refactor] Replace default fp8 dtype with cute to perform fast cast (#520) · 6addc509
    Lei Wang authored
    * [Refactor] Enhance GEMM Warp Partitioning Logic and Introduce Buffer Remapping (#516)
    
    * Improved the warp partitioning logic in `Gemm::ComputeWarpPartition` to better accommodate various GEMM policies, including FullRow, FullCol, and Square, ensuring optimal performance based on matrix dimensions.
    * Introduced a new `RemapBufferRewriter` class to handle buffer reference updates and padding annotations during statement transformations, enhancing memory access safety and clarity.
    * Updated the `OptimizeForTarget` function to include a new step for configuring index bitwidth, improving the overall optimization process.
    * Refactored existing code to utilize constants for warp sizes, enhancing maintainability and readability.
    * Added checks to ensure correct warp allocation and padding map handling, improving robustness in memory management strategies.
    
    * [Refactor] Update ConfigIndexBitwidthRewriter to Support Auto-Check Feature
    
    * Modified the constructor of `ConfigIndexBitwidthRewriter` to include an `auto_check` parameter, allowing for dynamic bitwidth adjustments based on input conditions.
    * Enhanced the `VisitExpr_` methods to apply the new auto-check logic, ensuring that integer types are upgraded to 64 bits when necessary, or to a specified index bitwidth otherwise.
    * Updated the `ConfigIndexBitwidth` pass to determine the index bitwidth based on the presence of configuration, improving flexibility in handling different scenarios.
    
    * Add dynamic matrix multiplication example and corresponding test
    
    * Introduced `example_dynamic.py` to demonstrate dynamic matrix multiplication using TileLang and PyTorch, including a main function for execution and performance profiling.
    * Added `test_example_dynamic.py` to validate the functionality of the dynamic matrix multiplication example.
    * The example includes detailed parameter configurations and checks against PyTorch's implementation for correctness.
    
    * lint fix
    
    * Add get_num_sms function to retrieve the number of streaming multiprocessors on the CUDA device
    
    * Implemented the `get_num_sms` function in `cuda_driver.py` to return the count of streaming multiprocessors for a specified CUDA device.
    * Updated the `__init__.py` file to include the new function in the module exports.
    
    * lint fix
    
    * Add global barrier state and expectation handling in CUDA code generation
    
    * Introduced `vid_global_barrier_state_` and `vid_global_barrier_expect_` to manage global barrier synchronization in the CUDA code generator.
    * Updated `Finish` method to declare the global barrier state if needed.
    * Implemented handling for `EvaluateNode` to initialize the barrier expectation.
    * Removed unnecessary extern declaration for the global barrier state in `PrintStorageSync` method.
    * Enhanced CUDA FP8 type definitions for better alignment and structure.
    
    * Enhance CUDA FP8 type handling and debug printing
    
    * Updated `cuda_fp8.h` to replace NVidia's FP8 types with Cute's FP8 types for better compatibility and structure.
    * Added specializations for `debug_print_var` and `debug_print_buffer_value` functions to support the new FP8 types, improving debugging capabilities for these data types.
    * Updated `debug.h` to include the new `cuda_fp8.h` header for access to the FP8 type definitions.
    
    * Refactor CUDA code generation to remove unnecessary managed qualifier for global barrier state
    
    * Updated the `Finish` method in `codegen_cuda.cc` to declare the global barrier state without the `__managed__` qualifier, simplifying the declaration.
    * Added a new `sync_global` function in `builtin.py` to synchronize all threads in a block, enhancing synchronization capabilities in the TileLang framework.
    
    * Remove deprecated CUDA kernel and Python script for FP8 E4M3 casting
    
    * Deleted the `cast_to_fp8_e4m3_kernel` CUDA kernel implementation and its corresponding Python script, streamlining the codebase by removing unused components related to FP8 E4M3 type casting.
    * This cleanup enhances maintainability and reduces potential confusion regarding obsolete code.
    
    * lint fix
    6addc509
builtin.py 9.66 KB