Commit c93e8695 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix Hopper GEMM layout for small tile size (#497)

* [Enhancement] Improve GEMM layout function and documentation

* Added detailed documentation for the makeGemmABLayout function, explaining parameters and layout selection strategies.
* Updated the layout selection logic to use mat_continuous consistently, enhancing clarity and correctness in memory layout calculations.
* Adjusted the InferLayout method to reflect changes in the layout function, ensuring accurate matrix dimension handling for transposed cases.

* lint fix
parent 73ae8087
......@@ -501,6 +501,36 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
return makeGemmABLayoutPadded(stride, continuous, 16);
}
/*!
* \brief Creates a memory layout for GEMM's A or B matrices.
*
* This function selects an appropriate memory layout based on the matrix
* dimensions, element size, continuity, and a k-factor. It aims to optimize
* memory access patterns, potentially using swizzling techniques or specialized
* layouts for different data types and hardware characteristics.
*
* \param mat_stride The leading dimension of the matrix (e.g., K for a
* row-major M x K matrix). This is the number of elements to skip to get to the
* same column in the next row (row-major) or to the same row in the next column
* (column-major). \param mat_continuous The length of the dimension stored
* contiguously in memory (e.g., K for a row-major M x K matrix, or M for a
* column-major M x K matrix). \param continuity The size of the dimension that
* is continuous from the perspective of memory bank access. This is used to
* select specific swizzling strategies. It might be the same as mat_continuous
* or different based on tiling or hardware details.
* \param element_size The size of each element in the matrix, in bits (e.g., 8,
* 16, 32, 64). \param kfactor An integer factor that influences layout
* selection, particularly for fp64 and int8 types. It often relates to how the
* K dimension of the GEMM (M x K * K x N) is handled or tiled.
* - For fp64 (element_size == 64):
* - kfactor == 1 often implies K is in the "outer" loop (e.g.,
* KxN matrix).
* - kfactor == 2 often implies K is in the "inner" loop (e.g.,
* NxK matrix).
* - For int8 (element_size == 8):
* - kfactor == 1 uses a padded layout.
* \return A Layout object representing the chosen memory layout.
*/
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
if (element_size == 64) {
......@@ -513,9 +543,9 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (continuity % (vector_size * 8) == 0)
else if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (continuity % (vector_size * 4) == 0)
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else {
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
......
......@@ -243,9 +243,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
const int64_t continuity =
trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2));
trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
results.Set(A,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2));
} else {
ICHECK(trans_A == false);
auto fragment =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment