• Lei Wang's avatar
    [Bugfix] Add tf32 casting to GEMM templates (#556) · 8cc8db52
    Lei Wang authored
    * Add tf32 casting functionality to GEMM templates
    
    - Introduced a `cast_float_to_tf32` function to convert float32 values to tfloat32 format across gemm_sm80, gemm_sm89, and gemm_sm90 templates.
    - Implemented conditional casting in relevant sections of the GEMM operations to ensure compatibility with tfloat32 types.
    - Enhanced the handling of tensor views to support the new casting logic, improving performance and accuracy in matrix operations.
    
    * lint fix
    
    * Refactor tfloat32 casting logic in GEMM templates
    
    - Replaced the `is_tfloat32` boolean with `need_tfloat32_cast` to improve clarity and accuracy in determining when to cast float32 to tfloat32.
    - Updated relevant sections in `gemm_sm80`, `gemm_sm89`, and `gemm_sm90` to utilize the new casting logic, enhancing compatibility with tfloat32 types.
    - Ensured consistent application of casting across tensor views, improving performance and correctness in matrix operations.
    
    * Refactor GEMM template functions for improved readability
    
    - Simplified the function signature of `body_rs` in both `gemm_sm80` and `gemm_sm90` templates for better clarity.
    - Adjusted the casting logic in `gemm_sm90` to ensure consistent application of `cast_float_to_tf32` across tensor views, enhancing performance and maintainability.
    
    * Enhance tf32 casting logic in GEMM templates
    
    - Updated the `cast_float_to_tf32` function in `gemm_sm80`, `gemm_sm89`, and `gemm_sm90` to conditionally apply the casting only if the input is finite, improving robustness.
    - Simplified the `need_tfloat32_cast` logic to clarify the conditions under which tfloat32 casting is required, enhancing code readability and maintainability.
    
    * Refactor GEMM template functions and layout inference logic
    
    - Removed the `cast_float_to_tf32` function from `gemm_sm90` and updated the `body_sr` function to streamline the casting process for tensor views, enhancing code clarity and maintainability.
    - Improved layout inference in `layout_inference.cc` by adding checks for the layout map's definition, ensuring robustness in handling layout annotations.
    - Simplified the handling of layout maps in the `annotate_layout` function, allowing for more flexible layout definitions and error handling.
    8cc8db52
layout_inference.cc 23.8 KB