Commit 68a3c4f3 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Fallback transposed_ldmatrix into `SM75_U16x4_LDSM_N` when warp_n is 8 (#498)

* Remove debug print statement from block_sparse_attn_triton.py and implement a timeout handler in autotuner for function execution. This enhances the robustness of the autotuner by allowing it to handle timeouts gracefully.

* Enhance the autotuner module by adding a timeout handler for function execution, improving robustness in handling long-running tasks. This change includes the introduction of a custom TimeoutException and updates to the run_with_timeout function for better signal management.

* Add merge shared memory allocations pass and related configurations

- Introduced a new pass for merging shared memory allocations in GPU kernels, allowing for more efficient memory usage.
- Registered configuration options for debugging and controlling the merging behavior.
- Updated relevant files to integrate the new pass into the TileLang engine and transform modules.
- Adjusted import paths and added documentation for the new functionality.

* Reduce num_stages parameter in GEMM functions from 3 to 1 for improved performance in test_tilelang_kernel_gemm.py

* Update Copy type in OperandTraits for GEMM templates to use conditional selection based on num_warp_n. This change enhances memory access patterns for different configurations in CUDA kernels.

* lint fix
parent c93e8695
...@@ -6,6 +6,7 @@ import sys ...@@ -6,6 +6,7 @@ import sys
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mla_decode(): def test_example_mla_decode():
with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]): with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]):
example_mla_decode.main() example_mla_decode.main()
......
...@@ -96,7 +96,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, ...@@ -96,7 +96,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{})); Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
}; };
template <int N, int K, int num_warp_n> template <int N, int K, int num_warp_n>
...@@ -106,7 +107,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, ...@@ -106,7 +107,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{})); Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
}; };
template <int N, int K, int num_warp_n> template <int N, int K, int num_warp_n>
......
...@@ -199,7 +199,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, ...@@ -199,7 +199,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{})); Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
}; };
template <int N, int K, int num_warp_n> template <int N, int K, int num_warp_n>
...@@ -209,7 +210,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, ...@@ -209,7 +210,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{})); Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>::type;
}; };
template <int N, int K, int num_warp_n> template <int N, int K, int num_warp_n>
......
...@@ -253,7 +253,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, ...@@ -253,7 +253,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{})); Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
}; };
template <int N, int K, int num_warp_n> template <int N, int K, int num_warp_n>
...@@ -263,7 +264,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, ...@@ -263,7 +264,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{})); Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
}; };
template <int N, int K, int num_warp_n> template <int N, int K, int num_warp_n>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment