Unverified Commit 5ad296bd authored by Ma Mingfei's avatar Ma Mingfei Committed by GitHub
Browse files

Optimize prefill performance on cpu backend (#8750)

parent 9f81d741
...@@ -105,7 +105,19 @@ namespace { ...@@ -105,7 +105,19 @@ namespace {
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) #define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
// parallel routines // [NB] Parallel Routines
//
// * at::parallel_for - applies for most of generic use cases, this will be compiled
// against openmp in default torch release.
//
// * parallel_for - same function as above, can choose payload partition scheme in
// balance211.
//
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
// this one will do payload balance across 2 dimensions.
//
// grain size for each thread
constexpr int GRAIN_SIZE = 1024; constexpr int GRAIN_SIZE = 1024;
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0> template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
...@@ -113,6 +125,17 @@ inline T div_up(T x, T y) { ...@@ -113,6 +125,17 @@ inline T div_up(T x, T y) {
return (x + y - 1) / y; return (x + y - 1) / y;
} }
// you can only use at::get_thread_num() with at::parallel_for()
// as it is lazy initialized, otherwise it will always return 0.
inline int get_thread_num() {
#if defined(_OPENMP)
return omp_get_thread_num();
#else
return 0;
#endif
}
// balance payload across each thread
template <typename T> template <typename T>
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
#if 0 #if 0
...@@ -153,6 +176,100 @@ inline void parallel_for(int n, const func_t& f) { ...@@ -153,6 +176,100 @@ inline void parallel_for(int n, const func_t& f) {
#endif #endif
} }
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) {
int actual_nth = at::get_num_threads();
if (m == 1) {
return actual_nth;
}
return std::max(1, (actual_nth >> 1) * 2);
}
template <typename func_t>
inline void parallel_2d(int m, int n, const func_t& f) {
// make sure we have even num_threads
int nth = adjust_num_threads(m);
// [NOTE] thread blocking:
//
// 1) prefer square block per thread
// 2) use even number of CPU cores
// 3) use all `num_threads` cores
//
// we have:
// TM * TN = T
// BM / TM = BN / TN
// then:
// TM = ((BM / BN) * T) ^ 0.5
//
float r = float(m) / n;
int nth_m = std::ceil(std::sqrt(r * nth));
int nth_n = 1;
for (; nth_m > 0; --nth_m) {
nth_n = nth / nth_m;
if (nth_m * nth_n == nth) {
break;
}
}
#if defined(_OPENMP)
#pragma omp parallel num_threads(nth)
{
int ith = omp_get_thread_num();
int ith_m = ith / nth_n;
int ith_n = ith % nth_n;
int thread_block_m = div_up(m, nth_m);
int thread_block_n = div_up(n, nth_n);
int begin_m = ith_m * thread_block_m;
int end_m = std::min(m, begin_m + thread_block_m);
int begin_n = ith_n * thread_block_n;
int end_n = std::min(n, begin_n + thread_block_n);
f(begin_m, end_m, begin_n, end_n);
}
#else
f(0, m, 0, n);
#endif
}
// limit max cache blocks
// when we need to do pre-unpack for weights, e.g. fp8
#define MAX_CACHE_BLOCK_SIZE 4
template <typename T>
inline int get_cache_blocks(int chunk_size) {
// L2 2MB and ratio of 50%
const int L2_size = 2048 * 1024 >> 1;
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
}
template <>
inline int get_cache_blocks<at::Float8_e4m3fn>(int chunk_size) {
// fp8 uses bf16 as accumulate type
int cache_block_size = get_cache_blocks<at::BFloat16>(chunk_size);
return std::min(MAX_CACHE_BLOCK_SIZE, cache_block_size);
}
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
template <typename T, typename func_t>
inline void loop_2d(int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1, int64_t chunk_size, const func_t& f) {
// get number of blocks for L2 in most inner loop
int64_t cache_blocks_nb = get_cache_blocks<T>(chunk_size);
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
for (int64_t nbb = nb0; nbb < nb1; nbb += cache_blocks_nb) {
for (int64_t mb = mb0; mb < mb1; ++mb) {
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, nb1); ++nb) {
f(mb, nb, nb - nbb);
}
}
}
}
// data indexing for dimension collapse // data indexing for dimension collapse
template <typename T> template <typename T>
inline T data_index_init(T offset) { inline T data_index_init(T offset) {
......
...@@ -254,7 +254,7 @@ void tinygemm_kernel( ...@@ -254,7 +254,7 @@ void tinygemm_kernel(
return; return;
} }
// pattern: 1-4-16 // pattern: 1-4-16, N = 16, 32, 48, 64
constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64; constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M); const int64_t MB = div_up(M, BLOCK_M);
...@@ -268,35 +268,59 @@ void tinygemm_kernel( ...@@ -268,35 +268,59 @@ void tinygemm_kernel(
switch (mb_size << 4 | nb_size >> 4) { switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1 // mb_size = 1
case 0x11:
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
break;
case 0x12: case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32); LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break; break;
case 0x13:
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
break;
case 0x14: case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64); LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break; break;
// mb_size = 2 // mb_size = 2
case 0x21:
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
break;
case 0x22: case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32); LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break; break;
case 0x23:
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
break;
case 0x24: case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64); LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break; break;
// mb_size = 3 // mb_size = 3
case 0x31:
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
break;
case 0x32: case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32); LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break; break;
case 0x33:
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
break;
case 0x34: case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64); LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break; break;
// mb_size = 4 // mb_size = 4
case 0x41:
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
break;
case 0x42: case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32); LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break; break;
case 0x43:
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
break;
case 0x44: case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64); LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break; break;
default: default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size);
} }
} }
} }
...@@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl( ...@@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl(
const int64_t MB = div_up(M, BLOCK_M); const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N); const int64_t NB = div_up(N, BLOCK_N);
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small const bool use_brgemm = can_use_brgemm<scalar_t>(M);
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>) || (N < 64);
// parallel on [MB, NB] // parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate // for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) { loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
UNUSED(i);
int64_t mb_start = mb * BLOCK_M; int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M); int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N; int64_t nb_start = nb * BLOCK_N;
...@@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl( ...@@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl(
/* ldb */ nb_size, /* ldb */ nb_size,
/* ldc */ out_strideM, /* ldc */ out_strideM,
/* brg */ use_brgemm); /* brg */ use_brgemm);
});
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
......
...@@ -27,10 +27,10 @@ template <> ...@@ -27,10 +27,10 @@ template <>
inline bool can_use_brgemm<at::Half>(int M) { inline bool can_use_brgemm<at::Half>(int M) {
return true; return true;
} }
// TODO: add u8s8 brgemm, this requires PyTorch 2.7 // this requires PyTorch 2.7 or above
template <> template <>
inline bool can_use_brgemm<int8_t>(int M) { inline bool can_use_brgemm<int8_t>(int M) {
return false; return M > 4;
} }
template <> template <>
...@@ -198,4 +198,5 @@ void tinygemm_kernel( ...@@ -198,4 +198,5 @@ void tinygemm_kernel(
int64_t ldb, int64_t ldb,
int64_t ldc, int64_t ldc,
bool brg, bool brg,
int64_t block_size_K); int64_t block_size_K,
bool do_unpack = true);
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
#include "gemm.h" #include "gemm.h"
#include "vec.h" #include "vec.h"
// we use 4x32 for BLOCK_M
#define BLOCK_SIZE_M_SCALE 4
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
...@@ -250,7 +247,8 @@ struct brgemm { ...@@ -250,7 +247,8 @@ struct brgemm {
int K, int K,
int lda, int lda,
int ldb, int ldb,
int ldc) { int ldc,
bool do_unpack = true) {
TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
} }
}; };
...@@ -270,17 +268,20 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> { ...@@ -270,17 +268,20 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
int K, int K,
int lda, int lda,
int ldb, int ldb,
int ldc) { int ldc,
bool do_unpack = true) {
constexpr int BLOCK_N = block_size_n(); constexpr int BLOCK_N = block_size_n();
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const int ldb_tmp = BLOCK_N; const int ldb_tmp = BLOCK_N;
for (int k = 0; k < K; k += BLOCK_K) { if (do_unpack) {
int kb_size = std::min(BLOCK_K, K - k); for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
}
} }
at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
...@@ -312,9 +313,11 @@ void tinygemm_kernel( ...@@ -312,9 +313,11 @@ void tinygemm_kernel(
int64_t ldb, int64_t ldb,
int64_t ldc, int64_t ldc,
bool brg, bool brg,
int64_t block_size_K) { int64_t block_size_K,
bool do_unpack = true) {
if (brg) { if (brg) {
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack);
return; return;
} }
...@@ -366,7 +369,7 @@ void fp8_scaled_mm_kernel_impl( ...@@ -366,7 +369,7 @@ void fp8_scaled_mm_kernel_impl(
int64_t block_size_N, int64_t block_size_N,
int64_t block_size_K, int64_t block_size_K,
int64_t buffer_size_per_thread) { int64_t buffer_size_per_thread) {
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n(); constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M); const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N); const int64_t NB = div_up(N, BLOCK_N);
...@@ -378,16 +381,12 @@ void fp8_scaled_mm_kernel_impl( ...@@ -378,16 +381,12 @@ void fp8_scaled_mm_kernel_impl(
// parallel on [MB, NB] // parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int64_t mb{0}, nb{0}; int tid = get_thread_num();
data_index_init(begin, mb, MB, nb, NB);
int tid = at::get_thread_num();
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K));
for (int64_t i = begin; i < end; ++i) { loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
UNUSED(i);
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
int64_t mb_start = mb * BLOCK_M; int64_t mb_start = mb * BLOCK_M;
...@@ -395,11 +394,14 @@ void fp8_scaled_mm_kernel_impl( ...@@ -395,11 +394,14 @@ void fp8_scaled_mm_kernel_impl(
int64_t nb_start = nb * BLOCK_N; int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N); int64_t nb_size = std::min(N - nb_start, BLOCK_N);
// only do unpacking for the first row
bool do_unpack = (mb == mb0);
tinygemm_kernel<scalar_t, has_bias>( tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM, /* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
/* C */ out + mb_start * out_strideM + nb_start, /* C */ out + mb_start * out_strideM + nb_start,
/* Btmp */ Btmp, /* Btmp */ Btmp + nb_offset * BLOCK_N * K,
/* Ctmp */ Ctmp, /* Ctmp */ Ctmp,
/* scale */ scale_ptr, /* scale */ scale_ptr,
/* bias */ bias + nb_start, /* bias */ bias + nb_start,
...@@ -410,11 +412,9 @@ void fp8_scaled_mm_kernel_impl( ...@@ -410,11 +412,9 @@ void fp8_scaled_mm_kernel_impl(
/* ldb */ nb_size, /* ldb */ nb_size,
/* ldc */ out_strideM, /* ldc */ out_strideM,
/* brg */ use_brgemm, /* brg */ use_brgemm,
/* block_size_K */ block_size_K); /* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
// move to the next index });
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
...@@ -441,8 +441,10 @@ void tinygemm_kernel( ...@@ -441,8 +441,10 @@ void tinygemm_kernel(
int64_t ldb, int64_t ldb,
int64_t ldc, int64_t ldc,
bool brg, bool brg,
int64_t block_size_K) { int64_t block_size_K,
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); bool do_unpack) {
tinygemm_kernel<scalar_t, false>(
A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack);
} }
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ #define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
...@@ -460,7 +462,8 @@ void tinygemm_kernel( ...@@ -460,7 +462,8 @@ void tinygemm_kernel(
int64_t ldb, \ int64_t ldb, \
int64_t ldc, \ int64_t ldc, \
bool brg, \ bool brg, \
int64_t block_size_K) int64_t block_size_K, \
bool do_unpack)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
...@@ -495,7 +498,7 @@ at::Tensor fp8_scaled_mm_cpu( ...@@ -495,7 +498,7 @@ at::Tensor fp8_scaled_mm_cpu(
int64_t block_size_N = block_size[0]; int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1]; int64_t block_size_K = block_size[1];
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n(); constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
...@@ -523,7 +526,7 @@ at::Tensor fp8_scaled_mm_cpu( ...@@ -523,7 +526,7 @@ at::Tensor fp8_scaled_mm_cpu(
// Btmp : [T, BLOCK_N * K] // Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N] // Ctmp : [T, BLOCK_M * BLOCK_N]
int num_threads = at::get_num_threads(); int num_threads = at::get_num_threads();
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
......
...@@ -4,6 +4,61 @@ ...@@ -4,6 +4,61 @@
namespace { namespace {
template <typename scalar_t, bool has_bias, int BLOCK_N>
struct scale_C {
static inline void apply(
scalar_t* __restrict__ C,
const int32_t* __restrict__ Ctmp,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
float As,
const float* __restrict__ Bs) {
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_N>
struct scale_C<at::BFloat16, has_bias, BLOCK_N> {
static inline void apply(
at::BFloat16* __restrict__ C,
const int32_t* __restrict__ Ctmp,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
float As,
const float* __restrict__ Bs) {
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512 vc[COLS];
__m512 vd0 = _mm512_set1_ps(As);
auto compute = [&](auto col) {
__m512 vd1 = _mm512_loadu_ps(Bs + col * 16);
__m512i vcomp = _mm512_loadu_si512(Bcomp + col * 16);
__m512i vc32 = _mm512_loadu_si512(Ctmp + col * 16);
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp));
if constexpr (has_bias) {
__m512 vbias = _mm512_loadu_ps(bias + col * 16);
vc[col] = _mm512_fmadd_ps(_mm512_mul_ps(vc[col], vd0), vd1, vbias);
} else {
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vd0), vd1);
}
};
Unroll<COLS>{}(compute);
auto storec = [&](auto col) {
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc[col + 1], vc[col + 0])));
}
};
Unroll<COLS>{}(storec);
}
};
#endif
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N> template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn { struct tinygemm_kernel_nn {
static inline void apply( static inline void apply(
...@@ -169,6 +224,17 @@ void tinygemm_kernel( ...@@ -169,6 +224,17 @@ void tinygemm_kernel(
// B compensation // B compensation
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K); const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
if (brg) {
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// apply compensation and scale
for (int64_t m = 0; m < M; ++m) {
scale_C<scalar_t, has_bias, BLOCK_N>::apply(C + m * ldc, Ctmp + m * BLOCK_N, Bcomp, bias, As[m], Bs);
}
return;
}
// pattern: 1-4-16 // pattern: 1-4-16
constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64; constexpr int64_t BLOCK_N = 64;
...@@ -233,22 +299,17 @@ void int8_scaled_mm_kernel_impl( ...@@ -233,22 +299,17 @@ void int8_scaled_mm_kernel_impl(
const int64_t MB = div_up(M, BLOCK_M); const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N); const int64_t NB = div_up(N, BLOCK_N);
// TODO: brgemm u8s8 depends on PyTorch 2.7 release. const bool use_brgemm = can_use_brgemm<int8_t>(M);
const bool use_brgemm = false;
// K + 4 after compensation // K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K); const int64_t packed_row_size = get_row_size<int8_t>(K);
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use int32_t for accumulate // for brgemm, use int32_t for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
for (int i = begin; i < end; ++i) { loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
UNUSED(i);
int mb_start = mb * BLOCK_M; int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M); int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N; int nb_start = nb * BLOCK_N;
...@@ -269,10 +330,7 @@ void int8_scaled_mm_kernel_impl( ...@@ -269,10 +330,7 @@ void int8_scaled_mm_kernel_impl(
/* ldb */ nb_size, /* ldb */ nb_size,
/* ldc */ N, /* ldc */ N,
/* brg */ use_brgemm); /* brg */ use_brgemm);
});
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
......
...@@ -579,36 +579,31 @@ void fused_experts_kernel_impl( ...@@ -579,36 +579,31 @@ void fused_experts_kernel_impl(
const int64_t stride_e = 2 * N * K; const int64_t stride_e = 2 * N * K;
const int64_t stride_n = K; const int64_t stride_n = K;
int64_t avg_M = std::max(int64_t(1), M * topk / E);
const bool use_brgemm = can_use_brgemm<scalar_t>(avg_M);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm // here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers // get local pointers
int tid = at::get_thread_num(); int tid = get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
bool is_brgemm_used = false; loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
// nb_upper from top half and nb_lower from bottom half
for (int64_t i = begin; i < end; ++i) { int64_t nb_upper = nb, nb_lower = nb + NB;
int64_t mb = i / NB; int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format // B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb]; int32_t expert_id = expert_ids[mb];
const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n;
const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n;
// 1.a load A // 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M; const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb]; int64_t m_size = offsets[mb + 1] - offsets[mb];
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
for (int64_t m = 0; m < m_size; ++m) { for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk; int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, input + index * K, K); copy_stub(A + m * K, input + index * K, K);
...@@ -659,9 +654,9 @@ void fused_experts_kernel_impl( ...@@ -659,9 +654,9 @@ void fused_experts_kernel_impl(
/* ldb */ n_size, /* ldb */ n_size,
/* ldc */ N); /* ldc */ N);
} }
} });
if (is_brgemm_used) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
} }
}); });
...@@ -676,24 +671,16 @@ void fused_experts_kernel_impl( ...@@ -676,24 +671,16 @@ void fused_experts_kernel_impl(
const int64_t stride_oc = IC; const int64_t stride_oc = IC;
// parallel on [MB2, NB2] // parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers // get local pointers
int tid = at::get_thread_num(); int tid = get_thread_num();
// we won't be using C1 for gemm2 // we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
bool is_brgemm_used = false; loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = offsets[mb + 1] - offsets[mb]; int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A ptr from ic1 of [M * topk, N] in sorted order // A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again // so as to avoid copy A to tmp buffer again
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
...@@ -736,9 +723,9 @@ void fused_experts_kernel_impl( ...@@ -736,9 +723,9 @@ void fused_experts_kernel_impl(
float weight = topk_weights[index]; float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
} }
} });
if (is_brgemm_used) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
} }
}); });
...@@ -776,36 +763,27 @@ void shared_expert_kernel_impl( ...@@ -776,36 +763,27 @@ void shared_expert_kernel_impl(
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
const int64_t stride_n = K; const int64_t stride_n = K;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm // here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers // get local pointers
int tid = at::get_thread_num(); int tid = get_thread_num();
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
bool is_brgemm_used = false; loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
// nb_upper from top half and nb_lower from bottom half
for (int64_t i = begin; i < end; ++i) { int64_t nb_upper = nb, nb_lower = nb + NB;
int64_t mb = i / NB; int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
// int64_t mb_start = mb * BLOCK_M;
// int64_t mb_size = std::min(M - mb_start, BLOCK_M);
// A shape [m_size, K] // A shape [m_size, K]
const scalar_t* A = input + mb * BLOCK_M * K; const scalar_t* A = input + mb * BLOCK_M * K;
// B shape [K, n_size] in vnni format // B shape [K, n_size] in vnni format
const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; const scalar_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n;
const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; const scalar_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n;
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
if (use_brgemm) { if (use_brgemm) {
// 1.b gemm: C0 = A @ B0 // 1.b gemm: C0 = A @ B0
...@@ -850,9 +828,9 @@ void shared_expert_kernel_impl( ...@@ -850,9 +828,9 @@ void shared_expert_kernel_impl(
/* ldb */ n_size, /* ldb */ n_size,
/* ldc */ N); /* ldc */ N);
} }
} });
if (is_brgemm_used) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
} }
}); });
...@@ -866,24 +844,16 @@ void shared_expert_kernel_impl( ...@@ -866,24 +844,16 @@ void shared_expert_kernel_impl(
const int64_t stride_oc = IC; const int64_t stride_oc = IC;
// parallel on [MB2, NB2] // parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers // get local pointers
int tid = at::get_thread_num(); int tid = get_thread_num();
// we won't be using C1 for gemm2 // we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
bool is_brgemm_used = false; loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A shape [m_size, IC] // A shape [m_size, IC]
const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N; const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N;
...@@ -922,9 +892,9 @@ void shared_expert_kernel_impl( ...@@ -922,9 +892,9 @@ void shared_expert_kernel_impl(
for (int64_t m = 0; m < m_size; ++m) { for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
} }
} });
if (is_brgemm_used) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
} }
}); });
...@@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu( ...@@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu(
// //
// for fp8 w8a16: // for fp8 w8a16:
// 7. intermediate_cache0 : [M * topk, 2N] // 7. intermediate_cache0 : [M * topk, 2N]
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)] // 8. B_tmp : [T, MAX_CACHE_BLOCK_SIZE, BLOCK_N, std::max(K, N)]
// //
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
...@@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu( ...@@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu(
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
} }
if (use_fp8_w8a16) { if (use_fp8_w8a16) {
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2; buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N) * 2;
} }
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
...@@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu( ...@@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu(
// //
// for fp8 w8a16: // for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N] // 5. intermediate_cache0 : [M, 2N]
// 6. B_tmp: [T, BLOCK_M, max(K, N)] // 6. B_tmp: [T, MAX_CACHE_BLOCK_SIZE, BLOCK_M, max(K, N)]
// //
int num_threads = at::get_num_threads(); int num_threads = at::get_num_threads();
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
...@@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu( ...@@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu(
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
} }
if (use_fp8_w8a16) { if (use_fp8_w8a16) {
buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2; buffer_size_nbytes += M * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_M * std::max(K, N) * 2;
} }
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
......
...@@ -174,18 +174,18 @@ void fused_experts_fp8_kernel_impl( ...@@ -174,18 +174,18 @@ void fused_experts_fp8_kernel_impl(
const int64_t stride_e = 2 * N * K; const int64_t stride_e = 2 * N * K;
const int64_t stride_n = K; const int64_t stride_n = K;
int64_t avg_M = std::max(int64_t(1), M * topk / E);
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(avg_M);
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm // here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers // get local pointers
int tid = at::get_thread_num(); int tid = get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
bool is_brgemm_used = false; loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format // B shape [K, n_size] in vnni format
...@@ -194,13 +194,14 @@ void fused_experts_fp8_kernel_impl( ...@@ -194,13 +194,14 @@ void fused_experts_fp8_kernel_impl(
const float* __restrict__ Bs = const float* __restrict__ Bs =
w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
// do unpacking for the first row or a new expert
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
// 1.a load A // 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M; const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb]; int64_t m_size = offsets[mb + 1] - offsets[mb];
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
for (int64_t m = 0; m < m_size; ++m) { for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk; int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, input + index * K, K); copy_stub(A + m * K, input + index * K, K);
...@@ -211,7 +212,7 @@ void fused_experts_fp8_kernel_impl( ...@@ -211,7 +212,7 @@ void fused_experts_fp8_kernel_impl(
/* A */ A, /* A */ A,
/* B */ B, /* B */ B,
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N, /* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), /* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs, /* scale */ Bs,
/* M */ m_size, /* M */ m_size,
...@@ -221,10 +222,11 @@ void fused_experts_fp8_kernel_impl( ...@@ -221,10 +222,11 @@ void fused_experts_fp8_kernel_impl(
/* ldb */ n_size, /* ldb */ n_size,
/* ldc */ 2 * N, /* ldc */ 2 * N,
/* brg */ use_brgemm, /* brg */ use_brgemm,
/* block_size_K */ block_size_K); /* block_size_K */ block_size_K,
} /* do_unpack */ do_unpack);
});
if (is_brgemm_used) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
} }
}); });
...@@ -248,22 +250,14 @@ void fused_experts_fp8_kernel_impl( ...@@ -248,22 +250,14 @@ void fused_experts_fp8_kernel_impl(
const int64_t stride_oc = IC; const int64_t stride_oc = IC;
// parallel on [MB2, NB2] // parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = at::get_thread_num(); int tid = get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
bool is_brgemm_used = false; loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = offsets[mb + 1] - offsets[mb]; int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A ptr from ic1 of [M * topk, N] in sorted order // A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again // so as to avoid copy A to tmp buffer again
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
...@@ -275,11 +269,15 @@ void fused_experts_fp8_kernel_impl( ...@@ -275,11 +269,15 @@ void fused_experts_fp8_kernel_impl(
const float* __restrict__ Bs = const float* __restrict__ Bs =
w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
// do unpacking for the first row or a new expert
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
tinygemm_kernel<scalar_t>( tinygemm_kernel<scalar_t>(
/* A */ A, /* A */ A,
/* B */ B, /* B */ B,
/* C */ C, /* C */ C,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), /* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs, /* scale */ Bs,
/* M */ m_size, /* M */ m_size,
...@@ -289,7 +287,8 @@ void fused_experts_fp8_kernel_impl( ...@@ -289,7 +287,8 @@ void fused_experts_fp8_kernel_impl(
/* ldb */ n_size, /* ldb */ n_size,
/* ldc */ BLOCK_N, /* ldc */ BLOCK_N,
/* brg */ use_brgemm, /* brg */ use_brgemm,
/* block_size_K */ block_size_K); /* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
// 2.b copy from C to ic2 in original order // 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32 // and also mul topk_weights in float32
...@@ -298,9 +297,9 @@ void fused_experts_fp8_kernel_impl( ...@@ -298,9 +297,9 @@ void fused_experts_fp8_kernel_impl(
float weight = topk_weights[index]; float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
} }
} });
if (is_brgemm_used) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
} }
}); });
...@@ -374,20 +373,23 @@ void shared_expert_fp8_kernel_impl( ...@@ -374,20 +373,23 @@ void shared_expert_fp8_kernel_impl(
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M); const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
int tid = at::get_thread_num();
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
for (int64_t i = begin; i < end; ++i) { loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t mb = i / NB;
int64_t nb = i % NB;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
// do unpacking for the first row
bool do_unpack = (mb == mb0);
tinygemm_kernel<scalar_t>( tinygemm_kernel<scalar_t>(
/* A */ input + mb * BLOCK_M * K, /* A */ input + mb * BLOCK_M * K,
/* B */ packed_w1 + nb * BLOCK_N * K, /* B */ packed_w1 + nb * BLOCK_N * K,
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), /* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size, /* M */ m_size,
...@@ -397,8 +399,9 @@ void shared_expert_fp8_kernel_impl( ...@@ -397,8 +399,9 @@ void shared_expert_fp8_kernel_impl(
/* ldb */ n_size, /* ldb */ n_size,
/* ldc */ 2 * N, /* ldc */ 2 * N,
/* brg */ use_brgemm, /* brg */ use_brgemm,
/* block_size_K */ block_size_K); /* block_size_K */ block_size_K,
} /* do_unpack */ do_unpack);
});
if (use_brgemm) { if (use_brgemm) {
at::native::cpublas::brgemm_release(); at::native::cpublas::brgemm_release();
...@@ -421,22 +424,23 @@ void shared_expert_fp8_kernel_impl( ...@@ -421,22 +424,23 @@ void shared_expert_fp8_kernel_impl(
scale_size_K = div_up(N, block_size_K); scale_size_K = div_up(N, block_size_K);
// parallel on [MB2, NB2] // parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = at::get_thread_num(); int tid = get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
for (int64_t i = begin; i < end; ++i) { loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// do unpacking for the first row
bool do_unpack = (mb == mb0);
// 2.a gemm: C = A @ B // 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>( tinygemm_kernel<scalar_t>(
/* A */ ic1 + mb * BLOCK_M * N, /* A */ ic1 + mb * BLOCK_M * N,
/* B */ packed_w2 + nb * BLOCK_N * N, /* B */ packed_w2 + nb * BLOCK_N * N,
/* C */ C, /* C */ C,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), /* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size, /* M */ m_size,
...@@ -446,7 +450,8 @@ void shared_expert_fp8_kernel_impl( ...@@ -446,7 +450,8 @@ void shared_expert_fp8_kernel_impl(
/* ldb */ n_size, /* ldb */ n_size,
/* ldc */ BLOCK_N, /* ldc */ BLOCK_N,
/* brg */ use_brgemm, /* brg */ use_brgemm,
/* block_size_K */ block_size_K); /* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
// 2.b copy from C to output and add fused_experts_out // 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
...@@ -454,7 +459,7 @@ void shared_expert_fp8_kernel_impl( ...@@ -454,7 +459,7 @@ void shared_expert_fp8_kernel_impl(
for (int64_t m = 0; m < m_size; ++m) { for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
} }
} });
}); });
if (use_brgemm) { if (use_brgemm) {
......
...@@ -109,6 +109,120 @@ inline void add_mul_stub( ...@@ -109,6 +109,120 @@ inline void add_mul_stub(
} }
} }
template <typename scalar_t, int BLOCK_N>
inline void silu_and_mul(
scalar_t* __restrict__ C,
const int32_t* __restrict__ C0, // x: x0, x1
const int32_t* __restrict__ C1, // y: y0, y1
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
const int32_t* __restrict__ Bcomp0,
const int32_t* __restrict__ Bcomp1,
int64_t m_size,
int64_t N) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512 vc0[COLS];
__m512 vc1[COLS];
__m512i vcomp0[COLS];
__m512i vcomp1[COLS];
__m512 vas;
__m512 vbs0[COLS];
__m512 vbs1[COLS];
auto load_scale_and_comp = [&](auto col) {
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
};
Unroll<COLS>{}(load_scale_and_comp);
auto scalec = [&](auto col, int64_t m) {
// update As
vas = _mm512_set1_ps(As[m]);
// C = As * (C - Bcomp) * Bs
__m512i vc32_0 = _mm512_loadu_si512(C0 + m * BLOCK_N + col * 16);
__m512i vc32_1 = _mm512_loadu_si512(C1 + m * BLOCK_N + col * 16);
vc0[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_0, vcomp0[col]));
vc1[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_1, vcomp1[col]));
vc0[col] = _mm512_mul_ps(_mm512_mul_ps(vc0[col], vas), vbs0[col]);
vc1[col] = _mm512_mul_ps(_mm512_mul_ps(vc1[col], vas), vbs1[col]);
};
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec one = fVec(1.f);
auto silu_and_mul = [&](auto col) {
fVec x = fVec(vc0[col]);
fVec y = fVec(vc1[col]);
x = x / (one + x.neg().exp_u20());
vc0[col] = x * y;
};
auto storec = [&](auto col, int64_t m) {
if constexpr (col % 2 == 0) {
fVec x0 = fVec(vc0[col + 0]);
fVec x1 = fVec(vc0[col + 1]);
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(C + m * N + col * 16);
}
};
for (int64_t m = 0; m < m_size; ++m) {
Unroll<COLS>{}(scalec, m);
Unroll<COLS>{}(silu_and_mul);
Unroll<COLS>{}(storec, m);
}
#else
TORCH_CHECK(false, "silu_and_mul: scalar path not implemented!");
#endif
}
template <int BLOCK_N>
inline void scale_C(
float* __restrict__ C,
const int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
int64_t m_size) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512 vc[COLS];
__m512i vcomp[COLS];
__m512 vas;
__m512 vbs[COLS];
auto load_scale_and_comp = [&](auto col) {
vcomp[col] = _mm512_loadu_si512(Bcomp + col * 16);
vbs[col] = _mm512_loadu_ps(Bs + col * 16);
};
Unroll<COLS>{}(load_scale_and_comp);
auto scalec = [&](auto col, int64_t m) {
// update As
vas = _mm512_set1_ps(As[m]);
// C = As * (C - Bcomp) * Bs
__m512i vc32 = _mm512_loadu_si512(Ctmp + m * BLOCK_N + col * 16);
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp[col]));
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vas), vbs[col]);
_mm512_storeu_ps(C + m * BLOCK_N + col * 16, vc[col]);
};
for (int64_t m = 0; m < m_size; ++m) {
Unroll<COLS>{}(scalec, m);
}
#else
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
#endif
}
/// gemm for w13 /// gemm for w13
template <typename scalar_t, int BLOCK_M, int BLOCK_N> template <typename scalar_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni { struct tinygemm_kernel_vnni {
...@@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl( ...@@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl(
const int64_t stride_e = 2 * N * packed_K; const int64_t stride_e = 2 * N * packed_K;
const int64_t stride_n = packed_K; const int64_t stride_n = packed_K;
int64_t avg_M = std::max(int64_t(1), M * topk / E);
const bool use_brgemm = can_use_brgemm<int8_t>(avg_M);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm // here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers // get local pointers
int tid = at::get_thread_num(); int tid = get_thread_num();
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N;
int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
alignas(64) float As[BLOCK_M]; alignas(64) float As[BLOCK_M];
for (int64_t i = begin; i < end; ++i) { loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t mb = i / NB; // nb_upper from top half and nb_lower from bottom half
int64_t nb = i % NB; int64_t nb_upper = nb, nb_lower = nb + NB;
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format // B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb]; int32_t expert_id = expert_ids[mb];
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb_upper * BLOCK_N;
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb_lower * BLOCK_N;
// 1.a load A // 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M; const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
...@@ -548,22 +665,62 @@ void fused_experts_int8_kernel_impl( ...@@ -548,22 +665,62 @@ void fused_experts_int8_kernel_impl(
As[m] = As_tmp[index]; As[m] = As_tmp[index];
} }
// fused 1.b: silu_and_mul(A @ B0, A @ B1) if (use_brgemm) {
const int64_t offset = offsets[mb]; // 1.b gemm: C0 = A @ B0
tinygemm_kernel( at::native::cpublas::brgemm(
/* A */ A, /* M */ m_size,
/* B0 */ B0, /* N */ n_size,
/* B1 */ B1, /* K */ K,
/* C */ ic1 + offset * N + nb * BLOCK_N, /* lda */ K,
/* As */ As, /* ldb */ n_size,
/* Bs0 */ Bs0, /* ldc */ BLOCK_N,
/* Bs1 */ Bs1, /* add_C */ false,
/* M */ m_size, /* A */ A,
/* N */ n_size, /* B */ B0,
/* K */ K, /* C */ C0);
/* lda */ K,
/* ldb */ n_size, // 1.c gemm: C1 = A @ B1
/* ldc */ N); at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B1,
/* C */ C1);
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
// 1.d silu and mul
const int64_t offset = offsets[mb];
silu_and_mul<scalar_t, BLOCK_N>(
ic1 + offset * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N);
} else {
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
const int64_t offset = offsets[mb];
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + offset * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
} }
}); });
...@@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl( ...@@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl(
const int64_t stride_oc = packed_N; const int64_t stride_oc = packed_N;
// parallel on [MB2, NB2] // parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers // get local pointers
int tid = at::get_thread_num(); int tid = get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N);
for (int64_t i = begin; i < end; ++i) { loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = offsets[mb + 1] - offsets[mb]; int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
...@@ -609,18 +763,36 @@ void fused_experts_int8_kernel_impl( ...@@ -609,18 +763,36 @@ void fused_experts_int8_kernel_impl(
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
// 2.a gemm: C = A @ B // 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>( if (use_brgemm) {
/* A */ A, at::native::cpublas::brgemm(
/* B */ B, /* M */ m_size,
/* C */ C, /* N */ n_size,
/* As */ As, /* K */ IC,
/* Bs */ Bs, /* lda */ IC,
/* M */ m_size, /* ldb */ n_size,
/* N */ n_size, /* ldc */ BLOCK_N,
/* K */ IC, /* add_C */ false,
/* lda */ IC, /* A */ A,
/* ldb */ n_size, /* B */ B,
/* ldc */ BLOCK_N); /* C */ C32);
// apply scales
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC);
scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size);
} else {
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
}
// 2.b copy from C to ic2 in original order // 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32 // and also mul topk_weights in float32
...@@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl( ...@@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl(
float weight = topk_weights[index]; float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
} }
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
} }
}); });
...@@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl( ...@@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl(
const int64_t packed_N = get_row_size<int8_t>(N); const int64_t packed_N = get_row_size<int8_t>(N);
const int64_t stride_n = packed_K; const int64_t stride_n = packed_K;
const bool use_brgemm = can_use_brgemm<int8_t>(M);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm // here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
for (int64_t i = begin; i < end; ++i) { // get local pointers
int64_t mb = i / NB; int tid = get_thread_num();
int64_t nb = i % NB; int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N;
int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB; loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); // nb_upper from top half and nb_lower from bottom half
int64_t nb_upper = nb, nb_lower = nb + NB;
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
// A shape [m_size, K] // A shape [m_size, K]
...@@ -724,26 +904,65 @@ void shared_expert_int8_kernel_impl( ...@@ -724,26 +904,65 @@ void shared_expert_int8_kernel_impl(
const float* As = As_tmp + mb * BLOCK_M; const float* As = As_tmp + mb * BLOCK_M;
// B shape [K, n_size] in vnni format // B shape [K, n_size] in vnni format
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; const int8_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; const int8_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; const float* __restrict__ Bs0 = w1s + nb_upper * BLOCK_N;
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; const float* __restrict__ Bs1 = w1s + nb_lower * BLOCK_N;
// fused 1.b: silu_and_mul(A @ B0, A @ B1) if (use_brgemm) {
tinygemm_kernel( // 1.b gemm: C0 = A @ B0
/* A */ A, at::native::cpublas::brgemm(
/* B0 */ B0, /* M */ m_size,
/* B1 */ B1, /* N */ n_size,
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, /* K */ K,
/* As */ As, /* lda */ K,
/* Bs0 */ Bs0, /* ldb */ n_size,
/* Bs1 */ Bs1, /* ldc */ BLOCK_N,
/* M */ m_size, /* add_C */ false,
/* N */ n_size, /* A */ A,
/* K */ K, /* B */ B0,
/* lda */ K, /* C */ C0);
/* ldb */ n_size,
/* ldc */ N); // 1.c gemm: C1 = A @ B1
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B1,
/* C */ C1);
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
// 1.d silu and mul
silu_and_mul<scalar_t, BLOCK_N>(
ic1 + mb * BLOCK_M * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N);
} else {
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
} }
}); });
...@@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl( ...@@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl(
const int64_t stride_oc = packed_N; const int64_t stride_oc = packed_N;
// parallel on [MB2, NB2] // parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers // get local pointers
int tid = at::get_thread_num(); int tid = get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N);
for (int64_t i = begin; i < end; ++i) { loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
...@@ -784,19 +1000,37 @@ void shared_expert_int8_kernel_impl( ...@@ -784,19 +1000,37 @@ void shared_expert_int8_kernel_impl(
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
const float* __restrict__ Bs = w2s + nb * BLOCK_N; const float* __restrict__ Bs = w2s + nb * BLOCK_N;
// 2.a gemm: C = A @ B if (use_brgemm) {
tinygemm_kernel<scalar_t>( at::native::cpublas::brgemm(
/* A */ A, /* M */ m_size,
/* B */ B, /* N */ n_size,
/* C */ C, /* K */ IC,
/* As */ As, /* lda */ IC,
/* Bs */ Bs, /* ldb */ n_size,
/* M */ m_size, /* ldc */ BLOCK_N,
/* N */ n_size, /* add_C */ false,
/* K */ IC, /* A */ A,
/* lda */ IC, /* B */ B,
/* ldb */ n_size, /* C */ C32);
/* ldc */ BLOCK_N);
// apply scales
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC);
scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size);
} else {
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
}
// 2.b copy from C to output and add fused_experts_out // 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
...@@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl( ...@@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl(
for (int64_t m = 0; m < m_size; ++m) { for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
} }
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
} }
}); });
} }
......
...@@ -100,8 +100,7 @@ void segment_gemm_kernel_impl( ...@@ -100,8 +100,7 @@ void segment_gemm_kernel_impl(
const int64_t NB1 = div_up(N1, BLOCK_N); const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1; const int64_t NB = NB0 + NB1;
// TODO: brgemm u8s8 depends on PyTorch 2.7 release. const bool use_brgemm = can_use_brgemm<int8_t>(M);
const bool use_brgemm = false;
// K + 4 after compensation // K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K); const int64_t packed_row_size = get_row_size<int8_t>(K);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment