Commit dbe8689f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Enhance smem copy selector for uncommon shape (#510)

* [Refactor] Enhance GEMM warp partitioning logic for improved performance and flexibility

* Updated the warp partitioning logic in `Gemm::ComputeWarpPartition` to better handle various GEMM policies, including FullRow, FullCol, and Square.
* Implemented checks to dynamically adjust warp allocation based on matrix dimensions, ensuring optimal performance.
* Introduced a new `SelectCopy` template to streamline memory access patterns in CUDA templates, enhancing compatibility across different architectures.
* Refactored the Python `GemmWarpPolicy` class to align with the updated C++ logic, improving clarity and maintainability in warp allocation strategies.

* [Refactor] Optimize matrix multiplication parameters and performance in quickstart example

* Updated thread count in the kernel context from 256 to 128 to enhance performance.
* Increased block sizes for matrix dimensions (M, N, block_M, block_N) to 1024 and 128 respectively, improving computational efficiency.
* Adjusted the pipeline stages in the GEMM loop from 0 to 3 for better parallel execution.
* Cleaned up comments for clarity and corrected a typo in the memory copy comment.

* [Refactor] Simplify Copy type selection in OperandTraits for improved clarity

* Replaced the conditional Copy type definition with a new SelectCopy template in OperandTraits, enhancing readability and maintainability of the code.
* This change streamlines the logic for selecting memory copy patterns based on matrix dimensions and warp configurations.
parent 094796b6
......@@ -68,44 +68,90 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
if (this->policy == GemmWarpPolicy::kFullRow ||
this->policy == GemmWarpPolicy::kSquare) {
m_warp = num_warps;
ICHECK(this->M % num_warps == 0) << this->M << " % " << num_warps;
n_warp = 1;
} else if (this->policy == GemmWarpPolicy::kFullCol) {
m_warp = 4;
n_warp = num_warps / 4;
ICHECK(this->N % n_warp == 0) << this->N << " % " << n_warp;
m_warp = 1;
n_warp = num_warps;
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
return {m_warp, n_warp};
}
if (this->policy == GemmWarpPolicy::kFullRow) {
// Try to partition M first
m_warp = num_warps;
ICHECK(this->M % num_warps == 0) << this->M << " % " << num_warps;
n_warp = 1;
// If M cannot be evenly divided by m_warp*16, try to split remaining warps
// to N
if (this->M % (m_warp * 16) != 0) {
// Calculate how many warps we can use for M
int max_m_warps = this->M / 16;
m_warp = max_m_warps;
// Use remaining warps for N
n_warp = num_warps / m_warp;
if (n_warp == 0)
n_warp = 1;
}
} else if (this->policy == GemmWarpPolicy::kFullCol) {
// Try to partition N first
m_warp = 1;
n_warp = num_warps;
ICHECK(this->N % num_warps == 0) << this->N << " % " << num_warps;
// If N cannot be evenly divided by n_warp*8, try to split remaining warps
// to M
if (this->N % (n_warp * 8) != 0) {
// Calculate how many warps we can use for N
int max_n_warps = this->N / 8;
n_warp = max_n_warps;
// Use remaining warps for M
m_warp = num_warps / n_warp;
if (m_warp == 0)
m_warp = 1;
}
} else if (this->policy == GemmWarpPolicy::kSquare) {
auto factors = toPrimeFactors(num_warps);
for (int factor : factors) {
bool M_divisible = (this->M % (factor * m_warp)) == 0;
bool N_divisible = (this->N % (factor * n_warp)) == 0;
if (M_divisible && N_divisible) {
// put N dimension first
// because usually n in mma
// is more smaller than m
if (this->N / n_warp >= this->M / m_warp)
n_warp *= factor;
else
m_warp *= factor;
} else if (N_divisible) {
n_warp *= factor;
} else if (M_divisible) {
m_warp *= factor;
} else {
ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
<< " with num_warps " << num_warps;
// First calculate the maximum possible warps for each dimension
int max_m_warps = this->M / 16; // Each warp needs at least 16 elements in M
int max_n_warps = this->N / 8; // Each warp needs at least 8 elements in N
// Calculate the ideal ratio of M/N warps based on the matrix dimensions
float ideal_ratio = 1.0f;
if (this->N > 0) {
ideal_ratio = static_cast<float>(this->M) / this->N;
}
// Start with a balanced initial guess
m_warp = 1;
n_warp = 1;
// Try to find the best balanced partition
int best_m = 1;
int best_n = 1;
float best_balance = std::numeric_limits<float>::max();
// Try all possible combinations that satisfy the constraints
for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
int n = num_warps / m;
if (n > max_n_warps)
continue;
if (m * n != num_warps)
continue;
// Calculate how balanced this partition is
float m_per_warp = static_cast<float>(this->M) / (m * 16);
float n_per_warp = static_cast<float>(this->N) / (n * 8);
float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);
if (balance < best_balance) {
best_balance = balance;
best_m = m;
best_n = n;
}
}
m_warp = best_m;
n_warp = best_n;
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
......
......@@ -56,6 +56,23 @@ struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
};
#endif
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
static constexpr int remainder = (N / num_warp_n) % 16;
using type = std::conditional_t<
remainder == 4 || remainder == 8 || remainder == 0,
std::conditional_t<
transpose,
std::conditional_t<
remainder == 4, SM75_U32x1_LDSM_N,
std::conditional_t<remainder == 8, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>>,
std::conditional_t<
remainder == 4, SM75_U16x2_LDSM_T,
std::conditional_t<remainder == 8, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>>>,
DefaultCopy>;
};
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
typename Enable = void>
struct OperandTraits {
......@@ -75,8 +92,7 @@ struct OperandTraits<16, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -85,8 +101,7 @@ struct OperandTraits<16, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -96,8 +111,7 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>::type;
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
......@@ -107,8 +121,7 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>::type;
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
......@@ -117,8 +130,7 @@ struct OperandTraits<32, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -127,8 +139,7 @@ struct OperandTraits<32, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -157,7 +168,7 @@ struct OperandTraits<8, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -166,8 +177,7 @@ struct OperandTraits<8, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......
......@@ -159,6 +159,23 @@ struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
};
#endif
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
static constexpr int remainder = (N / num_warp_n) % 16;
using type = std::conditional_t<
remainder == 4 || remainder == 8 || remainder == 0,
std::conditional_t<
transpose,
std::conditional_t<
remainder == 4, SM75_U32x1_LDSM_N,
std::conditional_t<remainder == 8, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>>,
std::conditional_t<
remainder == 4, SM75_U16x2_LDSM_T,
std::conditional_t<remainder == 8, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>>>,
DefaultCopy>;
};
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
typename Enable = void>
struct OperandTraits {
......@@ -178,8 +195,7 @@ struct OperandTraits<16, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -188,8 +204,7 @@ struct OperandTraits<16, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -199,8 +214,7 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>::type;
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
......@@ -210,8 +224,7 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>::type;
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
......@@ -220,8 +233,7 @@ struct OperandTraits<32, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -230,8 +242,7 @@ struct OperandTraits<32, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -260,8 +271,7 @@ struct OperandTraits<8, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -270,8 +280,7 @@ struct OperandTraits<8, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......
......@@ -217,14 +217,30 @@ struct OperandTraits {
using Copy = DefaultCopy;
};
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
static constexpr int remainder = (N / num_warp_n) % 16;
using type = std::conditional_t<
remainder == 4 || remainder == 8 || remainder == 0,
std::conditional_t<
transpose,
std::conditional_t<
remainder == 4, SM75_U32x1_LDSM_N,
std::conditional_t<remainder == 8, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>>,
std::conditional_t<
remainder == 4, SM75_U16x2_LDSM_T,
std::conditional_t<remainder == 8, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>>>,
DefaultCopy>;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -233,8 +249,7 @@ struct OperandTraits<16, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -244,8 +259,7 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>::type;
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
......@@ -255,8 +269,7 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>::type;
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
......@@ -265,8 +278,7 @@ struct OperandTraits<32, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......@@ -275,8 +287,7 @@ struct OperandTraits<32, N, K, true, num_warp_n,
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
......
......@@ -87,36 +87,73 @@ class GemmWarpPolicy(IntEnum):
if self.is_full_row():
# FullRow policy: Allocate all warps to rows.
m_warp = num_warps
assert (M % num_warps == 0), "M must be divisible by num_warps for FullRow policy"
n_warp = 1
# If M cannot be evenly divided by m_warp*16, try to split remaining warps to N
if M % (m_warp * 16) != 0:
# Calculate how many warps we can use for M
max_m_warps = M // 16
m_warp = max_m_warps
# Use remaining warps for N
n_warp = num_warps // m_warp
if n_warp == 0:
n_warp = 1
elif self.is_full_col():
# FullCol policy: Allocate all warps to columns.
m_warp = 1
n_warp = num_warps
assert (N % num_warps == 0), "N must be divisible by num_warps for FullCol policy"
# If N cannot be evenly divided by n_warp*8, try to split remaining warps to M
if N % (n_warp * 8) != 0:
# Calculate how many warps we can use for N
max_n_warps = N // 8
n_warp = max_n_warps
# Use remaining warps for M
m_warp = num_warps // n_warp
if m_warp == 0:
m_warp = 1
elif self.is_square():
# Square policy: Try to balance warps across rows and columns.
factors = self.to_prime_factors(num_warps)
for factor in factors:
M_divisible = (M % (factor * m_warp)) == 0
N_divisible = (N % (factor * n_warp)) == 0
# Assign the factor to either m_warp or n_warp based on divisibility and aspect ratio.
if M_divisible and N_divisible:
# Prefer to assign to rows if M is larger, otherwise to columns.
if N / n_warp >= M / m_warp:
n_warp *= factor
else:
m_warp *= factor
elif M_divisible:
m_warp *= factor
elif N_divisible:
n_warp *= factor
else:
# If no divisibility condition is met, raise an error.
raise ValueError(
f"Cannot compute warp partition for shape {M} x {N} with num_warps {num_warps}"
)
# First calculate the maximum possible warps for each dimension
max_m_warps = M // 16 # Each warp needs at least 16 elements in M
max_n_warps = N // 8 # Each warp needs at least 8 elements in N
# Calculate the ideal ratio of M/N warps based on the matrix dimensions
ideal_ratio = 1.0
if N > 0:
ideal_ratio = float(M) / N
# Start with a balanced initial guess
m_warp = 1
n_warp = 1
# Try to find the best balanced partition
best_m = 1
best_n = 1
best_balance = float('inf')
# Try all possible combinations that satisfy the constraints
for m in range(1, min(max_m_warps, num_warps) + 1):
n = num_warps // m
if n > max_n_warps:
continue
if m * n != num_warps:
continue
# Calculate how balanced this partition is
m_per_warp = float(M) / (m * 16)
n_per_warp = float(N) / (n * 8)
balance = abs(m_per_warp / n_per_warp - ideal_ratio)
if balance < best_balance:
best_balance = balance
best_m = m
best_n = n
m_warp = best_m
n_warp = best_n
else:
# Raise an error for unknown policies.
raise ValueError(f"Unknown GemmWarpPolicy: {self}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment