Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ed3cdc81
Commit
ed3cdc81
authored
Oct 13, 2025
by
zhuwenwen
Browse files
新增fp8—e5m2
parent
cf13152f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
199 additions
and
357 deletions
+199
-357
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+16
-0
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+115
-340
vllm/_custom_ops.py
vllm/_custom_ops.py
+6
-4
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+4
-3
vllm/attention/layer.py
vllm/attention/layer.py
+49
-3
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+3
-2
vllm/utils/__init__.py
vllm/utils/__init__.py
+2
-2
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+4
-3
No files found.
csrc/cache_kernels.cu
View file @
ed3cdc81
...
@@ -965,6 +965,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
...
@@ -965,6 +965,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
}
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
uint8_t
,
float
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint8_t
,
uint16_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
uint8_t
,
__nv_bfloat16
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
float
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint16_t
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
kv_cache_dtype
);
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
kv_cache_dtype
);
}
}
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
ed3cdc81
...
@@ -53,10 +53,6 @@ static inline __device__ uint8_t float_to_fp8(float f) {
...
@@ -53,10 +53,6 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return
result
;
return
result
;
}
}
// template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x;
// }
template
<
typename
Tout
,
typename
Tin
>
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
...
@@ -64,281 +60,6 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
...
@@ -64,281 +60,6 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return
x
;
return
x
;
}
}
// #if HIP_FP8_TYPE_OCP
// using fp8_type = __hip_fp8_e4m3;
// using fp8x2_type = __hip_fp8x2_e4m3;
// #else
// using fp8_type = __hip_fp8_e4m3_fnuz;
// using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
// #endif
// // fp8 -> half
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
// return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
// }
// // fp8x2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// return tmp.ui32;
// }
// // fp8x4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
// union {
// uint2 u32x2;
// uint32_t u32[2];
// } tmp;
// tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
// tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
// return tmp.u32x2;
// }
// // fp8x8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
// union {
// uint4 u64x2;
// uint2 u64[2];
// } tmp;
// tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
// tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
// return tmp.u64x2;
// }
// using __nv_bfloat16 = __hip_bfloat16;
// // fp8 -> __nv_bfloat16
// template <>
// __inline__ __device__ __nv_bfloat16
// vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8));
// }
// using __nv_bfloat162 = __hip_bfloat162;
// // fp8x2 -> __nv_bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
// __nv_bfloat162 res;
// res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
// res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
// return res;
// }
// // fp8x4 -> bf16_4_t
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
// bf16_4_t res;
// res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
// res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x8 -> bf16_8_t
// template <>
// __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// bf16_4_t tmp1, tmp2;
// tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
// tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
// bf16_8_t res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // fp8 -> float
// template <>
// __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8);
// }
// // fp8x2 -> float2
// template <>
// __inline__ __device__ float2
// vec_conversion<float2, uint16_t>(const uint16_t& a) {
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2);
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ Float4_
// vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
// Float4_ res;
// res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
// res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ float4
// vec_conversion<float4, uint32_t>(const uint32_t& a) {
// Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
// float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
// return res;
// }
// // fp8x8 -> float8
// template <>
// __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
// Float4_ tmp1, tmp2;
// tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
// tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
// Float8_ res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // half -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
// __half_raw tmp;
// tmp.x = a;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // bf16 -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
// return __hip_cvt_float_to_fp8(__bfloat162float(a),
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float -> fp8
// template <>
// __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
// return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, float2>(const float2& a) {
// union {
// half2 float16;
// uint32_t uint32;
// };
// float16 = __float22half2_rn(a);
// return uint32;
// }
// // Float4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
// uint2 b;
// float2 val;
// val.x = a.x.x;
// val.y = a.x.y;
// b.x = vec_conversion<uint32_t, float2>(val);
// val.x = a.y.x;
// val.y = a.y.y;
// b.y = vec_conversion<uint32_t, float2>(val);
// return b;
// }
// // Float4 -> float4
// template <>
// __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
// float4 b;
// b.x = a.x.x;
// b.y = a.x.y;
// b.z = a.y.x;
// b.w = a.y.y;
// return b;
// }
// // Float8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
// uint4 b;
// b.x = vec_conversion<uint32_t, float2>(a.x);
// b.y = vec_conversion<uint32_t, float2>(a.y);
// b.z = vec_conversion<uint32_t, float2>(a.z);
// b.w = vec_conversion<uint32_t, float2>(a.w);
// return b;
// }
// // float2 -> bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, float2>(const float2& a) {
// __nv_bfloat162 b = __float22bfloat162_rn(a);
// return b;
// }
// // Float4 -> bfloat162x2
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
// bf16_4_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// return b;
// }
// // Float8 -> bfloat162x4
// template <>
// __inline__ __device__ bf16_8_t
// vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
// bf16_8_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// b.z = __float22bfloat162_rn(a.z);
// b.w = __float22bfloat162_rn(a.w);
// return b;
// }
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
Convention of the scale in API, e.g: FP8_data = Quantization(
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
scale => HP
*/
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat16
=
__hip_bfloat16
;
// fp8 -> __nv_bfloat16
// fp8 -> __nv_bfloat16
...
@@ -347,9 +68,6 @@ __inline__ __device__ __nv_bfloat16
...
@@ -347,9 +68,6 @@ __inline__ __device__ __nv_bfloat16
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
__float2bfloat16
(
fp8_to_float
(
a
)
*
scale
);
return
__float2bfloat16
(
fp8_to_float
(
a
)
*
scale
);
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8) * scale);
}
}
// fp8x2 -> __nv_bfloat162
// fp8x2 -> __nv_bfloat162
...
@@ -395,9 +113,6 @@ template <>
...
@@ -395,9 +113,6 @@ template <>
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
const
uint8_t
&
a
,
float
scale
)
{
return
fp8_to_float
(
a
)
*
scale
;
return
fp8_to_float
(
a
)
*
scale
;
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8) * scale;
}
}
// fp8x2 -> float2
// fp8x2 -> float2
...
@@ -408,10 +123,6 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
...
@@ -408,10 +123,6 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
f2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
f2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
f2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
f2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
return
f2r
;
return
f2r
;
// [[maybe_unused]]
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2) * scale;
}
}
// fp8x4 -> float4
// fp8x4 -> float4
...
@@ -453,9 +164,6 @@ __inline__ __device__ uint16_t
...
@@ -453,9 +164,6 @@ __inline__ __device__ uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
float
res
=
fp8_to_float
(
a
)
*
scale
;
float
res
=
fp8_to_float
(
a
)
*
scale
;
return
float_to_half
(
res
);
return
float_to_half
(
res
);
// __half_raw res;
// res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
// return res.x;
}
}
// fp8x2 -> half2
// fp8x2 -> half2
...
@@ -469,16 +177,6 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
...
@@ -469,16 +177,6 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
return
res
.
u32
;
return
res
.
u32
;
// [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// tmp.h2r.x.data *= scale;
// tmp.h2r.y.data *= scale;
// return tmp.ui32;
}
}
// fp8x4 -> half2x2
// fp8x4 -> half2x2
...
@@ -513,11 +211,6 @@ __inline__ __device__ uint8_t
...
@@ -513,11 +211,6 @@ __inline__ __device__ uint8_t
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
float
res_f
=
half_to_float
(
a
)
/
scale
;
return
float_to_fp8
(
res_f
);
return
float_to_fp8
(
res_f
);
// __half_raw tmp;
// tmp.x = a;
// tmp.data /= scale;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
}
// halfx2 -> fp8x2
// halfx2 -> fp8x2
...
@@ -536,15 +229,6 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
...
@@ -536,15 +229,6 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
);
return
tmp
.
ui16
;
return
tmp
.
ui16
;
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// tmp.h2r.x.data /= scale;
// tmp.h2r.y.data /= scale;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
}
// half2x2 -> fp8x4
// half2x2 -> fp8x4
...
@@ -581,9 +265,6 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
...
@@ -581,9 +265,6 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const
__nv_bfloat16
&
a
,
float
scale
)
{
const
__nv_bfloat16
&
a
,
float
scale
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
return
float_to_fp8
(
res_f
);
return
float_to_fp8
(
res_f
);
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
}
// bf16x2 -> fp8x2
// bf16x2 -> fp8x2
...
@@ -627,8 +308,6 @@ template <>
...
@@ -627,8 +308,6 @@ template <>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
)
{
return
float_to_fp8
(
a
/
scale
);
return
float_to_fp8
(
a
/
scale
);
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
}
// floatx2 -> fp8x2
// floatx2 -> fp8x2
...
@@ -642,8 +321,6 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
...
@@ -642,8 +321,6 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
);
return
tmp
.
ui16
;
return
tmp
.
ui16
;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
}
// floatx4 -> fp8x4
// floatx4 -> fp8x4
...
@@ -658,27 +335,113 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
...
@@ -658,27 +335,113 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
);
return
tmp
.
ui32
;
return
tmp
.
ui32
;
}
}
// #endif // ENABLE_FP8
// template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
inline
__device__
uint8_t
float_to_fp8e5m2
(
float
f
)
{
// __inline__ __device__ Tout convert(const Tin& x) {
constexpr
uint32_t
fp32_inf
=
UINT32_C
(
255
)
<<
23
;
// #ifdef ENABLE_FP8
constexpr
uint32_t
fp8_max
=
UINT32_C
(
143
)
<<
23
;
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
134
)
<<
23
;
// return vec_conversion<Tout, Tin>(x);
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
// }
uint8_t
result
=
0u
;
// #endif
const
uint32_t
sign
=
f_bits
&
UINT32_C
(
0x80000000
);
// assert(false);
f_bits
^=
sign
;
// return {}; // Squash missing return statement warning
if
(
f_bits
>=
fp8_max
)
{
// }
result
=
f_bits
>
fp32_inf
?
UINT8_C
(
0x7F
)
:
UINT8_C
(
0x7C
);
}
else
{
if
(
f_bits
<
(
UINT32_C
(
113
)
<<
23
))
{
f_bits
=
c10
::
detail
::
fp32_to_bits
(
c10
::
detail
::
fp32_from_bits
(
f_bits
)
+
c10
::
detail
::
fp32_from_bits
(
denorm_mask
));
result
=
static_cast
<
uint8_t
>
(
f_bits
-
denorm_mask
);
}
else
{
uint32_t
mant_odd
=
(
f_bits
>>
21
)
&
1
;
f_bits
+=
((
uint32_t
)(
15
-
127
)
<<
23
)
+
0xFFFFF
;
f_bits
+=
mant_odd
;
result
=
static_cast
<
uint8_t
>
(
f_bits
>>
21
);
}
}
result
|=
static_cast
<
uint8_t
>
(
sign
>>
24
);
return
result
;
}
// fp8
template
<
typename
Tin
>
__inline__
__device__
uint8_t
scaled_vec_conversion_to_e5m2
(
const
Tin
&
a
,
float
scale
)
{
return
0
;
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_to_e5m2
<
float
>
(
const
float
&
a
,
float
scale
)
{
return
float_to_fp8e5m2
(
a
/
scale
);
}
// half -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_to_e5m2
<
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
return
float_to_fp8e5m2
(
res_f
);
}
// bf16 -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_to_e5m2
<
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
float
scale
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
return
float_to_fp8e5m2
(
res_f
);
}
inline
__device__
float
fp8e5m2_to_fp32
(
const
uint8_t
&
input
)
{
union
uf16
{
uint16_t
as_bits
;
_Float16
as_value
;
}
;
uf16
u16
;
u16
.
as_bits
=
(
uint16_t
)
input
<<
8
;
return
(
float
)
u16
.
as_value
;
}
template
<
typename
Tout
>
__inline__
__device__
Tout
scaled_vec_conversion_from_e5m2
(
const
uint8_t
&
a
,
float
scale
)
{
return
0
;
}
// fp8 -> float
template
<
>
__inline__
__device__
float
scaled_vec_conversion_from_e5m2
<
float
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
fp8e5m2_to_fp32
(
a
)
*
scale
;
}
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion_from_e5m2
<
uint16_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
float_to_half
(
fp8e5m2_to_fp32
(
a
)
*
scale
);
}
// fp8 -> bf16
template
<
>
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion_from_e5m2
<
__nv_bfloat16
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
__float2bfloat16
(
fp8e5m2_to_fp32
(
a
)
*
scale
);
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
// #ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
// }
}
// #endif
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tout
)
==
1
){
// assert(false);
return
scaled_vec_conversion_to_e5m2
<
Tin
>
(
x
,
scale
);
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tin
)
==
1
){
return
scaled_vec_conversion_from_e5m2
<
Tout
>
(
x
,
scale
);
}
return
{};
// Squash missing return statement warning
return
{};
// Squash missing return statement warning
}
}
...
@@ -686,7 +449,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -686,7 +449,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// the data type of the key and value cache. The FN is a macro that calls a
// the data type of the key and value cache. The FN is a macro that calls a
// function with template<typename scalar_t, typename cache_t,
// function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
...
@@ -719,11 +482,23 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -719,11 +482,23 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
} \
}
}
}
// namespace fp8
}
// namespace fp8
#endif // USE_ROCM
#endif // USE_ROCM
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
vllm/_custom_ops.py
View file @
ed3cdc81
...
@@ -2166,12 +2166,14 @@ def gather_cache(src_cache: torch.Tensor,
...
@@ -2166,12 +2166,14 @@ def gather_cache(src_cache: torch.Tensor,
kv_dtype
=
"auto"
,
kv_dtype
=
"auto"
,
scale
:
float
=
1.0
,
scale
:
float
=
1.0
,
)
->
None
:
)
->
None
:
#支持"kv cache fp8"
#支持"kv cache fp8"
临时方案,带dtype的gather_cache在vllm0.10后会实现。
if
kv_dtype
==
"fp8"
:
if
kv_dtype
==
"fp8"
or
kv_dtype
==
"fp8_e5m2"
or
kv_dtype
==
"fp8_e4m3"
:
dst_fp8
=
torch
.
zeros
(
dst
.
shape
,
dtype
=
torch
.
uint8
,
device
=
dst
.
device
)
dst_fp8
=
torch
.
empty
(
dst
.
shape
,
dtype
=
torch
.
uint8
,
device
=
dst
.
device
)
convert_fp8
(
dst_fp8
,
dst
,
scale
,
kv_dtype
)
#
convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst_fp8
,
block_table
,
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst_fp8
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
cu_seq_lens
,
batch_size
,
seq_starts
)
#dst_fp8->bf16
convert_fp8
(
dst
,
dst_fp8
,
scale
,
kv_dtype
)
else
:
else
:
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
cu_seq_lens
,
batch_size
,
seq_starts
)
...
...
vllm/attention/backends/flashmla.py
View file @
ed3cdc81
...
@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
self
.
kv_cache_dtype
!=
"fp8"
:
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
raise
NotImplementedError
(
return
"FlashMLA with other KV cache not yet supported"
)
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
def
_forward_decode
(
self
,
self
,
...
...
vllm/attention/layer.py
View file @
ed3cdc81
...
@@ -24,6 +24,12 @@ from vllm.platforms import _Backend, current_platform
...
@@ -24,6 +24,12 @@ from vllm.platforms import _Backend, current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
USE_XFORMERS_OPS
=
None
try
:
tag_cudagraph_unsafe
=
(
torch
.
_C
.
Tag
.
cudagraph_unsafe
,
)
except
AttributeError
:
tag_cudagraph_unsafe
=
()
# type: ignore[assignment]
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
"""Attention layer.
"""Attention layer.
...
@@ -204,10 +210,12 @@ class Attention(nn.Module):
...
@@ -204,10 +210,12 @@ class Attention(nn.Module):
`vllm.forward_context.get_forward_context().attn_metadata`.
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
"""
if
self
.
calculate_kv_scales
:
if
self
.
calculate_kv_scales
:
attn_metadata
=
get_forward_context
().
attn_metadata
#
attn_metadata = get_forward_context().attn_metadata
if
(
attn_metadata
is
not
None
and
getattr
(
attn_metadata
,
"enable_kv_scales_calculation"
,
False
)):
# #
if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None:
# if key is not None and value is not None:
self
.
calc_kv_scales
(
query
,
key
,
value
)
# self.calc_kv_scales(query, key, value)
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
query
,
key
,
value
,
self
.
layer_name
)
if
self
.
use_output
:
if
self
.
use_output
:
output_shape
=
(
output_shape
output_shape
=
(
output_shape
if
output_shape
is
not
None
else
query
.
shape
)
if
output_shape
is
not
None
else
query
.
shape
)
...
@@ -397,6 +405,44 @@ def maybe_save_kv_layer_to_connector(
...
@@ -397,6 +405,44 @@ def maybe_save_kv_layer_to_connector(
attn_metadata
[
layer_name
])
attn_metadata
[
layer_name
])
def
maybe_calc_kv_scales
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
calc_kv_scales
(
query
,
key
,
value
)
def
maybe_calc_kv_scales_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"maybe_calc_kv_scales"
,
op_func
=
maybe_calc_kv_scales
,
mutates_args
=
[],
fake_impl
=
maybe_calc_kv_scales_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
tag_cudagraph_unsafe
,)
def
unified_attention
(
def
unified_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
...
vllm/attention/ops/flashmla.py
View file @
ed3cdc81
...
@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
...
@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
kv_dtype
=
"fp8_e4m3"
if
kv_cache_dtype
==
"fp8"
else
kv_cache_dtype
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_mla
(
q
,
q
,
k_cache
,
k_cache
,
...
@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
...
@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
,
num_splits
,
k_scale
,
k_scale
,
"fp8_e4m3"
,
kv_dtype
,
)
)
return
out
,
softmax_lse
return
out
,
softmax_lse
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
...
...
vllm/utils/__init__.py
View file @
ed3cdc81
...
@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16"
:
torch
.
bfloat16
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
"float"
:
torch
.
float
,
"fp8"
:
torch
.
uint8
,
"fp8"
:
torch
.
uint8
,
#
"fp8_e4m3": torch.uint8,
"fp8_e4m3"
:
torch
.
uint8
,
#
"fp8_e5m2": torch.uint8,
"fp8_e5m2"
:
torch
.
uint8
,
"int8"
:
torch
.
int8
,
"int8"
:
torch
.
int8
,
}
}
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
ed3cdc81
...
@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
self
.
kv_cache_dtype
!=
"fp8"
:
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
raise
NotImplementedError
(
return
"FlashMLA with other KV cache not yet supported"
)
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
def
_forward_decode
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment