Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cea85c38
Commit
cea85c38
authored
Jan 16, 2026
by
zhuwenwen
Browse files
Merge branch 'v0.11.0-dev' of
http://10.16.6.30/dcutoolkit/deeplearing/vllm
into v0.11.0-dev
parents
6d8c8719
bc80af59
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
148 additions
and
77 deletions
+148
-77
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+16
-3
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+127
-69
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+2
-4
vllm/envs.py
vllm/envs.py
+1
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+2
-0
No files found.
csrc/custom_all_reduce.cuh
View file @
cea85c38
...
...
@@ -644,9 +644,22 @@ class CustomAllreduce {
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
std
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size, dev_curr_hdp_reg, world_size_) ;
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, dev_curr_hdp_reg, world_size_) ;
{ \
void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \
case ngpus: { \
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
cea85c38
...
...
@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) {
}
// float -> fp8
static
inline
__device__
uint8_t
float_to_fp8
(
float
f
)
{
static
inline
__device__
uint8_t
float_to_fp8
_e4m3
(
float
f
)
{
constexpr
uint32_t
fp8_max
=
UINT32_C
(
1087
)
<<
20
;
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
141
)
<<
23
;
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
...
...
@@ -53,6 +53,31 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return
result
;
}
static
inline
__device__
uint8_t
float_to_fp8_e5m2
(
float
f
)
{
constexpr
uint32_t
fp32_inf
=
UINT32_C
(
255
)
<<
23
;
constexpr
uint32_t
fp8_max
=
UINT32_C
(
143
)
<<
23
;
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
134
)
<<
23
;
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
uint8_t
result
=
0u
;
const
uint32_t
sign
=
f_bits
&
UINT32_C
(
0x80000000
);
f_bits
^=
sign
;
if
(
f_bits
>=
fp8_max
)
{
result
=
f_bits
>
fp32_inf
?
UINT8_C
(
0x7F
)
:
UINT8_C
(
0x7C
);
}
else
{
if
(
f_bits
<
(
UINT32_C
(
113
)
<<
23
))
{
f_bits
=
c10
::
detail
::
fp32_to_bits
(
c10
::
detail
::
fp32_from_bits
(
f_bits
)
+
c10
::
detail
::
fp32_from_bits
(
denorm_mask
));
result
=
static_cast
<
uint8_t
>
(
f_bits
-
denorm_mask
);
}
else
{
uint32_t
mant_odd
=
(
f_bits
>>
21
)
&
1
;
f_bits
+=
((
uint32_t
)(
15
-
127
)
<<
23
)
+
0xFFFFF
;
f_bits
+=
mant_odd
;
result
=
static_cast
<
uint8_t
>
(
f_bits
>>
21
);
}
}
result
|=
static_cast
<
uint8_t
>
(
sign
>>
24
);
return
result
;
}
// template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x;
...
...
@@ -60,7 +85,7 @@ static inline __device__ uint8_t float_to_fp8(float f) {
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
const
float
scale
)
{
const
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
return
x
;
}
...
...
@@ -344,8 +369,10 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template
<
>
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
return
__float2bfloat16
(
fp8_to_float
(
a
)
*
scale
);
// fp8_type f8;
// f8.__x = a;
...
...
@@ -356,32 +383,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
template
<
>
__inline__
__device__
__nv_bfloat162
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
__nv_bfloat162
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
res
;
}
// fp8x4 -> bf16_4_t
template
<
>
__inline__
__device__
bf16_4_t
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
bf16_4_t
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
scale
,
kv_type
);
return
res
;
}
// fp8x8 -> bf16_8_t
template
<
>
__inline__
__device__
bf16_8_t
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
bf16_4_t
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
);
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
bf16_8_t
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
...
...
@@ -393,7 +420,10 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float
template
<
>
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
return
fp8_to_float
(
a
)
*
scale
;
// fp8_type f8;
// f8.__x = a;
...
...
@@ -403,10 +433,10 @@ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
// fp8x2 -> float2
template
<
>
__inline__
__device__
float2
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float2
f2r
;
f2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
f2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
f2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
f2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
f2r
;
// [[maybe_unused]]
// fp8x2_type f8x2;
...
...
@@ -417,28 +447,28 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
// fp8x4 -> float4
template
<
>
__inline__
__device__
Float4_
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
Float4_
res
;
res
.
x
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
y
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
res
.
x
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
,
kv_type
);
return
res
;
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
Float4_
res
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
);
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
Float4_
res
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
,
kv_type
);
return
{
res
.
x
.
x
,
res
.
x
.
y
,
res
.
y
.
x
,
res
.
y
.
y
};
}
// fp8x8 -> float8
template
<
>
__inline__
__device__
Float8_
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
Float4_
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
);
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
Float8_
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
...
...
@@ -450,7 +480,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
float
res
=
fp8_to_float
(
a
)
*
scale
;
return
float_to_half
(
res
);
// __half_raw res;
...
...
@@ -461,13 +494,13 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint16_t
u16
[
2
];
uint32_t
u32
;
}
res
;
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
res
.
u32
;
// [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
...
...
@@ -484,35 +517,40 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
// fp8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint2
u32x2
;
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
kv_type
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
,
kv_type
);
return
tmp
.
u32x2
;
}
// fp8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
scaled_vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint4
u64x2
;
uint2
u64
[
2
];
}
tmp
;
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
);
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
u64x2
;
}
// half -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
return
float_to_fp8
(
res_f
);
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
res_f
);
}
else
{
return
float_to_fp8_e5m2
(
res_f
);
}
// __half_raw tmp;
// tmp.x = a;
// tmp.data /= scale;
...
...
@@ -523,7 +561,7 @@ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
// halfx2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint8_t
ui8
[
2
];
uint16_t
ui16
;
...
...
@@ -533,8 +571,8 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
half2
h2r
;
}
tmp_a
;
tmp_a
.
ui32
=
a
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
,
kv_type
);
return
tmp
.
ui16
;
// union {
// uint32_t ui32;
...
...
@@ -550,37 +588,41 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
// half2x2 -> fp8x4
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
y
,
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui32
;
}
// half2x4 -> fp8x8
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint4
>
(
const
uint4
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint2
ui2
[
2
];
uint4
ui4
;
}
tmp
;
tmp
.
ui4
=
a
;
uint2
res
;
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
0
],
scale
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
1
],
scale
);
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
0
],
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
1
],
scale
,
kv_type
);
return
res
;
}
// bf16 -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
float
scale
)
{
const
__nv_bfloat16
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
return
float_to_fp8
(
res_f
);
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
res_f
);
}
else
{
return
float_to_fp8_e5m2
(
res_f
);
}
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
...
...
@@ -589,44 +631,48 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
// bf16x2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
const
__nv_bfloat162
&
a
,
float
scale
)
{
const
__nv_bfloat162
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint8_t
ui8
[
2
];
uint16_t
ui16
;
}
tmp
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
y
,
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui16
;
}
// bf16x4 -> fp8x4
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
(
const
bf16_4_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
(
const
bf16_4_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
x
,
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
y
,
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui32
;
}
// bf16x8 -> fp8x8
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
bf16_8_t
>
(
const
bf16_8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint2
,
bf16_8_t
>
(
const
bf16_8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
uint2
res
;
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
x
,
a
.
y
},
scale
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
z
,
a
.
w
},
scale
);
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
x
,
a
.
y
},
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
z
,
a
.
w
},
scale
,
kv_type
);
return
res
;
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
)
{
return
float_to_fp8
(
a
/
scale
);
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
a
/
scale
);
}
else
{
return
float_to_fp8_e5m2
(
a
/
scale
);
}
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
...
...
@@ -634,13 +680,13 @@ scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
// floatx2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
float2
>
(
const
float2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
float2
>
(
const
float2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint8_t
ui8
[
2
];
uint16_t
ui16
;
}
tmp
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui16
;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
...
...
@@ -649,13 +695,13 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
// floatx4 -> fp8x4
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
float4
>
(
const
float4
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
float4
>
(
const
float4
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
x
,
a
.
y
},
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
x
,
a
.
y
},
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
,
kv_type
);
return
tmp
.
ui32
;
}
// #endif // ENABLE_FP8
...
...
@@ -674,11 +720,11 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
// #ifdef ENABLE_FP8
//
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
//
}
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
||
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
,
kv_dt
);
}
// #endif
//
assert(false);
assert
(
false
);
return
{};
// Squash missing return statement warning
}
...
...
@@ -719,6 +765,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} \
else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
cea85c38
...
...
@@ -278,10 +278,8 @@ class CustomAllreduce:
if
envs
.
VLLM_CUSTOM_CACHE
:
return
self
.
all_reduce
(
input
,
registered
=
True
)
else
:
if
not
self
.
fully_connected
:
return
self
.
all_reduce
(
input
,
registered
=
False
)
else
:
return
self
.
all_reduce
(
input
,
registered
=
True
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
...
...
vllm/envs.py
View file @
cea85c38
...
...
@@ -1565,7 +1565,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE"
:
lambda
:
bool
(
int
(
os
.
environ
.
get
(
"VLLM_CUSTOM_CACHE"
,
"
0
"
))),
lambda
:
bool
(
int
(
os
.
environ
.
get
(
"VLLM_CUSTOM_CACHE"
,
"
1
"
))),
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX"
:
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
cea85c38
...
...
@@ -143,6 +143,8 @@ class FlashAttentionBackend(AttentionBackend):
def
get_fp8_dtype_for_flashattn
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
if
kv_cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
):
return
torch
.
float8_e4m3fn
elif
kv_cache_dtype
in
(
"fp8_e5m2"
):
return
torch
.
float8_e5m2
else
:
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment