Commit ed3cdc81 authored by zhuwenwen's avatar zhuwenwen
Browse files

新增fp8—e5m2

parent cf13152f
......@@ -965,6 +965,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else if (kv_cache_dtype == "fp8_e5m2") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E5M2);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
......
......@@ -53,10 +53,6 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return result;
}
// template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x;
// }
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
......@@ -64,281 +60,6 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return x;
}
// #if HIP_FP8_TYPE_OCP
// using fp8_type = __hip_fp8_e4m3;
// using fp8x2_type = __hip_fp8x2_e4m3;
// #else
// using fp8_type = __hip_fp8_e4m3_fnuz;
// using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
// #endif
// // fp8 -> half
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
// return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
// }
// // fp8x2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// return tmp.ui32;
// }
// // fp8x4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
// union {
// uint2 u32x2;
// uint32_t u32[2];
// } tmp;
// tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
// tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
// return tmp.u32x2;
// }
// // fp8x8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
// union {
// uint4 u64x2;
// uint2 u64[2];
// } tmp;
// tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
// tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
// return tmp.u64x2;
// }
// using __nv_bfloat16 = __hip_bfloat16;
// // fp8 -> __nv_bfloat16
// template <>
// __inline__ __device__ __nv_bfloat16
// vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8));
// }
// using __nv_bfloat162 = __hip_bfloat162;
// // fp8x2 -> __nv_bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
// __nv_bfloat162 res;
// res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
// res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
// return res;
// }
// // fp8x4 -> bf16_4_t
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
// bf16_4_t res;
// res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
// res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x8 -> bf16_8_t
// template <>
// __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// bf16_4_t tmp1, tmp2;
// tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
// tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
// bf16_8_t res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // fp8 -> float
// template <>
// __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8);
// }
// // fp8x2 -> float2
// template <>
// __inline__ __device__ float2
// vec_conversion<float2, uint16_t>(const uint16_t& a) {
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2);
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ Float4_
// vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
// Float4_ res;
// res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
// res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ float4
// vec_conversion<float4, uint32_t>(const uint32_t& a) {
// Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
// float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
// return res;
// }
// // fp8x8 -> float8
// template <>
// __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
// Float4_ tmp1, tmp2;
// tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
// tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
// Float8_ res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // half -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
// __half_raw tmp;
// tmp.x = a;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // bf16 -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
// return __hip_cvt_float_to_fp8(__bfloat162float(a),
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float -> fp8
// template <>
// __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
// return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, float2>(const float2& a) {
// union {
// half2 float16;
// uint32_t uint32;
// };
// float16 = __float22half2_rn(a);
// return uint32;
// }
// // Float4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
// uint2 b;
// float2 val;
// val.x = a.x.x;
// val.y = a.x.y;
// b.x = vec_conversion<uint32_t, float2>(val);
// val.x = a.y.x;
// val.y = a.y.y;
// b.y = vec_conversion<uint32_t, float2>(val);
// return b;
// }
// // Float4 -> float4
// template <>
// __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
// float4 b;
// b.x = a.x.x;
// b.y = a.x.y;
// b.z = a.y.x;
// b.w = a.y.y;
// return b;
// }
// // Float8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
// uint4 b;
// b.x = vec_conversion<uint32_t, float2>(a.x);
// b.y = vec_conversion<uint32_t, float2>(a.y);
// b.z = vec_conversion<uint32_t, float2>(a.z);
// b.w = vec_conversion<uint32_t, float2>(a.w);
// return b;
// }
// // float2 -> bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, float2>(const float2& a) {
// __nv_bfloat162 b = __float22bfloat162_rn(a);
// return b;
// }
// // Float4 -> bfloat162x2
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
// bf16_4_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// return b;
// }
// // Float8 -> bfloat162x4
// template <>
// __inline__ __device__ bf16_8_t
// vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
// bf16_8_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// b.z = __float22bfloat162_rn(a.z);
// b.w = __float22bfloat162_rn(a.w);
// return b;
// }
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
Convention of the scale in API, e.g: FP8_data = Quantization(
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
scale => HP
*/
using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
......@@ -347,9 +68,6 @@ __inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
return __float2bfloat16(fp8_to_float(a) * scale);
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8) * scale);
}
// fp8x2 -> __nv_bfloat162
......@@ -395,9 +113,6 @@ template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) {
return fp8_to_float(a) * scale;
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8) * scale;
}
// fp8x2 -> float2
......@@ -408,10 +123,6 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
return f2r;
// [[maybe_unused]]
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2) * scale;
}
// fp8x4 -> float4
......@@ -453,9 +164,6 @@ __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
float res = fp8_to_float(a) * scale;
return float_to_half(res);
// __half_raw res;
// res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
// return res.x;
}
// fp8x2 -> half2
......@@ -469,16 +177,6 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
return res.u32;
// [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// tmp.h2r.x.data *= scale;
// tmp.h2r.y.data *= scale;
// return tmp.ui32;
}
// fp8x4 -> half2x2
......@@ -513,11 +211,6 @@ __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
float res_f = half_to_float(a) / scale;
return float_to_fp8(res_f);
// __half_raw tmp;
// tmp.x = a;
// tmp.data /= scale;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// halfx2 -> fp8x2
......@@ -536,15 +229,6 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
return tmp.ui16;
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// tmp.h2r.x.data /= scale;
// tmp.h2r.y.data /= scale;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// half2x2 -> fp8x4
......@@ -581,9 +265,6 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) {
float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8(res_f);
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// bf16x2 -> fp8x2
......@@ -627,8 +308,6 @@ template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return float_to_fp8(a / scale);
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// floatx2 -> fp8x2
......@@ -642,8 +321,6 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
return tmp.ui16;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// floatx4 -> fp8x4
......@@ -658,27 +335,113 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
return tmp.ui32;
}
// #endif // ENABLE_FP8
// template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
// __inline__ __device__ Tout convert(const Tin& x) {
// #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
// return vec_conversion<Tout, Tin>(x);
// }
// #endif
// assert(false);
// return {}; // Squash missing return statement warning
// }
inline __device__ uint8_t float_to_fp8e5m2(float f) {
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
uint8_t result = 0u;
const uint32_t sign = f_bits & UINT32_C(0x80000000);
f_bits ^= sign;
if (f_bits >= fp8_max) {
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
+ c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint32_t mant_odd = (f_bits >> 21) & 1;
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
// fp8
template <typename Tin>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2(const Tin& a, float scale) {
return 0;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<float>(const float& a, float scale) {
return float_to_fp8e5m2(a / scale);
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<uint16_t>(const uint16_t& a, float scale) {
float res_f = half_to_float(a) / scale;
return float_to_fp8e5m2(res_f);
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8e5m2(res_f);
}
inline __device__ float fp8e5m2_to_fp32(const uint8_t& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
uf16 u16;
u16.as_bits = (uint16_t)input << 8;
return (float)u16.as_value;
}
template <typename Tout>
__inline__ __device__ Tout
scaled_vec_conversion_from_e5m2(const uint8_t& a, float scale) {
return 0;
}
// fp8 -> float
template <>
__inline__ __device__ float
scaled_vec_conversion_from_e5m2<float>(const uint8_t& a, float scale) {
return fp8e5m2_to_fp32(a)*scale;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion_from_e5m2<uint16_t>(const uint8_t& a, float scale) {
return float_to_half(fp8e5m2_to_fp32(a)*scale);
}
// fp8 -> bf16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
return __float2bfloat16(fp8e5m2_to_fp32(a)*scale);
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
// }
// #endif
// assert(false);
}
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
}
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tin)==1){
return scaled_vec_conversion_from_e5m2<Tout>(x, scale);
}
return {}; // Squash missing return statement warning
}
......@@ -686,7 +449,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// the data type of the key and value cache. The FN is a macro that calls a
// function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
......@@ -719,11 +482,23 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
} // namespace fp8
#endif // USE_ROCM
} // namespace vllm
\ No newline at end of file
......@@ -2166,12 +2166,14 @@ def gather_cache(src_cache: torch.Tensor,
kv_dtype = "auto",
scale: float = 1.0,
) -> None:
#支持"kv cache fp8"
if kv_dtype == "fp8":
dst_fp8 = torch.zeros(dst.shape, dtype=torch.uint8, device=dst.device)
convert_fp8(dst_fp8, dst, scale, kv_dtype)
#支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if kv_dtype == "fp8" or kv_dtype == "fp8_e5m2" or kv_dtype == "fp8_e4m3":
dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device)
#convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16
convert_fp8(dst, dst_fp8, scale, kv_dtype)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
......
......@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8":
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
return
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode(
self,
......
......@@ -24,6 +24,12 @@ from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
USE_XFORMERS_OPS = None
try:
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module):
"""Attention layer.
......@@ -204,10 +210,12 @@ class Attention(nn.Module):
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# attn_metadata = get_forward_context().attn_metadata
# #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None:
self.calc_kv_scales(query, key, value)
# self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
......@@ -397,6 +405,44 @@ def maybe_save_kv_layer_to_connector(
attn_metadata[layer_name])
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake( query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=[],
fake_impl=maybe_calc_kv_scales_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
......
......@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm():
if kv_cache_dtype == "fp8":
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
......@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata,
num_splits,
k_scale,
"fp8_e4m3",
kv_dtype,
)
return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
......
......@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
# "fp8_e4m3": torch.uint8,
# "fp8_e5m2": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
}
......
......@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8":
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
return
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment