".github/vscode:/vscode.git/clone" did not exist on "3a9d7d97588a1bbc906d8a17be77cf382492a7b6"
Unverified Commit fb4959b2 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

Add fp8 gemm kernel for CPU in sgl-kernel and add gemm UT (#6216)


Co-authored-by: default avatarYanbingJiang <yanbing.jiang@intel.com>
Co-authored-by: default avatarmingfeima <mingfei.ma@intel.com>
parent 9a405274
......@@ -22,7 +22,7 @@ namespace {
} \
}()
// dispatch: bfloat16, float16, int8_t
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
......@@ -38,6 +38,10 @@ namespace {
using packed_t = int8_t; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Float8_e4m3fn: { \
using packed_t = at::Float8_e4m3fn; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
......
......@@ -424,7 +424,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) {
const int64_t stride = OC * IC;
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf || st == at::kChar, "expect weight to be bfloat16, float16 or int8.");
st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
CPU_DISPATCH_PACKED_TYPES(st, [&] {
// adjust most inner dimension size
......
......@@ -33,6 +33,11 @@ inline bool can_use_brgemm<int8_t>(int M) {
return false;
}
template <>
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
return M > 4;
}
// work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K
......
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
inline void unpack_B(
at::BFloat16* __restrict__ Btmp,
const at::Float8_e4m3fn* __restrict__ packed_B,
int N,
int K,
int ldb,
int ldb_tmp,
float scale) {
#if defined(CPU_CAPABILITY_AVX512)
// [K/2, N, 2]
const int K2 = K >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
const __m512 vd = _mm512_set1_ps(scale);
constexpr int BLOCK_N = block_size_n();
static_assert(BLOCK_N == 32);
for (int k = 0; k < K2; ++k) {
for (int n = 0; n < N; n += 64) { // BLOCK_N = 32
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + n);
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 0, (__m512i)bf16_0);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 32, (__m512i)bf16_1);
}
}
#else
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
#endif
}
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ bias,
const float* __restrict__ scale,
int K,
int lda,
int ldb,
int ldc,
int64_t block_size_K) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ bias,
const float* __restrict__ scale,
int K,
int lda,
int ldb,
int ldc,
int64_t block_size_K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_set1_ps(0.f);
}
};
Unroll<ROWS * COLS>{}(loadc);
const int K2 = K >> 1;
const int lda2 = lda >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
auto compute = [&](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
int idx = k * 2 / block_size_K;
const __m512 vd = _mm512_set1_ps(scale[idx]);
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
vb[col + 0] = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
vb[col + 1] = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
for (int k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 1, 3 use 256bit store
// for COLS = 2, 4 use 512bit store
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 2, \
C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, \
scale, \
K, \
lda, \
ldb, \
ldc, \
block_size_K);
template <typename scalar_t, typename packed_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
}
};
template <typename scalar_t, bool has_bias>
struct brgemm<scalar_t, scalar_t, has_bias> {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
UNUSED(scale);
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <bool has_bias>
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
at::BFloat16* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
constexpr int BLOCK_N = block_size_n();
// [BLOCK_K, BLOCK_N] -> [BLOCK_K / 2, BLOCK_N * 2]
const int ldb_tmp = block_size_n();
static_assert(BLOCK_K == 128);
// accumulate across K per BLOCK_K
for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
const bool add_C = (k != 0);
at::native::cpublas::brgemm(M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp);
}
// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K) {
if (brg) {
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void fp8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const at::Float8_e4m3fn* __restrict__ mat2,
const float* __restrict__ scales2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM,
int64_t block_size_N,
int64_t block_size_K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const int64_t scale_size_K = div_up(K, block_size_K);
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
// for brgemm when mat2 is float8_e4m3
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
/* C */ out + mb_start * out_strideM + nb_start,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* scale */ scale_ptr,
/* bias */ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales2 to be float32.");
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1];
constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype, "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
fp8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<at::Float8_e4m3fn>(),
scales2.data_ptr<float>(),
bias_data,
M,
N,
K,
mat1_strideM,
out_strideM,
block_size_N,
block_size_K);
});
return out;
}
......@@ -94,6 +94,16 @@ at::Tensor int8_scaled_mm_cpu(
at::ScalarType out_dtype,
bool is_vnni);
// fp8 gemm
at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// quant + igemm
at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
......@@ -198,6 +208,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// igemm
m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX");
// fp8 gemm
m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX");
// quant + igemm
m.def(
"int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX");
......
......@@ -30,6 +30,66 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
// The following conversion is without denorm behavior, that is to say,
// Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6)
// Min subnorm : S.0000.001 = 2**(−9)
// 0.0019 ~ 0.0137 cannot be converted correctly.
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
auto mask = _mm512_cmpneq_epi16_mask(
_mm512_and_si512(x, _mm512_set1_epi16(127)),
_mm512_setzero_si512()); // mask = x & 0x7f
auto mask_nan = _mm512_cmpneq_epi16_mask(
_mm512_and_si512(x, _mm512_set1_epi16(127)),
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
auto exponent = _mm512_add_epi16(
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
return (__m512bh)(_mm512_or_si512(
nonsign,
_mm512_slli_epi16(
_mm512_and_si512(x, _mm512_set1_epi16(128)),
8))); // add sign (x & 128) << 8
}
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
__m512i lg2mant = _mm512_mask_mov_epi16(
_mm512_mask_mov_epi16(
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
_mm512_set1_epi16(2));
return (__m512bh)(_mm512_or_si512(
_mm512_maskz_mov_epi16(
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
_mm512_mask_blend_epi16(
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
_mm512_or_si512(
_mm512_and_si512(
_mm512_sllv_epi16(
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
_mm512_set1_epi16(0x007f)),
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
_mm512_or_si512(
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
_mm512_slli_epi16(
_mm512_add_epi16(
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
7)))),
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
}
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
#ifdef SGLANG_CPU_FP8_CVT_FTZ
return cvt_e4m3_bf16_intrinsic_without_denorm(a);
#else
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
#endif
}
#endif
// vector to scalar reduction
......
......@@ -47,6 +47,8 @@ def _get_version():
return line.split("=")[1].strip().strip('"')
cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1"
operator_namespace = "sgl_kernel"
include_dirs = []
......@@ -56,6 +58,7 @@ sources = [
"csrc/cpu/decode.cpp",
"csrc/cpu/extend.cpp",
"csrc/cpu/gemm.cpp",
"csrc/cpu/gemm_fp8.cpp",
"csrc/cpu/gemm_int8.cpp",
"csrc/cpu/moe.cpp",
"csrc/cpu/moe_int8.cpp",
......@@ -76,6 +79,9 @@ extra_compile_args = {
"-fopenmp",
]
}
if cpu_fp8_ftz:
extra_compile_args["cxx"].append("-DSGLANG_CPU_FP8_CVT_FTZ")
libraries = ["c10", "torch", "torch_python"]
cmdclass = {
"build_ext": BuildExtension.with_options(use_ninja=True),
......
import itertools
import unittest
import torch
import torch.nn as nn
# TODO: use interface in cpu.py
from sgl_kernel.common_ops import (
convert_weight_packed,
fp8_scaled_mm_cpu,
int8_scaled_mm_cpu,
int8_scaled_mm_with_quant,
per_token_quant_int8_cpu,
weight_packed_linear,
)
from utils import (
convert_weight,
native_w8a8_per_token_matmul,
per_token_quant_int8,
precision,
)
from sglang.test.test_utils import CustomTestCase
class Mod(nn.Module):
def __init__(self, input_channel, output_channel, has_bias):
super(Mod, self).__init__()
self.linear = torch.nn.Linear(input_channel, output_channel, has_bias)
def forward(self, x):
return self.linear(x)
class TestGemm(CustomTestCase):
M = [1, 101]
N = [32 * 13]
K = [32 * 16]
has_bias = [False, True]
M_int8 = [2, 128]
N_int8 = [32 * 12]
K_int8 = [32 * 17]
M_fp8 = [1, 11]
N_fp8 = [128, 224]
K_fp8 = [512, 576]
def _bf16_gemm(self, M, N, K, has_bias):
mat1 = torch.randn(M, K, dtype=torch.bfloat16)
mat2 = torch.randn(N, K, dtype=torch.bfloat16)
ref = torch.matmul(mat1.float(), mat2.float().t())
if has_bias:
bias = torch.randn(N, dtype=torch.float32)
ref.add_(bias.bfloat16())
ref = ref.bfloat16()
out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False)
packed_mat2 = convert_weight_packed(mat2)
out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True)
atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(ref, out2, atol=atol, rtol=rtol))
def test_bf16_gemm(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._bf16_gemm(*params)
def _int8_gemm(self, M, N, K, has_bias):
dtype = torch.bfloat16
A = torch.randn((M, K), dtype=dtype) / 10
Aq, As = per_token_quant_int8(A)
factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2
Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
Bs = torch.rand(N) * factor_for_scale
bias = torch.randn(N) if has_bias else None
ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype)
atol = rtol = precision[ref_out.dtype]
Aq2, As2 = per_token_quant_int8_cpu(A)
out = int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
)
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
# test the fused version
fused_out = int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
)
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
def test_int8_gemm(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._int8_gemm(*params)
def _fp8_gemm(self, M, N, K, has_bias):
prepack = True
chunk = False
scale_block_size_N = 64
scale_block_size_K = 128
assert scale_block_size_N <= N
assert scale_block_size_K <= K
A_dtype = torch.bfloat16
model = Mod(K, N, has_bias).eval()
if chunk:
data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K)
else:
data = torch.randn(M, K, dtype=A_dtype)
weight = model.linear.weight # (N, K)
if has_bias:
bias = model.linear.bias
fp8_weight, scales, dq_weight = convert_weight(
weight, [scale_block_size_N, scale_block_size_K], A_dtype
)
if has_bias:
ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype)
else:
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
if prepack:
fp8_weight = convert_weight_packed(fp8_weight)
opt = fp8_scaled_mm_cpu(
data,
fp8_weight,
scales,
[scale_block_size_N, scale_block_size_K],
bias if has_bias else None,
data.dtype,
prepack,
)
atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, opt, atol=atol, rtol=rtol))
def test_fp8_gemm(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._fp8_gemm(*params)
if __name__ == "__main__":
unittest.main()
import math
import torch
precision = {
torch.bfloat16: 1e-2,
torch.float16: 1e-3,
torch.float32: 1e-5,
}
def per_token_quant_int8(x):
x = x.float()
absmax = x.abs().max(dim=-1).values
absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
scale_x = absmax / 127
x_q = x.mul(127 / absmax)
x_q = torch.round(x_q).to(torch.int8)
return x_q, scale_x
def convert_weight(weight, scale_block_size, A_dtype):
N, K = weight.size()
fp8_max = 448.0
scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128)
pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N
pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K
if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
weight_blocks = weight.view(
math.ceil(N / scale_block_size_N),
scale_block_size_N,
math.ceil(K / scale_block_size_K),
scale_block_size_K,
) # (8, 128, 8, 128)
weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128)
# Step 2: compute per-block max abs values → scale
abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1)
scales = abs_max / fp8_max
scales = torch.where(
scales == 0, torch.ones_like(scales), scales
) # avoid division by zero
q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn)
q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous()
if pad_N > 0 or pad_K > 0:
q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K)
q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous()
else:
q_fp8_reshape = q_fp8_reshape.view(N, K)
dq_weight = q_fp8.float() * scales
dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128)
if pad_N > 0 or pad_K > 0:
w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype)
w_dq = w_dq[:N, :K].contiguous()
else:
w_dq = dq_weight.view(N, K).to(A_dtype)
scales = scales.view(
math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K)
)
return q_fp8_reshape, scales, w_dq
def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
if bias is not None:
C.add_(bias.view(1, -1))
return C.reshape(origin_C_shape).to(output_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment