You need to sign in or sign up before continuing.
Unverified Commit fb4959b2 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

Add fp8 gemm kernel for CPU in sgl-kernel and add gemm UT (#6216)


Co-authored-by: default avatarYanbingJiang <yanbing.jiang@intel.com>
Co-authored-by: default avatarmingfeima <mingfei.ma@intel.com>
parent 9a405274
...@@ -22,7 +22,7 @@ namespace { ...@@ -22,7 +22,7 @@ namespace {
} \ } \
}() }()
// dispatch: bfloat16, float16, int8_t // dispatch: bfloat16, float16, int8_t, fp8_e4m3
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ #define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \ [&] { \
switch (TYPE) { \ switch (TYPE) { \
...@@ -38,6 +38,10 @@ namespace { ...@@ -38,6 +38,10 @@ namespace {
using packed_t = int8_t; \ using packed_t = int8_t; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
case at::ScalarType::Float8_e4m3fn: { \
using packed_t = at::Float8_e4m3fn; \
return __VA_ARGS__(); \
} \
default: \ default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \ TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \ } \
......
...@@ -424,7 +424,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { ...@@ -424,7 +424,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) {
const int64_t stride = OC * IC; const int64_t stride = OC * IC;
TORCH_CHECK( TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf || st == at::kChar, "expect weight to be bfloat16, float16 or int8."); st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
CPU_DISPATCH_PACKED_TYPES(st, [&] { CPU_DISPATCH_PACKED_TYPES(st, [&] {
// adjust most inner dimension size // adjust most inner dimension size
......
...@@ -33,6 +33,11 @@ inline bool can_use_brgemm<int8_t>(int M) { ...@@ -33,6 +33,11 @@ inline bool can_use_brgemm<int8_t>(int M) {
return false; return false;
} }
template <>
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
return M > 4;
}
// work around compiler internal error // work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K #define BLOCK_K 128 // 4 * TILE_K
......
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
inline void unpack_B(
at::BFloat16* __restrict__ Btmp,
const at::Float8_e4m3fn* __restrict__ packed_B,
int N,
int K,
int ldb,
int ldb_tmp,
float scale) {
#if defined(CPU_CAPABILITY_AVX512)
// [K/2, N, 2]
const int K2 = K >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
const __m512 vd = _mm512_set1_ps(scale);
constexpr int BLOCK_N = block_size_n();
static_assert(BLOCK_N == 32);
for (int k = 0; k < K2; ++k) {
for (int n = 0; n < N; n += 64) { // BLOCK_N = 32
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + n);
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 0, (__m512i)bf16_0);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 32, (__m512i)bf16_1);
}
}
#else
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
#endif
}
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ bias,
const float* __restrict__ scale,
int K,
int lda,
int ldb,
int ldc,
int64_t block_size_K) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ bias,
const float* __restrict__ scale,
int K,
int lda,
int ldb,
int ldc,
int64_t block_size_K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_set1_ps(0.f);
}
};
Unroll<ROWS * COLS>{}(loadc);
const int K2 = K >> 1;
const int lda2 = lda >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
auto compute = [&](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
int idx = k * 2 / block_size_K;
const __m512 vd = _mm512_set1_ps(scale[idx]);
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
vb[col + 0] = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
vb[col + 1] = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
for (int k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 1, 3 use 256bit store
// for COLS = 2, 4 use 512bit store
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 2, \
C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, \
scale, \
K, \
lda, \
ldb, \
ldc, \
block_size_K);
template <typename scalar_t, typename packed_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
}
};
template <typename scalar_t, bool has_bias>
struct brgemm<scalar_t, scalar_t, has_bias> {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
UNUSED(scale);
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <bool has_bias>
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
at::BFloat16* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
constexpr int BLOCK_N = block_size_n();
// [BLOCK_K, BLOCK_N] -> [BLOCK_K / 2, BLOCK_N * 2]
const int ldb_tmp = block_size_n();
static_assert(BLOCK_K == 128);
// accumulate across K per BLOCK_K
for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
const bool add_C = (k != 0);
at::native::cpublas::brgemm(M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp);
}
// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K) {
if (brg) {
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void fp8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const at::Float8_e4m3fn* __restrict__ mat2,
const float* __restrict__ scales2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM,
int64_t block_size_N,
int64_t block_size_K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const int64_t scale_size_K = div_up(K, block_size_K);
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
// for brgemm when mat2 is float8_e4m3
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
/* C */ out + mb_start * out_strideM + nb_start,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* scale */ scale_ptr,
/* bias */ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales2 to be float32.");
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1];
constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype, "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
fp8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<at::Float8_e4m3fn>(),
scales2.data_ptr<float>(),
bias_data,
M,
N,
K,
mat1_strideM,
out_strideM,
block_size_N,
block_size_K);
});
return out;
}
...@@ -94,6 +94,16 @@ at::Tensor int8_scaled_mm_cpu( ...@@ -94,6 +94,16 @@ at::Tensor int8_scaled_mm_cpu(
at::ScalarType out_dtype, at::ScalarType out_dtype,
bool is_vnni); bool is_vnni);
// fp8 gemm
at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// quant + igemm // quant + igemm
at::Tensor int8_scaled_mm_with_quant( at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1, at::Tensor& mat1,
...@@ -198,6 +208,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -198,6 +208,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// igemm // igemm
m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX"); m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX");
// fp8 gemm
m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX");
// quant + igemm // quant + igemm
m.def( m.def(
"int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX"); "int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX");
......
...@@ -30,6 +30,66 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize ...@@ -30,6 +30,66 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) #define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
// The following conversion is without denorm behavior, that is to say,
// Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6)
// Min subnorm : S.0000.001 = 2**(−9)
// 0.0019 ~ 0.0137 cannot be converted correctly.
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
auto mask = _mm512_cmpneq_epi16_mask(
_mm512_and_si512(x, _mm512_set1_epi16(127)),
_mm512_setzero_si512()); // mask = x & 0x7f
auto mask_nan = _mm512_cmpneq_epi16_mask(
_mm512_and_si512(x, _mm512_set1_epi16(127)),
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
auto exponent = _mm512_add_epi16(
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
return (__m512bh)(_mm512_or_si512(
nonsign,
_mm512_slli_epi16(
_mm512_and_si512(x, _mm512_set1_epi16(128)),
8))); // add sign (x & 128) << 8
}
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
__m512i lg2mant = _mm512_mask_mov_epi16(
_mm512_mask_mov_epi16(
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
_mm512_set1_epi16(2));
return (__m512bh)(_mm512_or_si512(
_mm512_maskz_mov_epi16(
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
_mm512_mask_blend_epi16(
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
_mm512_or_si512(
_mm512_and_si512(
_mm512_sllv_epi16(
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
_mm512_set1_epi16(0x007f)),
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
_mm512_or_si512(
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
_mm512_slli_epi16(
_mm512_add_epi16(
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
7)))),
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
}
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
#ifdef SGLANG_CPU_FP8_CVT_FTZ
return cvt_e4m3_bf16_intrinsic_without_denorm(a);
#else
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
#endif
}
#endif #endif
// vector to scalar reduction // vector to scalar reduction
......
...@@ -47,6 +47,8 @@ def _get_version(): ...@@ -47,6 +47,8 @@ def _get_version():
return line.split("=")[1].strip().strip('"') return line.split("=")[1].strip().strip('"')
cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1"
operator_namespace = "sgl_kernel" operator_namespace = "sgl_kernel"
include_dirs = [] include_dirs = []
...@@ -56,6 +58,7 @@ sources = [ ...@@ -56,6 +58,7 @@ sources = [
"csrc/cpu/decode.cpp", "csrc/cpu/decode.cpp",
"csrc/cpu/extend.cpp", "csrc/cpu/extend.cpp",
"csrc/cpu/gemm.cpp", "csrc/cpu/gemm.cpp",
"csrc/cpu/gemm_fp8.cpp",
"csrc/cpu/gemm_int8.cpp", "csrc/cpu/gemm_int8.cpp",
"csrc/cpu/moe.cpp", "csrc/cpu/moe.cpp",
"csrc/cpu/moe_int8.cpp", "csrc/cpu/moe_int8.cpp",
...@@ -76,6 +79,9 @@ extra_compile_args = { ...@@ -76,6 +79,9 @@ extra_compile_args = {
"-fopenmp", "-fopenmp",
] ]
} }
if cpu_fp8_ftz:
extra_compile_args["cxx"].append("-DSGLANG_CPU_FP8_CVT_FTZ")
libraries = ["c10", "torch", "torch_python"] libraries = ["c10", "torch", "torch_python"]
cmdclass = { cmdclass = {
"build_ext": BuildExtension.with_options(use_ninja=True), "build_ext": BuildExtension.with_options(use_ninja=True),
......
import itertools
import unittest
import torch
import torch.nn as nn
# TODO: use interface in cpu.py
from sgl_kernel.common_ops import (
convert_weight_packed,
fp8_scaled_mm_cpu,
int8_scaled_mm_cpu,
int8_scaled_mm_with_quant,
per_token_quant_int8_cpu,
weight_packed_linear,
)
from utils import (
convert_weight,
native_w8a8_per_token_matmul,
per_token_quant_int8,
precision,
)
from sglang.test.test_utils import CustomTestCase
class Mod(nn.Module):
def __init__(self, input_channel, output_channel, has_bias):
super(Mod, self).__init__()
self.linear = torch.nn.Linear(input_channel, output_channel, has_bias)
def forward(self, x):
return self.linear(x)
class TestGemm(CustomTestCase):
M = [1, 101]
N = [32 * 13]
K = [32 * 16]
has_bias = [False, True]
M_int8 = [2, 128]
N_int8 = [32 * 12]
K_int8 = [32 * 17]
M_fp8 = [1, 11]
N_fp8 = [128, 224]
K_fp8 = [512, 576]
def _bf16_gemm(self, M, N, K, has_bias):
mat1 = torch.randn(M, K, dtype=torch.bfloat16)
mat2 = torch.randn(N, K, dtype=torch.bfloat16)
ref = torch.matmul(mat1.float(), mat2.float().t())
if has_bias:
bias = torch.randn(N, dtype=torch.float32)
ref.add_(bias.bfloat16())
ref = ref.bfloat16()
out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False)
packed_mat2 = convert_weight_packed(mat2)
out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True)
atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(ref, out2, atol=atol, rtol=rtol))
def test_bf16_gemm(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._bf16_gemm(*params)
def _int8_gemm(self, M, N, K, has_bias):
dtype = torch.bfloat16
A = torch.randn((M, K), dtype=dtype) / 10
Aq, As = per_token_quant_int8(A)
factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2
Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
Bs = torch.rand(N) * factor_for_scale
bias = torch.randn(N) if has_bias else None
ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype)
atol = rtol = precision[ref_out.dtype]
Aq2, As2 = per_token_quant_int8_cpu(A)
out = int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
)
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
# test the fused version
fused_out = int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
)
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
def test_int8_gemm(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._int8_gemm(*params)
def _fp8_gemm(self, M, N, K, has_bias):
prepack = True
chunk = False
scale_block_size_N = 64
scale_block_size_K = 128
assert scale_block_size_N <= N
assert scale_block_size_K <= K
A_dtype = torch.bfloat16
model = Mod(K, N, has_bias).eval()
if chunk:
data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K)
else:
data = torch.randn(M, K, dtype=A_dtype)
weight = model.linear.weight # (N, K)
if has_bias:
bias = model.linear.bias
fp8_weight, scales, dq_weight = convert_weight(
weight, [scale_block_size_N, scale_block_size_K], A_dtype
)
if has_bias:
ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype)
else:
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
if prepack:
fp8_weight = convert_weight_packed(fp8_weight)
opt = fp8_scaled_mm_cpu(
data,
fp8_weight,
scales,
[scale_block_size_N, scale_block_size_K],
bias if has_bias else None,
data.dtype,
prepack,
)
atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, opt, atol=atol, rtol=rtol))
def test_fp8_gemm(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._fp8_gemm(*params)
if __name__ == "__main__":
unittest.main()
import math
import torch
precision = {
torch.bfloat16: 1e-2,
torch.float16: 1e-3,
torch.float32: 1e-5,
}
def per_token_quant_int8(x):
x = x.float()
absmax = x.abs().max(dim=-1).values
absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
scale_x = absmax / 127
x_q = x.mul(127 / absmax)
x_q = torch.round(x_q).to(torch.int8)
return x_q, scale_x
def convert_weight(weight, scale_block_size, A_dtype):
N, K = weight.size()
fp8_max = 448.0
scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128)
pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N
pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K
if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
weight_blocks = weight.view(
math.ceil(N / scale_block_size_N),
scale_block_size_N,
math.ceil(K / scale_block_size_K),
scale_block_size_K,
) # (8, 128, 8, 128)
weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128)
# Step 2: compute per-block max abs values → scale
abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1)
scales = abs_max / fp8_max
scales = torch.where(
scales == 0, torch.ones_like(scales), scales
) # avoid division by zero
q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn)
q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous()
if pad_N > 0 or pad_K > 0:
q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K)
q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous()
else:
q_fp8_reshape = q_fp8_reshape.view(N, K)
dq_weight = q_fp8.float() * scales
dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128)
if pad_N > 0 or pad_K > 0:
w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype)
w_dq = w_dq[:N, :K].contiguous()
else:
w_dq = dq_weight.view(N, K).to(A_dtype)
scales = scales.view(
math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K)
)
return q_fp8_reshape, scales, w_dq
def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
if bias is not None:
C.add_(bias.view(1, -1))
return C.reshape(origin_C_shape).to(output_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment