Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
83f2f396
Commit
83f2f396
authored
Sep 30, 2025
by
王敏
Browse files
同步0.9.2-ds分支代码
parents
d2e57a90
20605c42
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
425 additions
and
437 deletions
+425
-437
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+16
-0
csrc/ops.h
csrc/ops.h
+0
-2
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+116
-351
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-3
vllm/_custom_ops.py
vllm/_custom_ops.py
+16
-3
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+4
-3
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+6
-2
vllm/attention/layer.py
vllm/attention/layer.py
+46
-3
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+3
-2
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+33
-5
vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py
...ted/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py
+1
-1
vllm/envs.py
vllm/envs.py
+22
-11
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+56
-23
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+21
-11
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+8
-5
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+45
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+21
-5
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+3
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+4
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+4
-3
No files found.
csrc/cache_kernels.cu
View file @
83f2f396
...
...
@@ -965,6 +965,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
uint8_t
,
float
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint8_t
,
uint16_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
uint8_t
,
__nv_bfloat16
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
float
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint16_t
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
kv_cache_dtype
);
}
...
...
csrc/ops.h
View file @
83f2f396
...
...
@@ -179,7 +179,6 @@ void merge_attn_states(torch::Tensor& output,
const
torch
::
Tensor
&
suffix_lse
);
#ifndef USE_ROCM
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
...
...
@@ -204,7 +203,6 @@ void convert_vertical_slash_indexes_mergehead(
torch
::
Tensor
vertical_indices_count
,
// [N_HEADS, ]
torch
::
Tensor
slash_indices_count
,
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
);
#endif
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
83f2f396
...
...
@@ -53,10 +53,6 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return
result
;
}
// template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x;
// }
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
...
...
@@ -64,281 +60,6 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return
x
;
}
// #if HIP_FP8_TYPE_OCP
// using fp8_type = __hip_fp8_e4m3;
// using fp8x2_type = __hip_fp8x2_e4m3;
// #else
// using fp8_type = __hip_fp8_e4m3_fnuz;
// using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
// #endif
// // fp8 -> half
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
// return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
// }
// // fp8x2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// return tmp.ui32;
// }
// // fp8x4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
// union {
// uint2 u32x2;
// uint32_t u32[2];
// } tmp;
// tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
// tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
// return tmp.u32x2;
// }
// // fp8x8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
// union {
// uint4 u64x2;
// uint2 u64[2];
// } tmp;
// tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
// tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
// return tmp.u64x2;
// }
// using __nv_bfloat16 = __hip_bfloat16;
// // fp8 -> __nv_bfloat16
// template <>
// __inline__ __device__ __nv_bfloat16
// vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8));
// }
// using __nv_bfloat162 = __hip_bfloat162;
// // fp8x2 -> __nv_bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
// __nv_bfloat162 res;
// res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
// res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
// return res;
// }
// // fp8x4 -> bf16_4_t
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
// bf16_4_t res;
// res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
// res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x8 -> bf16_8_t
// template <>
// __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// bf16_4_t tmp1, tmp2;
// tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
// tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
// bf16_8_t res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // fp8 -> float
// template <>
// __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8);
// }
// // fp8x2 -> float2
// template <>
// __inline__ __device__ float2
// vec_conversion<float2, uint16_t>(const uint16_t& a) {
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2);
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ Float4_
// vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
// Float4_ res;
// res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
// res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ float4
// vec_conversion<float4, uint32_t>(const uint32_t& a) {
// Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
// float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
// return res;
// }
// // fp8x8 -> float8
// template <>
// __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
// Float4_ tmp1, tmp2;
// tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
// tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
// Float8_ res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // half -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
// __half_raw tmp;
// tmp.x = a;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // bf16 -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
// return __hip_cvt_float_to_fp8(__bfloat162float(a),
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float -> fp8
// template <>
// __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
// return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, float2>(const float2& a) {
// union {
// half2 float16;
// uint32_t uint32;
// };
// float16 = __float22half2_rn(a);
// return uint32;
// }
// // Float4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
// uint2 b;
// float2 val;
// val.x = a.x.x;
// val.y = a.x.y;
// b.x = vec_conversion<uint32_t, float2>(val);
// val.x = a.y.x;
// val.y = a.y.y;
// b.y = vec_conversion<uint32_t, float2>(val);
// return b;
// }
// // Float4 -> float4
// template <>
// __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
// float4 b;
// b.x = a.x.x;
// b.y = a.x.y;
// b.z = a.y.x;
// b.w = a.y.y;
// return b;
// }
// // Float8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
// uint4 b;
// b.x = vec_conversion<uint32_t, float2>(a.x);
// b.y = vec_conversion<uint32_t, float2>(a.y);
// b.z = vec_conversion<uint32_t, float2>(a.z);
// b.w = vec_conversion<uint32_t, float2>(a.w);
// return b;
// }
// // float2 -> bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, float2>(const float2& a) {
// __nv_bfloat162 b = __float22bfloat162_rn(a);
// return b;
// }
// // Float4 -> bfloat162x2
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
// bf16_4_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// return b;
// }
// // Float8 -> bfloat162x4
// template <>
// __inline__ __device__ bf16_8_t
// vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
// bf16_8_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// b.z = __float22bfloat162_rn(a.z);
// b.w = __float22bfloat162_rn(a.w);
// return b;
// }
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
Convention of the scale in API, e.g: FP8_data = Quantization(
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
scale => HP
*/
using
__nv_bfloat16
=
__hip_bfloat16
;
// fp8 -> __nv_bfloat16
...
...
@@ -347,9 +68,6 @@ __inline__ __device__ __nv_bfloat16
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
__float2bfloat16
(
fp8_to_float
(
a
)
*
scale
);
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8) * scale);
}
// fp8x2 -> __nv_bfloat162
...
...
@@ -395,9 +113,6 @@ template <>
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
fp8_to_float
(
a
)
*
scale
;
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8) * scale;
}
// fp8x2 -> float2
...
...
@@ -408,10 +123,6 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
f2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
f2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
return
f2r
;
// [[maybe_unused]]
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2) * scale;
}
// fp8x4 -> float4
...
...
@@ -453,9 +164,6 @@ __inline__ __device__ uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
float
res
=
fp8_to_float
(
a
)
*
scale
;
return
float_to_half
(
res
);
// __half_raw res;
// res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
// return res.x;
}
// fp8x2 -> half2
...
...
@@ -469,16 +177,6 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
return
res
.
u32
;
// [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// tmp.h2r.x.data *= scale;
// tmp.h2r.y.data *= scale;
// return tmp.ui32;
}
// fp8x4 -> half2x2
...
...
@@ -513,11 +211,6 @@ __inline__ __device__ uint8_t
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
return
float_to_fp8
(
res_f
);
// __half_raw tmp;
// tmp.x = a;
// tmp.data /= scale;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// halfx2 -> fp8x2
...
...
@@ -536,15 +229,6 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
);
return
tmp
.
ui16
;
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// tmp.h2r.x.data /= scale;
// tmp.h2r.y.data /= scale;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// half2x2 -> fp8x4
...
...
@@ -581,9 +265,6 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const
__nv_bfloat16
&
a
,
float
scale
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
return
float_to_fp8
(
res_f
);
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// bf16x2 -> fp8x2
...
...
@@ -627,8 +308,6 @@ template <>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
)
{
return
float_to_fp8
(
a
/
scale
);
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// floatx2 -> fp8x2
...
...
@@ -642,8 +321,6 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
);
return
tmp
.
ui16
;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// floatx4 -> fp8x4
...
...
@@ -658,27 +335,113 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
);
return
tmp
.
ui32
;
}
// #endif // ENABLE_FP8
// template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
// __inline__ __device__ Tout convert(const Tin& x) {
// #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
// return vec_conversion<Tout, Tin>(x);
// }
// #endif
// assert(false);
// return {}; // Squash missing return statement warning
// }
inline
__device__
uint8_t
float_to_fp8e5m2
(
float
f
)
{
constexpr
uint32_t
fp32_inf
=
UINT32_C
(
255
)
<<
23
;
constexpr
uint32_t
fp8_max
=
UINT32_C
(
143
)
<<
23
;
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
134
)
<<
23
;
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
uint8_t
result
=
0u
;
const
uint32_t
sign
=
f_bits
&
UINT32_C
(
0x80000000
);
f_bits
^=
sign
;
if
(
f_bits
>=
fp8_max
)
{
result
=
f_bits
>
fp32_inf
?
UINT8_C
(
0x7F
)
:
UINT8_C
(
0x7C
);
}
else
{
if
(
f_bits
<
(
UINT32_C
(
113
)
<<
23
))
{
f_bits
=
c10
::
detail
::
fp32_to_bits
(
c10
::
detail
::
fp32_from_bits
(
f_bits
)
+
c10
::
detail
::
fp32_from_bits
(
denorm_mask
));
result
=
static_cast
<
uint8_t
>
(
f_bits
-
denorm_mask
);
}
else
{
uint32_t
mant_odd
=
(
f_bits
>>
21
)
&
1
;
f_bits
+=
((
uint32_t
)(
15
-
127
)
<<
23
)
+
0xFFFFF
;
f_bits
+=
mant_odd
;
result
=
static_cast
<
uint8_t
>
(
f_bits
>>
21
);
}
}
result
|=
static_cast
<
uint8_t
>
(
sign
>>
24
);
return
result
;
}
// fp8
template
<
typename
Tin
>
__inline__
__device__
uint8_t
scaled_vec_conversion_to_e5m2
(
const
Tin
&
a
,
float
scale
)
{
return
0
;
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_to_e5m2
<
float
>
(
const
float
&
a
,
float
scale
)
{
return
float_to_fp8e5m2
(
a
/
scale
);
}
// half -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_to_e5m2
<
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
return
float_to_fp8e5m2
(
res_f
);
}
// bf16 -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_to_e5m2
<
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
float
scale
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
return
float_to_fp8e5m2
(
res_f
);
}
inline
__device__
float
fp8e5m2_to_fp32
(
const
uint8_t
&
input
)
{
union
uf16
{
uint16_t
as_bits
;
_Float16
as_value
;
}
;
uf16
u16
;
u16
.
as_bits
=
(
uint16_t
)
input
<<
8
;
return
(
float
)
u16
.
as_value
;
}
template
<
typename
Tout
>
__inline__
__device__
Tout
scaled_vec_conversion_from_e5m2
(
const
uint8_t
&
a
,
float
scale
)
{
return
0
;
}
// fp8 -> float
template
<
>
__inline__
__device__
float
scaled_vec_conversion_from_e5m2
<
float
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
fp8e5m2_to_fp32
(
a
)
*
scale
;
}
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion_from_e5m2
<
uint16_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
float_to_half
(
fp8e5m2_to_fp32
(
a
)
*
scale
);
}
// fp8 -> bf16
template
<
>
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion_from_e5m2
<
__nv_bfloat16
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
__float2bfloat16
(
fp8e5m2_to_fp32
(
a
)
*
scale
);
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
// #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
// }
// #endif
// assert(false);
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tout
)
==
1
){
return
scaled_vec_conversion_to_e5m2
<
Tin
>
(
x
,
scale
);
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tin
)
==
1
){
return
scaled_vec_conversion_from_e5m2
<
Tout
>
(
x
,
scale
);
}
return
{};
// Squash missing return statement warning
}
...
...
@@ -686,7 +449,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// the data type of the key and value cache. The FN is a macro that calls a
// function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
...
...
@@ -697,16 +460,6 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "int8") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else { \
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
...
...
@@ -719,11 +472,23 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
}
// namespace fp8
#endif // USE_ROCM
}
// namespace vllm
\ No newline at end of file
}
// namespace vllm
csrc/torch_bindings.cpp
View file @
83f2f396
...
...
@@ -229,8 +229,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()"
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
#ifndef USE_ROCM
ops
.
def
(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
...
...
@@ -253,7 +251,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" bool causal) -> ()"
);
ops
.
impl
(
"convert_vertical_slash_indexes_mergehead"
,
torch
::
kCUDA
,
&
convert_vertical_slash_indexes_mergehead
);
#endif
// Activation ops
// Activation function used in SwiGLU.
...
...
vllm/_custom_ops.py
View file @
83f2f396
...
...
@@ -2162,9 +2162,22 @@ def gather_cache(src_cache: torch.Tensor,
block_table
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
batch_size
:
int
,
seq_starts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
seq_starts
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_dtype
=
"auto"
,
scale
:
float
=
1.0
,
)
->
None
:
#支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if
kv_dtype
==
"fp8"
or
kv_dtype
==
"fp8_e5m2"
or
kv_dtype
==
"fp8_e4m3"
:
dst_fp8
=
torch
.
empty
(
dst
.
shape
,
dtype
=
torch
.
uint8
,
device
=
dst
.
device
)
#convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst_fp8
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
#dst_fp8->bf16
convert_fp8
(
dst
,
dst_fp8
,
scale
,
kv_dtype
)
else
:
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
...
...
vllm/attention/backends/flashmla.py
View file @
83f2f396
...
...
@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
self
.
kv_cache_dtype
!=
"fp8"
:
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
return
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
self
,
...
...
vllm/attention/backends/mla/common.py
View file @
83f2f396
...
...
@@ -1179,6 +1179,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
kv_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
),
):
prefill_metadata
=
attn_metadata
.
prefill_metadata
assert
prefill_metadata
is
not
None
...
...
@@ -1207,6 +1208,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
cu_seq_lens
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
batch_size
=
prefill_metadata
.
num_prefills
,
seq_starts
=
prefill_metadata
.
context_chunk_starts
[
i
],
kv_dtype
=
self
.
kv_cache_dtype
,
scale
=
kv_scale
,
)
kv_c_normed
=
workspace
[:
toks
]
\
...
...
@@ -1262,6 +1265,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
kv_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
),
)
->
torch
.
Tensor
:
prefill_metadata
=
attn_metadata
.
prefill_metadata
...
...
@@ -1297,7 +1301,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
kv_scale
)
output
=
torch
.
empty_like
(
suffix_output
)
merge_attn_states
(
...
...
@@ -1387,7 +1391,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if
has_prefill
:
output
[:
num_prefill_tokens
]
=
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
)
attn_metadata
,
kv_scale
=
layer
.
_k_scale
)
if
has_decode
:
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
...
...
vllm/attention/layer.py
View file @
83f2f396
...
...
@@ -24,6 +24,11 @@ from vllm.platforms import _Backend, current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
USE_XFORMERS_OPS
=
None
try
:
tag_cudagraph_unsafe
=
(
torch
.
_C
.
Tag
.
cudagraph_unsafe
,
)
except
AttributeError
:
tag_cudagraph_unsafe
=
()
# type: ignore[assignment]
class
Attention
(
nn
.
Module
):
"""Attention layer.
...
...
@@ -204,9 +209,12 @@ class Attention(nn.Module):
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if
self
.
calculate_kv_scales
:
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
enable_kv_scales_calculation
:
self
.
calc_kv_scales
(
query
,
key
,
value
)
# attn_metadata = get_forward_context().attn_metadata
# #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None:
# self.calc_kv_scales(query, key, value)
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
query
,
key
,
value
,
self
.
layer_name
)
if
self
.
use_output
:
output_shape
=
(
output_shape
if
output_shape
is
not
None
else
query
.
shape
)
...
...
@@ -394,7 +402,42 @@ def maybe_save_kv_layer_to_connector(
assert
isinstance
(
attn_metadata
,
dict
)
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
[
layer_name
])
def
maybe_calc_kv_scales
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
calc_kv_scales
(
query
,
key
,
value
)
def
maybe_calc_kv_scales_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"maybe_calc_kv_scales"
,
op_func
=
maybe_calc_kv_scales
,
mutates_args
=
[],
fake_impl
=
maybe_calc_kv_scales_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
tag_cudagraph_unsafe
,)
def
unified_attention
(
query
:
torch
.
Tensor
,
...
...
vllm/attention/ops/flashmla.py
View file @
83f2f396
...
...
@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
current_platform
.
is_rocm
():
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
kv_dtype
=
"fp8_e4m3"
if
kv_cache_dtype
==
"fp8"
else
kv_cache_dtype
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_mla
(
q
,
k_cache
,
...
...
@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata
,
num_splits
,
k_scale
,
"fp8_e4m3"
,
kv_dtype
,
)
return
out
,
softmax_lse
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
83f2f396
...
...
@@ -55,7 +55,7 @@ class ReqMeta:
slot_mapping
=
slot_mapping
,
)
@
dataclass
class
P2pNcclConnectorMetadata
(
KVConnectorMetadata
):
requests
:
list
[
ReqMeta
]
...
...
@@ -95,6 +95,12 @@ class P2pNcclConnector(KVConnectorBase_V1):
hostname
=
""
,
port_offset
=
self
.
_rank
,
)
if
role
==
KVConnectorRole
.
WORKER
else
None
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
model_config
=
vllm_config
.
model_config
self
.
total_num_hidden_layers
=
getattr
(
self
.
model_config
.
hf_text_config
,
"num_hidden_layers"
,
0
)
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
# ==============================
# Worker-side methods
...
...
@@ -285,13 +291,35 @@ class P2pNcclConnector(KVConnectorBase_V1):
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
pp_size
==
1
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
))
else
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
))
elif
(
self
.
pp_size
==
8
):
for
i
in
range
(
8
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
i
))
else
:
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
def
wait_for_save
(
self
):
if
self
.
is_producer
:
assert
self
.
p2p_nccl_engine
is
not
None
self
.
p2p_nccl_engine
.
wait_for_sent
()
pass
# if self.is_producer:
# assert self.p2p_nccl_engine is not None
# self.p2p_nccl_engine.wait_for_sent()
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
],
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py
View file @
83f2f396
...
...
@@ -63,7 +63,7 @@ class TensorMemoryPool:
than min_block_size
"""
def
__init__
(
self
,
max_block_size
:
int
,
min_block_size
:
int
=
5
12
):
def
__init__
(
self
,
max_block_size
:
int
,
min_block_size
:
int
=
12
8
):
if
max_block_size
<=
0
or
min_block_size
<=
0
:
raise
ValueError
(
"Block sizes must be positive"
)
if
max_block_size
<
min_block_size
:
...
...
vllm/envs.py
View file @
83f2f396
...
...
@@ -164,10 +164,11 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA
:
bool
=
False
VLLM_USE_APEX_RN
:
bool
=
False
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_LIGHT_OP
:
bool
=
False
VLLM_USE_TRITON_CAT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
VLLM_USE_LIGHTOP
:
bool
=
False
VLLM_USE_OPT_CAT
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
VLLM_USE_MORI_EP
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -1050,7 +1051,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If there are any problems during use, use environment variables
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_HAS_CONTEXT_DEFAULT"
,
"
0
"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_HAS_CONTEXT_DEFAULT"
,
"
1
"
))),
# If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN"
:
...
...
@@ -1094,15 +1095,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use
global cache for moe
"VLLM_USE_LIGHT
_
OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHT
_
OP"
,
"
Tru
e"
).
lower
()
in
# vLLM will use
lightop for deepseek-v3
"VLLM_USE_LIGHTOP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP"
,
"
Fals
e"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use
global cache for moe
"VLLM_USE_
TRITON
_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_
TRITON
_CAT"
,
"
Tru
e"
).
lower
()
in
# vLLM will use
opt cat for deepseek-v3
"VLLM_USE_
OPT
_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_
OPT
_CAT"
,
"
Fals
e"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use opt merge_aatn_states,not triton
# vLLM will use opt merge_aatn_states,
not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
...
...
@@ -1111,6 +1112,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
getenv
(
'USE_FUSED_RMS_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
# vllm will use lightop's moe_sum fusion operator for deepseek
"VLLM_USE_DEEPSEEK_MOE_SUM_MUL_ADD"
:
lambda
:
(
os
.
getenv
(
'VLLM_USE_DEEPSEEK_MOE_SUM_MUL_ADD'
,
'True'
).
lower
()
in
(
"true"
,
"1"
)),
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT"
:
lambda
:
(
os
.
getenv
(
'USE_FUSED_SILU_MUL_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use all_to_all ep mode
"VLLM_USE_MORI_EP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MORI_EP"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
83f2f396
...
...
@@ -40,12 +40,19 @@ from vllm.model_executor.layers.fused_moe.utils import (
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
direct_register_custom_op
from
lightop
import
op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
os
.
environ
[
'DPSK_FP16_QUICK'
]
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
,
'0'
)
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
logger
=
init_logger
(
__name__
)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
moe_cache_singleton
=
None
def
get_moe_cache
(
top_k_num
,
N
,
K
,
device
,
dtype
):
global
moe_cache_singleton
if
moe_cache_singleton
is
None
:
...
...
@@ -1258,14 +1265,14 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
in
t
]
=
None
,
)
->
None
:
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
floa
t
]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
num_local_tokens
,
true_bs
)
block_shape
,
use_nn_moe
,
shared_output
,
routed_scaling_factor
)
def
inplace_fused_experts_fake
(
...
...
@@ -1292,8 +1299,8 @@ def inplace_fused_experts_fake(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
in
t
]
=
None
,
)
->
None
:
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
floa
t
]
=
None
)
->
None
:
pass
...
...
@@ -1330,15 +1337,15 @@ def outplace_fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
in
t
]
=
None
,
)
->
torch
.
Tensor
:
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
floa
t
]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
num_local_tokens
,
true_bs
)
block_shape
,
use_nn_moe
,
shared_output
,
routed_scaling_factor
)
def
outplace_fused_experts_fake
(
...
...
@@ -1364,8 +1371,8 @@ def outplace_fused_experts_fake(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
in
t
]
=
None
,
)
->
torch
.
Tensor
:
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
floa
t
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -1423,8 +1430,8 @@ def fused_experts(
allow_deep_gemm
:
bool
=
False
,
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
in
t
]
=
None
,
)
->
torch
.
Tensor
:
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
floa
t
]
=
None
)
->
torch
.
Tensor
:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N
=
w1
.
size
(
1
)
...
...
@@ -1483,8 +1490,8 @@ def fused_experts(
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
num_local_tokens
=
num_local_tokens
,
true_bs
=
true_bs
)
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
def
fused_experts_impl
(
...
...
@@ -1512,8 +1519,8 @@ def fused_experts_impl(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
in
t
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
floa
t
]
=
None
,
)
->
torch
.
Tensor
:
num_tokens
=
hidden_states
.
size
(
0
)
if
use_nn_moe
:
...
...
@@ -1559,8 +1566,8 @@ def fused_experts_impl(
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
,
num_local_tokens
=
num_local_tokens
,
true_bs
=
true_bs
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
...
...
@@ -1587,7 +1594,9 @@ def fused_experts_impl(
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
#
...
...
@@ -1760,8 +1769,28 @@ def fused_experts_impl(
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
if
envs
.
VLLM_USE_LIGHTOP
and
not
dpsk_fp16_quick
:
from
lightop
import
op
as
op
op
.
moe_sum
(
input
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
output
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
bias
=
shared_output
[
begin_chunk_idx
:
end_chunk_idx
],
expert_mask
=
None
,
num_local_tokens
=
None
,
factor
=
routed_scaling_factor
)
# else:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx])
# if shared_output is not None:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
# else:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else
:
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
...
...
@@ -1795,6 +1824,8 @@ def fused_moe(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
...
...
@@ -1880,7 +1911,9 @@ def fused_moe(
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
class
TritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
...
...
@@ -2097,4 +2130,4 @@ def modular_triton_fused_moe(
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
),
)
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/layer.py
View file @
83f2f396
...
...
@@ -42,7 +42,7 @@ from vllm.platforms.interface import CpuArchEnum
from
vllm.utils
import
direct_register_custom_op
,
has_deep_ep
,
has_pplx
from
vllm
import
_custom_ops
as
ops
from
lightop
import
op
if
current_platform
.
is_cuda_alike
():
from
.fused_batched_moe
import
BatchedTritonExperts
...
...
@@ -222,6 +222,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -373,6 +374,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
...
...
@@ -397,6 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
shared_output
=
shared_output
,
use_nn_moe
=
use_nn_moe
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
...
...
@@ -418,6 +421,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
...
...
@@ -460,7 +464,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
use_nn_moe
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
def
forward_cpu
(
...
...
@@ -1285,7 +1291,8 @@ class FusedMoE(torch.nn.Module):
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
if
use_fused_gate
:
if
envs
.
VLLM_USE_LIGHT_OP
:
if
envs
.
VLLM_USE_LIGHTOP
:
from
lightop
import
op
as
op
topk_weights
,
topk_ids
=
op
.
moe_fused_gate
(
router_logits
,
e_score_correction_bias
,
...
...
@@ -1434,14 +1441,15 @@ class FusedMoE(torch.nn.Module):
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
):
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if
current_platform
.
is_tpu
():
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
else
:
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
)
self
.
layer_name
,
shared_output
)
def
forward_impl_chunked
(
self
,
full_hidden_states
:
torch
.
Tensor
,
full_router_logits
:
torch
.
Tensor
):
...
...
@@ -1520,7 +1528,8 @@ class FusedMoE(torch.nn.Module):
return
full_final_hidden_states
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
self
.
quant_method
is
not
None
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
):
...
...
@@ -1554,6 +1563,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view
=
self
.
expert_load_view
,
logical_to_physical_map
=
self
.
logical_to_physical_map
,
logical_replica_count
=
self
.
logical_replica_count
,
shared_output
=
shared_output
,
use_nn_moe
=
self
.
use_nn_moe
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
...
...
@@ -1626,17 +1636,17 @@ class FusedMoE(torch.nn.Module):
return
s
def
moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
def
moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
quant_method
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
shared_output
)
def
moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -1647,4 +1657,4 @@ direct_register_custom_op(
fake_impl
=
moe_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
View file @
83f2f396
...
...
@@ -9,7 +9,6 @@ from vllm.triton_utils import tl, triton
from
vllm.utils
import
cdiv
,
round_up
import
vllm.envs
as
envs
from
lightop
import
op
@
triton
.
jit
...
...
@@ -153,7 +152,7 @@ def moe_align_block_size(
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_sorted_ids
:
bool
=
False
,
num_token
:
Optional
[
int
]
=
None
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
...
...
@@ -233,12 +232,16 @@ def moe_align_block_size(
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
envs
.
VLLM_USE_LIGHT_OP
:
if
envs
.
VLLM_USE_LIGHTOP
or
expert_mask
is
not
None
:
from
lightop
import
op
as
op
op
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
expert_map
,
None
,
num_local_tokens
)
expert_ids
,
num_tokens_post_pad
,
expert_map
=
expert_map
,
expert_mask
=
expert_mask
,
num_local_tokens
=
None
)
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
,
num_tokens_post_pad
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
...
...
vllm/model_executor/layers/layernorm.py
View file @
83f2f396
...
...
@@ -10,6 +10,7 @@ import vllm.envs as envs
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
def
is_rocm_aiter_rmsnorm_enabled
()
->
bool
:
...
...
@@ -39,6 +40,33 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
return
out
def
rms_norm_opt
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
from
lightop
import
fused_rms_norm_contiguous
out
=
torch
.
empty_like
(
x
)
fused_rms_norm_contiguous
(
out
,
x
,
weight
,
variance_epsilon
,
)
return
out
def
rms_norm_opt_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
direct_register_custom_op
(
op_name
=
"rms_norm_opt"
,
op_func
=
rms_norm_opt
,
mutates_args
=
[],
fake_impl
=
rms_norm_opt_fake
,
)
def
fused_add_rms_norm
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -187,6 +215,23 @@ class RMSNorm(CustomOp):
else
:
return
norm_func
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
def
forward_cuda_opt
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
add_residual
=
residual
is
not
None
norm_func
=
dispatch_cuda_rmsnorm_func
(
add_residual
)
if
add_residual
:
return
norm_func
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
else
:
return
torch
.
ops
.
vllm
.
rms_norm_opt
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
def
forward_apex
(
self
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/linear.py
View file @
83f2f396
...
...
@@ -38,7 +38,13 @@ if envs.USE_FUSED_RMS_QUANT:
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
except
Exception
as
e
:
print
(
f
"Error: Import fused rmsquant error:
{
e
}
"
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
try
:
# from lightop import fuse_silu_mul_quant
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
except
Exception
as
e
:
print
(
f
"Error: Import fused silu_mul_qunat error:
{
e
}
"
)
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
...
...
@@ -1516,7 +1522,8 @@ class RowParallelLinear(LinearBase):
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
self
,
input_
,
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
self
.
input_is_parallel
:
input_parallel
=
input_
...
...
@@ -1531,9 +1538,18 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
use_fused_silu_mul_quant
:
xq
,
xs
=
lm_fuse_silu_mul_quant
(
input_parallel
)
silu_quant_args
=
[
xq
,
xs
]
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
,
silu_quant_args
=
silu_quant_args
)
else
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
83f2f396
...
...
@@ -666,7 +666,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
...
...
@@ -677,7 +678,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme
=
layer
.
scheme
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
,
input_quant_args
=
input_quant_args
)
class
CompressedTensorsKVCacheMethod
(
BaseKVCacheMethod
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
83f2f396
...
...
@@ -1097,7 +1097,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
...
...
@@ -1137,7 +1137,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
)
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
class
CompressedTensorsWNA16MarlinMoEMethod
(
CompressedTensorsMoEMethod
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
83f2f396
...
...
@@ -111,7 +111,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
],
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
)
->
torch
.
Tensor
:
# return self.kernel.apply_weights(layer, x, bias)
...
...
@@ -122,5 +123,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point
=
layer
.
input_zero_point
,
azp_adj
=
layer
.
azp_adj
,
bias
=
bias
,
w8a8_strategy
=
self
.
w8a8_strategy
)
w8a8_strategy
=
self
.
w8a8_strategy
,
input_quant_args
=
input_quant_args
)
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment