#pragma once #if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__) #define CPU_CAPABILITY_AVX512 #endif #include #include namespace { using namespace at::vec; template , int> = 0> inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { return at::vec::convert_from_float(a, b); } #if defined(CPU_CAPABILITY_AVX512) // `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics // use native instruction for bfloat16->float32 conversion template <> inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); } #define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) #define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) #endif // vector to scalar reduction #if defined(CPU_CAPABILITY_AVX512) && 0 inline float vec_reduce_sum(const Vectorized& a) { return _mm512_reduce_add_ps(__m512(a)); } inline float vec_reduce_max(const Vectorized& a) { return _mm512_reduce_max_ps(__m512(a)); } #else inline float vec_reduce_sum(const Vectorized& a) { return vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, a); } inline float vec_reduce_max(const Vectorized& a) { return vec_reduce_all([](Vectorized& x, Vectorized& y) { return maximum(x, y); }, a); } #endif // https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 template inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) { float amax = 0.f; // absolute max for (int64_t k = 0; k < K; ++k) { const float val = static_cast(A[k]); amax = std::max(amax, std::abs(val)); } amax = std::max(amax, eps); const float scale = amax / 127; const float inv_scale = 127 / amax; for (int64_t k = 0; k < K; ++k) { const float val = static_cast(A[k]) * inv_scale; Aq[k] = (uint8_t)(std::round(val)) + 128; } As = scale; } #if defined(CPU_CAPABILITY_AVX512) template <> inline void quantize_row_int8( uint8_t* __restrict__ Aq, float& As, const at::BFloat16* __restrict__ A, int64_t K, float eps) { const __m512 signBit = _mm512_set1_ps(-0.0f); const __m512i off = _mm512_set1_epi32(128); // K is 32x, no remainder float amax = 0.f; __m512 vamax0 = _mm512_set1_ps(0.f); __m512 vamax1 = _mm512_set1_ps(0.f); for (int64_t k = 0; k < K; k += 32) { __m512i va = _mm512_loadu_si512((void*)(A + k)); __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0)); vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1)); } amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1)); amax = std::max(amax, eps); const float scale = amax / 127; const float inv_scale = 127 / amax; const __m512 vd = _mm512_set1_ps(inv_scale); for (int64_t k = 0; k < K; k += 32) { __m512i va = _mm512_loadu_si512((void*)(A + k)); __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); va0 = _mm512_mul_ps(va0, vd); va1 = _mm512_mul_ps(va1, vd); va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); __m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off)); __m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off)); _mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0)); } As = scale; } #endif } // anonymous namespace