Commit cea85c38 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.11.0-dev

parents 6d8c8719 bc80af59
...@@ -644,10 +644,23 @@ class CustomAllreduce { ...@@ -644,10 +644,23 @@ class CustomAllreduce {
size /= d; size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P); auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads); int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \ // #define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \ // name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, dev_curr_hdp_reg, world_size_) ; // rank_, size, dev_curr_hdp_reg, world_size_) ;
#define KL(ngpus, name) \
{ \
void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \ #define REDUCE_CASE(ngpus) \
case ngpus: { \ case ngpus: { \
if (world_size_ == 2) { \ if (world_size_ == 2) { \
......
...@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) { ...@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) {
} }
// float -> fp8 // float -> fp8
static inline __device__ uint8_t float_to_fp8(float f) { static inline __device__ uint8_t float_to_fp8_e4m3(float f) {
constexpr uint32_t fp8_max = UINT32_C(1087) << 20; constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
constexpr uint32_t denorm_mask = UINT32_C(141) << 23; constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f); uint32_t f_bits = c10::detail::fp32_to_bits(f);
...@@ -53,6 +53,31 @@ static inline __device__ uint8_t float_to_fp8(float f) { ...@@ -53,6 +53,31 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return result; return result;
} }
static inline __device__ uint8_t float_to_fp8_e5m2(float f) {
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
uint8_t result = 0u;
const uint32_t sign = f_bits & UINT32_C(0x80000000);
f_bits ^= sign;
if (f_bits >= fp8_max) {
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
+ c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint32_t mant_odd = (f_bits >> 21) & 1;
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
// template <typename Tout, typename Tin> // template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) { // __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x; // return x;
...@@ -60,7 +85,7 @@ static inline __device__ uint8_t float_to_fp8(float f) { ...@@ -60,7 +85,7 @@ static inline __device__ uint8_t float_to_fp8(float f) {
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
const float scale) { const float scale, Fp8KVCacheDataType kv_type) {
return x; return x;
} }
...@@ -344,8 +369,10 @@ using __nv_bfloat16 = __hip_bfloat16; ...@@ -344,8 +369,10 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16 // fp8 -> __nv_bfloat16
template <> template <>
__inline__ __device__ __nv_bfloat16 __inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return __float2bfloat16(fp8_to_float(a) * scale); return __float2bfloat16(fp8_to_float(a) * scale);
// fp8_type f8; // fp8_type f8;
// f8.__x = a; // f8.__x = a;
...@@ -356,32 +383,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { ...@@ -356,32 +383,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
template <> template <>
__inline__ __device__ __nv_bfloat162 __inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
float scale) { float scale, Fp8KVCacheDataType kv_type) {
__nv_bfloat162 res; __nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type);
res.y = res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res; return res;
} }
// fp8x4 -> bf16_4_t // fp8x4 -> bf16_4_t
template <> template <>
__inline__ __device__ bf16_4_t __inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t res; bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale); scale, kv_type);
return res; return res;
} }
// fp8x8 -> bf16_8_t // fp8x8 -> bf16_8_t
template <> template <>
__inline__ __device__ bf16_8_t __inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) { scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t tmp1, tmp2; bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale); tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale); tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, kv_type);
bf16_8_t res; bf16_8_t res;
res.x = tmp1.x; res.x = tmp1.x;
res.y = tmp1.y; res.y = tmp1.y;
...@@ -393,7 +420,10 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) { ...@@ -393,7 +420,10 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float // fp8 -> float
template <> template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>( __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) { const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return fp8_to_float(a) * scale; return fp8_to_float(a) * scale;
// fp8_type f8; // fp8_type f8;
// f8.__x = a; // f8.__x = a;
...@@ -403,10 +433,10 @@ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>( ...@@ -403,10 +433,10 @@ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 __inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float2 f2r; float2 f2r;
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale); f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale, kv_type);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale); f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return f2r; return f2r;
// [[maybe_unused]] // [[maybe_unused]]
// fp8x2_type f8x2; // fp8x2_type f8x2;
...@@ -417,28 +447,28 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) { ...@@ -417,28 +447,28 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ Float4_ __inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) { scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) {
Float4_ res; Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale); res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale); res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return res; return res;
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ float4 __inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale); Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale, kv_type);
return {res.x.x, res.x.y, res.y.x, res.y.y}; return {res.x.x, res.x.y, res.y.x, res.y.y};
} }
// fp8x8 -> float8 // fp8x8 -> float8
template <> template <>
__inline__ __device__ Float8_ __inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) { scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ tmp1, tmp2; Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale); tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale); tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, kv_type);
Float8_ res; Float8_ res;
res.x = tmp1.x; res.x = tmp1.x;
res.y = tmp1.y; res.y = tmp1.y;
...@@ -450,7 +480,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) { ...@@ -450,7 +480,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half // fp8 -> half
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
float res = fp8_to_float(a) * scale; float res = fp8_to_float(a) * scale;
return float_to_half(res); return float_to_half(res);
// __half_raw res; // __half_raw res;
...@@ -461,13 +494,13 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) { ...@@ -461,13 +494,13 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint16_t u16[2]; uint16_t u16[2];
uint32_t u32; uint32_t u32;
} res; } res;
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale); res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale, kv_type);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale); res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res.u32; return res.u32;
// [[maybe_unused]] __half2_raw h2r = // [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); // __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
...@@ -484,35 +517,40 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) { ...@@ -484,35 +517,40 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
// fp8x4 -> half2x2 // fp8x4 -> half2x2
template <> template <>
__inline__ __device__ uint2 __inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint2 u32x2; uint2 u32x2;
uint32_t u32[2]; uint32_t u32[2];
} tmp; } tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale); tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, kv_type);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale); tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return tmp.u32x2; return tmp.u32x2;
} }
// fp8x8 -> half2x4 // fp8x8 -> half2x4
template <> template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) { float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint4 u64x2; uint4 u64x2;
uint2 u64[2]; uint2 u64[2];
} tmp; } tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale); tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, kv_type);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale); tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, kv_type);
return tmp.u64x2; return tmp.u64x2;
} }
// half -> fp8 // half -> fp8
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = half_to_float(a) / scale; float res_f = half_to_float(a) / scale;
return float_to_fp8(res_f); if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
// __half_raw tmp; // __half_raw tmp;
// tmp.x = a; // tmp.x = a;
// tmp.data /= scale; // tmp.data /= scale;
...@@ -523,7 +561,7 @@ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) { ...@@ -523,7 +561,7 @@ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
// halfx2 -> fp8x2 // halfx2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint8_t ui8[2]; uint8_t ui8[2];
uint16_t ui16; uint16_t ui16;
...@@ -533,8 +571,8 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { ...@@ -533,8 +571,8 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
half2 h2r; half2 h2r;
} tmp_a; } tmp_a;
tmp_a.ui32 = a; tmp_a.ui32 = a;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale); tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale); tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale, kv_type);
return tmp.ui16; return tmp.ui16;
// union { // union {
// uint32_t ui32; // uint32_t ui32;
...@@ -550,37 +588,41 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { ...@@ -550,37 +588,41 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
// half2x2 -> fp8x4 // half2x2 -> fp8x4
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) { scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint16_t ui16[2]; uint16_t ui16[2];
uint32_t ui32; uint32_t ui32;
} tmp; } tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale); tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale, kv_type);
return tmp.ui32; return tmp.ui32;
} }
// half2x4 -> fp8x8 // half2x4 -> fp8x8
template <> template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a, __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) { float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint2 ui2[2]; uint2 ui2[2];
uint4 ui4; uint4 ui4;
} tmp; } tmp;
tmp.ui4 = a; tmp.ui4 = a;
uint2 res; uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale); res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale); res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale, kv_type);
return res; return res;
} }
// bf16 -> fp8 // bf16 -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) { const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = (static_cast<float>(a)) / scale; float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8(res_f); if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale, // return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation, // fp8_type::__default_saturation,
// fp8_type::__default_interpret); // fp8_type::__default_interpret);
...@@ -589,44 +631,48 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( ...@@ -589,44 +631,48 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
// bf16x2 -> fp8x2 // bf16x2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>( __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) { const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint8_t ui8[2]; uint8_t ui8[2];
uint16_t ui16; uint16_t ui16;
} tmp; } tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale); tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale); tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale, kv_type);
return tmp.ui16; return tmp.ui16;
} }
// bf16x4 -> fp8x4 // bf16x4 -> fp8x4
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) { scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint16_t ui16[2]; uint16_t ui16[2];
uint32_t ui32; uint32_t ui32;
} tmp; } tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale); tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale, kv_type);
return tmp.ui32; return tmp.ui32;
} }
// bf16x8 -> fp8x8 // bf16x8 -> fp8x8
template <> template <>
__inline__ __device__ uint2 __inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) { scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) {
uint2 res; uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale); res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale); res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale, kv_type);
return res; return res;
} }
// float -> fp8 // float -> fp8
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) { scaled_vec_conversion<uint8_t, float>(const float& a, float scale, Fp8KVCacheDataType kv_type) {
return float_to_fp8(a / scale); if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(a / scale);
} else {
return float_to_fp8_e5m2(a / scale);
}
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation, // return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret); // fp8_type::__default_interpret);
} }
...@@ -634,13 +680,13 @@ scaled_vec_conversion<uint8_t, float>(const float& a, float scale) { ...@@ -634,13 +680,13 @@ scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
// floatx2 -> fp8x2 // floatx2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) { scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint8_t ui8[2]; uint8_t ui8[2];
uint16_t ui16; uint16_t ui16;
} tmp; } tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale); tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale); tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale, kv_type);
return tmp.ui16; return tmp.ui16;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation, // return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret); // fp8_type::__default_interpret);
...@@ -649,13 +695,13 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) { ...@@ -649,13 +695,13 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
// floatx4 -> fp8x4 // floatx4 -> fp8x4
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) { scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint16_t ui16[2]; uint16_t ui16[2];
uint32_t ui32; uint32_t ui32;
} tmp; } tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale); tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale, kv_type);
return tmp.ui32; return tmp.ui32;
} }
// #endif // ENABLE_FP8 // #endif // ENABLE_FP8
...@@ -674,11 +720,11 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) { ...@@ -674,11 +720,11 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// #ifdef ENABLE_FP8 // #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale); return scaled_vec_conversion<Tout, Tin>(x, scale, kv_dt);
// } }
// #endif // #endif
// assert(false); assert(false);
return {}; // Squash missing return statement warning return {}; // Squash missing return statement warning
} }
...@@ -719,6 +765,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -719,6 +765,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \ TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \ "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} \
else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \ } \
......
...@@ -278,10 +278,8 @@ class CustomAllreduce: ...@@ -278,10 +278,8 @@ class CustomAllreduce:
if envs.VLLM_CUSTOM_CACHE: if envs.VLLM_CUSTOM_CACHE:
return self.all_reduce(input, registered=True) return self.all_reduce(input, registered=True)
else: else:
if not self.fully_connected: return self.all_reduce(input, registered=False)
return self.all_reduce(input, registered=False)
else:
return self.all_reduce(input, registered=True)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
......
...@@ -1565,7 +1565,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1565,7 +1565,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE": "VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "0"))), lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX": "VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
......
...@@ -143,6 +143,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -143,6 +143,8 @@ class FlashAttentionBackend(AttentionBackend):
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"): if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn return torch.float8_e4m3fn
elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2
else: else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment