• Lei Wang's avatar
    [Enhancement] Enable runtime tensor data type validation (#146) · d0434c3e
    Lei Wang authored
    * Fix debug print buffer template for unsigned char type
    
    - Update debug_print_buffer_value template specialization for unsigned char
    - Modify test_tilelang_debug_print.py to include additional dtype tests
    - Add test case for uint8 dtype in debug print buffer function
    
    * Refactor debug print buffer template formatting for unsigned char
    
    - Improve code formatting for debug_print_buffer_value template specialization
    - Adjust line breaks and indentation for better readability
    - Maintain consistent code style with other template specializations
    
    * Extract map_torch_type utility function to tilelang.utils.tensor
    
    - Move map_torch_type function from multiple test files to a centralized location
    - Import map_torch_type from tilelang.utils.tensor in kernel test files
    - Improve code reusability by creating a shared utility function for type mapping
    
    * Add buffer dtype mapping for Cython kernel adapter
    
    - Introduce buffer_dtype_map in CythonKernelAdapter to track buffer variable dtypes
    - Add _process_buffer_dtype method to extract dtype information from TIR function
    - Update CythonKernelWrapper to support setting and validating buffer dtypes
    - Enhance type checking during kernel execution with dtype verification
    - Improve logging message for Cython JIT adapter compilation
    
    * Add static shape mapping for Cython kernel adapter
    
    - Introduce static_shape_map in CythonKernelAdapter to track buffer variable static shapes
    - Add _process_static_shape method to extract static shape information from TIR function
    - Update CythonKernelWrapper to support setting and validating static shapes
    - Enhance type checking during kernel execution with static shape verification
    
    * Add Multi-Head Attention (MHA) Backward Pass Test for TileLang Kernel
    
    - Implement comprehensive test for Multi-Head Attention backward pass
    - Support both causal and non-causal attention scenarios
    - Add reference implementation for comparing kernel outputs
    - Test different batch sizes, head counts, sequence lengths, and head dimensions
    - Verify forward and backward pass correctness using torch.testing.assert_close
    
    * Set random seed for MHA backward pass test
    
    - Add random seed initialization for consistent test reproducibility
    - Use tilelang.testing.set_random_seed(42) to ensure deterministic test results
    d0434c3e
test_tilelang_kernel_gemv_simt.py 6.14 KB