Commit 83f2f396 authored by 王敏's avatar 王敏
Browse files

同步0.9.2-ds分支代码

parents d2e57a90 20605c42
...@@ -965,6 +965,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ...@@ -965,6 +965,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3); vllm::Fp8KVCacheDataType::kFp8E4M3);
} }
} else if (kv_cache_dtype == "fp8_e5m2") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E5M2);
}
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
} }
......
...@@ -179,7 +179,6 @@ void merge_attn_states(torch::Tensor& output, ...@@ -179,7 +179,6 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_lse); const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes( void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
...@@ -204,7 +203,6 @@ void convert_vertical_slash_indexes_mergehead( ...@@ -204,7 +203,6 @@ void convert_vertical_slash_indexes_mergehead(
torch::Tensor vertical_indices_count, // [N_HEADS, ] torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size, torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal); int64_t block_size_M, int64_t block_size_N, bool causal);
#endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon); double epsilon);
......
...@@ -53,10 +53,6 @@ static inline __device__ uint8_t float_to_fp8(float f) { ...@@ -53,10 +53,6 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return result; return result;
} }
// template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x;
// }
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
...@@ -64,281 +60,6 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, ...@@ -64,281 +60,6 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return x; return x;
} }
// #if HIP_FP8_TYPE_OCP
// using fp8_type = __hip_fp8_e4m3;
// using fp8x2_type = __hip_fp8x2_e4m3;
// #else
// using fp8_type = __hip_fp8_e4m3_fnuz;
// using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
// #endif
// // fp8 -> half
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
// return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
// }
// // fp8x2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// return tmp.ui32;
// }
// // fp8x4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
// union {
// uint2 u32x2;
// uint32_t u32[2];
// } tmp;
// tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
// tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
// return tmp.u32x2;
// }
// // fp8x8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
// union {
// uint4 u64x2;
// uint2 u64[2];
// } tmp;
// tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
// tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
// return tmp.u64x2;
// }
// using __nv_bfloat16 = __hip_bfloat16;
// // fp8 -> __nv_bfloat16
// template <>
// __inline__ __device__ __nv_bfloat16
// vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8));
// }
// using __nv_bfloat162 = __hip_bfloat162;
// // fp8x2 -> __nv_bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
// __nv_bfloat162 res;
// res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
// res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
// return res;
// }
// // fp8x4 -> bf16_4_t
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
// bf16_4_t res;
// res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
// res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x8 -> bf16_8_t
// template <>
// __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// bf16_4_t tmp1, tmp2;
// tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
// tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
// bf16_8_t res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // fp8 -> float
// template <>
// __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8);
// }
// // fp8x2 -> float2
// template <>
// __inline__ __device__ float2
// vec_conversion<float2, uint16_t>(const uint16_t& a) {
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2);
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ Float4_
// vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
// Float4_ res;
// res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
// res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ float4
// vec_conversion<float4, uint32_t>(const uint32_t& a) {
// Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
// float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
// return res;
// }
// // fp8x8 -> float8
// template <>
// __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
// Float4_ tmp1, tmp2;
// tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
// tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
// Float8_ res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // half -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
// __half_raw tmp;
// tmp.x = a;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // bf16 -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
// return __hip_cvt_float_to_fp8(__bfloat162float(a),
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float -> fp8
// template <>
// __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
// return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, float2>(const float2& a) {
// union {
// half2 float16;
// uint32_t uint32;
// };
// float16 = __float22half2_rn(a);
// return uint32;
// }
// // Float4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
// uint2 b;
// float2 val;
// val.x = a.x.x;
// val.y = a.x.y;
// b.x = vec_conversion<uint32_t, float2>(val);
// val.x = a.y.x;
// val.y = a.y.y;
// b.y = vec_conversion<uint32_t, float2>(val);
// return b;
// }
// // Float4 -> float4
// template <>
// __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
// float4 b;
// b.x = a.x.x;
// b.y = a.x.y;
// b.z = a.y.x;
// b.w = a.y.y;
// return b;
// }
// // Float8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
// uint4 b;
// b.x = vec_conversion<uint32_t, float2>(a.x);
// b.y = vec_conversion<uint32_t, float2>(a.y);
// b.z = vec_conversion<uint32_t, float2>(a.z);
// b.w = vec_conversion<uint32_t, float2>(a.w);
// return b;
// }
// // float2 -> bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, float2>(const float2& a) {
// __nv_bfloat162 b = __float22bfloat162_rn(a);
// return b;
// }
// // Float4 -> bfloat162x2
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
// bf16_4_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// return b;
// }
// // Float8 -> bfloat162x4
// template <>
// __inline__ __device__ bf16_8_t
// vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
// bf16_8_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// b.z = __float22bfloat162_rn(a.z);
// b.w = __float22bfloat162_rn(a.w);
// return b;
// }
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
Convention of the scale in API, e.g: FP8_data = Quantization(
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
scale => HP
*/
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16 // fp8 -> __nv_bfloat16
...@@ -347,9 +68,6 @@ __inline__ __device__ __nv_bfloat16 ...@@ -347,9 +68,6 @@ __inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
return __float2bfloat16(fp8_to_float(a) * scale); return __float2bfloat16(fp8_to_float(a) * scale);
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8) * scale);
} }
// fp8x2 -> __nv_bfloat162 // fp8x2 -> __nv_bfloat162
...@@ -395,9 +113,6 @@ template <> ...@@ -395,9 +113,6 @@ template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>( __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) { const uint8_t& a, float scale) {
return fp8_to_float(a) * scale; return fp8_to_float(a) * scale;
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8) * scale;
} }
// fp8x2 -> float2 // fp8x2 -> float2
...@@ -408,10 +123,6 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) { ...@@ -408,10 +123,6 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale); f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale); f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
return f2r; return f2r;
// [[maybe_unused]]
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2) * scale;
} }
// fp8x4 -> float4 // fp8x4 -> float4
...@@ -453,9 +164,6 @@ __inline__ __device__ uint16_t ...@@ -453,9 +164,6 @@ __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
float res = fp8_to_float(a) * scale; float res = fp8_to_float(a) * scale;
return float_to_half(res); return float_to_half(res);
// __half_raw res;
// res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
// return res.x;
} }
// fp8x2 -> half2 // fp8x2 -> half2
...@@ -469,16 +177,6 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) { ...@@ -469,16 +177,6 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale); res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale); res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
return res.u32; return res.u32;
// [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// tmp.h2r.x.data *= scale;
// tmp.h2r.y.data *= scale;
// return tmp.ui32;
} }
// fp8x4 -> half2x2 // fp8x4 -> half2x2
...@@ -513,11 +211,6 @@ __inline__ __device__ uint8_t ...@@ -513,11 +211,6 @@ __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
float res_f = half_to_float(a) / scale; float res_f = half_to_float(a) / scale;
return float_to_fp8(res_f); return float_to_fp8(res_f);
// __half_raw tmp;
// tmp.x = a;
// tmp.data /= scale;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// halfx2 -> fp8x2 // halfx2 -> fp8x2
...@@ -536,15 +229,6 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { ...@@ -536,15 +229,6 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale); tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale); tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
return tmp.ui16; return tmp.ui16;
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// tmp.h2r.x.data /= scale;
// tmp.h2r.y.data /= scale;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// half2x2 -> fp8x4 // half2x2 -> fp8x4
...@@ -581,9 +265,6 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( ...@@ -581,9 +265,6 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) { const __nv_bfloat16& a, float scale) {
float res_f = (static_cast<float>(a)) / scale; float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8(res_f); return float_to_fp8(res_f);
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// bf16x2 -> fp8x2 // bf16x2 -> fp8x2
...@@ -627,8 +308,6 @@ template <> ...@@ -627,8 +308,6 @@ template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) { scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return float_to_fp8(a / scale); return float_to_fp8(a / scale);
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// floatx2 -> fp8x2 // floatx2 -> fp8x2
...@@ -642,8 +321,6 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) { ...@@ -642,8 +321,6 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale); tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale); tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
return tmp.ui16; return tmp.ui16;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// floatx4 -> fp8x4 // floatx4 -> fp8x4
...@@ -658,27 +335,113 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) { ...@@ -658,27 +335,113 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
return tmp.ui32; return tmp.ui32;
} }
// #endif // ENABLE_FP8
// template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> inline __device__ uint8_t float_to_fp8e5m2(float f) {
// __inline__ __device__ Tout convert(const Tin& x) { constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
// #ifdef ENABLE_FP8 constexpr uint32_t fp8_max = UINT32_C(143) << 23;
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
// return vec_conversion<Tout, Tin>(x); uint32_t f_bits = c10::detail::fp32_to_bits(f);
// } uint8_t result = 0u;
// #endif const uint32_t sign = f_bits & UINT32_C(0x80000000);
// assert(false); f_bits ^= sign;
// return {}; // Squash missing return statement warning if (f_bits >= fp8_max) {
// } result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
+ c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint32_t mant_odd = (f_bits >> 21) & 1;
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
// fp8
template <typename Tin>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2(const Tin& a, float scale) {
return 0;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<float>(const float& a, float scale) {
return float_to_fp8e5m2(a / scale);
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<uint16_t>(const uint16_t& a, float scale) {
float res_f = half_to_float(a) / scale;
return float_to_fp8e5m2(res_f);
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8e5m2(res_f);
}
inline __device__ float fp8e5m2_to_fp32(const uint8_t& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
uf16 u16;
u16.as_bits = (uint16_t)input << 8;
return (float)u16.as_value;
}
template <typename Tout>
__inline__ __device__ Tout
scaled_vec_conversion_from_e5m2(const uint8_t& a, float scale) {
return 0;
}
// fp8 -> float
template <>
__inline__ __device__ float
scaled_vec_conversion_from_e5m2<float>(const uint8_t& a, float scale) {
return fp8e5m2_to_fp32(a)*scale;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion_from_e5m2<uint16_t>(const uint8_t& a, float scale) {
return float_to_half(fp8e5m2_to_fp32(a)*scale);
}
// fp8 -> bf16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
return __float2bfloat16(fp8e5m2_to_fp32(a)*scale);
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale); return scaled_vec_conversion<Tout, Tin>(x, scale);
// } }
// #endif else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
// assert(false); return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
}
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tin)==1){
return scaled_vec_conversion_from_e5m2<Tout>(x, scale);
}
return {}; // Squash missing return statement warning return {}; // Squash missing return statement warning
} }
...@@ -686,7 +449,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -686,7 +449,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// the data type of the key and value cache. The FN is a macro that calls a // the data type of the key and value cache. The FN is a macro that calls a
// function with template<typename scalar_t, typename cache_t, // function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>. // Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \ if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
...@@ -697,16 +460,6 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -697,16 +460,6 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else if (KV_DTYPE == "int8") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else { \
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
...@@ -719,11 +472,23 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -719,11 +472,23 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \ TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \ "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \ } \
} }
} // namespace fp8 } // namespace fp8
#endif // USE_ROCM #endif // USE_ROCM
} // namespace vllm } // namespace vllm
\ No newline at end of file
...@@ -229,8 +229,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -229,8 +229,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()"); " Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def( ops.def(
"convert_vertical_slash_indexes(" "convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, " " Tensor! block_count, Tensor! block_offset, "
...@@ -253,7 +251,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -253,7 +251,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" bool causal) -> ()"); " bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead); &convert_vertical_slash_indexes_mergehead);
#endif
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
......
...@@ -2162,9 +2162,22 @@ def gather_cache(src_cache: torch.Tensor, ...@@ -2162,9 +2162,22 @@ def gather_cache(src_cache: torch.Tensor,
block_table: torch.Tensor, block_table: torch.Tensor,
cu_seq_lens: torch.Tensor, cu_seq_lens: torch.Tensor,
batch_size: int, batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None: seq_starts: Optional[torch.Tensor] = None,
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, kv_dtype = "auto",
cu_seq_lens, batch_size, seq_starts) scale: float = 1.0,
) -> None:
#支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if kv_dtype == "fp8" or kv_dtype == "fp8_e5m2" or kv_dtype == "fp8_e4m3":
dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device)
#convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16
convert_fp8(dst, dst_fp8, scale, kv_dtype)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
def get_device_attribute(attribute: int, device: int) -> int: def get_device_attribute(attribute: int, device: int) -> int:
......
...@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl") "FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8": if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
raise NotImplementedError( return
"FlashMLA with other KV cache not yet supported") raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode( def _forward_decode(
self, self,
......
...@@ -1179,6 +1179,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1179,6 +1179,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q: torch.Tensor, q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
): ):
prefill_metadata = attn_metadata.prefill_metadata prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None assert prefill_metadata is not None
...@@ -1207,6 +1208,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1207,6 +1208,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
batch_size=prefill_metadata.num_prefills, batch_size=prefill_metadata.num_prefills,
seq_starts=prefill_metadata.context_chunk_starts[i], seq_starts=prefill_metadata.context_chunk_starts[i],
kv_dtype=self.kv_cache_dtype,
scale=kv_scale,
) )
kv_c_normed = workspace[:toks]\ kv_c_normed = workspace[:toks]\
...@@ -1262,6 +1265,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1262,6 +1265,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe: torch.Tensor, k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
) -> torch.Tensor: ) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata prefill_metadata = attn_metadata.prefill_metadata
...@@ -1297,7 +1301,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1297,7 +1301,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2 # ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse = output suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \ context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata) q, kv_c_and_k_pe_cache, attn_metadata, kv_scale)
output = torch.empty_like(suffix_output) output = torch.empty_like(suffix_output)
merge_attn_states( merge_attn_states(
...@@ -1387,7 +1391,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1387,7 +1391,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if has_prefill: if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill( output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata) attn_metadata, kv_scale=layer._k_scale)
if has_decode: if has_decode:
decode_q_nope, decode_q_pe = decode_q.split( decode_q_nope, decode_q_pe = decode_q.split(
......
...@@ -24,6 +24,11 @@ from vllm.platforms import _Backend, current_platform ...@@ -24,6 +24,11 @@ from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target from vllm.v1.attention.backends.utils import validate_kv_sharing_target
USE_XFORMERS_OPS = None
try:
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module): class Attention(nn.Module):
"""Attention layer. """Attention layer.
...@@ -204,9 +209,12 @@ class Attention(nn.Module): ...@@ -204,9 +209,12 @@ class Attention(nn.Module):
`vllm.forward_context.get_forward_context().attn_metadata`. `vllm.forward_context.get_forward_context().attn_metadata`.
""" """
if self.calculate_kv_scales: if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata # attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation: # #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
self.calc_kv_scales(query, key, value) # if key is not None and value is not None:
# self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
if self.use_output: if self.use_output:
output_shape = (output_shape output_shape = (output_shape
if output_shape is not None else query.shape) if output_shape is not None else query.shape)
...@@ -394,7 +402,42 @@ def maybe_save_kv_layer_to_connector( ...@@ -394,7 +402,42 @@ def maybe_save_kv_layer_to_connector(
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name]) attn_metadata[layer_name])
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake( query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=[],
fake_impl=maybe_calc_kv_scales_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
......
...@@ -99,7 +99,8 @@ def flash_mla_with_kvcache( ...@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm(): if current_platform.is_rocm():
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q, q,
k_cache, k_cache,
...@@ -112,7 +113,7 @@ def flash_mla_with_kvcache( ...@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
k_scale, k_scale,
"fp8_e4m3", kv_dtype,
) )
return out, softmax_lse return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
......
...@@ -55,7 +55,7 @@ class ReqMeta: ...@@ -55,7 +55,7 @@ class ReqMeta:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
) )
@dataclass @dataclass
class P2pNcclConnectorMetadata(KVConnectorMetadata): class P2pNcclConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] requests: list[ReqMeta]
...@@ -95,6 +95,12 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -95,6 +95,12 @@ class P2pNcclConnector(KVConnectorBase_V1):
hostname="", hostname="",
port_offset=self._rank, port_offset=self._rank,
) if role == KVConnectorRole.WORKER else None ) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config
self.model_config = vllm_config.model_config
self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
"num_hidden_layers", 0)
self.pp_size = self.parallel_config.pipeline_parallel_size
# ============================== # ==============================
# Worker-side methods # Worker-side methods
...@@ -285,13 +291,35 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -285,13 +291,35 @@ class P2pNcclConnector(KVConnectorBase_V1):
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_address)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4))
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4))
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i))
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
def wait_for_save(self): def wait_for_save(self):
if self.is_producer: pass
assert self.p2p_nccl_engine is not None # if self.is_producer:
self.p2p_nccl_engine.wait_for_sent() # assert self.p2p_nccl_engine is not None
# self.p2p_nccl_engine.wait_for_sent()
def get_finished( def get_finished(
self, finished_req_ids: set[str], self, finished_req_ids: set[str],
......
...@@ -63,7 +63,7 @@ class TensorMemoryPool: ...@@ -63,7 +63,7 @@ class TensorMemoryPool:
than min_block_size than min_block_size
""" """
def __init__(self, max_block_size: int, min_block_size: int = 512): def __init__(self, max_block_size: int, min_block_size: int = 128):
if max_block_size <= 0 or min_block_size <= 0: if max_block_size <= 0 or min_block_size <= 0:
raise ValueError("Block sizes must be positive") raise ValueError("Block sizes must be positive")
if max_block_size < min_block_size: if max_block_size < min_block_size:
......
...@@ -164,10 +164,11 @@ if TYPE_CHECKING: ...@@ -164,10 +164,11 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHTOP: bool = False
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_OPT_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_MORI_EP: bool = False VLLM_USE_MORI_EP: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1050,7 +1051,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1050,7 +1051,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If there are any problems during use, use environment variables # If there are any problems during use, use environment variables
# to restore the default usage. # to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT": "VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))), lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "1"))),
# If set, vLLM will transpose weight to use nn layout # If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN": "VLLM_USE_NN":
...@@ -1094,15 +1095,15 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1094,15 +1095,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13": "VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHT_OP": "VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use opt cat for deepseek-v3
"VLLM_USE_TRITON_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states,not triton # vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
...@@ -1111,6 +1112,16 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1111,6 +1112,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm will use lightop's moe_sum fusion operator for deepseek
"VLLM_USE_DEEPSEEK_MOE_SUM_MUL_ADD":
lambda: (os.getenv('VLLM_USE_DEEPSEEK_MOE_SUM_MUL_ADD', 'True').lower() in
("true", "1")),
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
# vLLM will use all_to_all ep mode # vLLM will use all_to_all ep mode
"VLLM_USE_MORI_EP": "VLLM_USE_MORI_EP":
lambda: (os.environ.get("VLLM_USE_MORI_EP", "True").lower() in lambda: (os.environ.get("VLLM_USE_MORI_EP", "True").lower() in
......
...@@ -40,12 +40,19 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -40,12 +40,19 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from lightop import op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
def get_moe_cache(top_k_num,N,K,device,dtype): def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton global moe_cache_singleton
if moe_cache_singleton is None: if moe_cache_singleton is None:
...@@ -1258,14 +1265,14 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1258,14 +1265,14 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> None: routed_scaling_factor: Optional[float] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map, per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe, num_local_tokens, true_bs) block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1292,8 +1299,8 @@ def inplace_fused_experts_fake( ...@@ -1292,8 +1299,8 @@ def inplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> None: routed_scaling_factor: Optional[float] = None) -> None:
pass pass
...@@ -1330,15 +1337,15 @@ def outplace_fused_experts( ...@@ -1330,15 +1337,15 @@ def outplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor: routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16,use_int4_w4a8, per_channel_quant, use_int4_w4a16,use_int4_w4a8, per_channel_quant,
global_num_experts, expert_map, w1_scale, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe, num_local_tokens, true_bs) block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1364,8 +1371,8 @@ def outplace_fused_experts_fake( ...@@ -1364,8 +1371,8 @@ def outplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor: routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1423,8 +1430,8 @@ def fused_experts( ...@@ -1423,8 +1430,8 @@ def fused_experts(
allow_deep_gemm: bool = False, allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor: routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available. # permute/unpermute ops are available.
N = w1.size(1) N = w1.size(1)
...@@ -1483,8 +1490,8 @@ def fused_experts( ...@@ -1483,8 +1490,8 @@ def fused_experts(
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens, shared_output=shared_output,
true_bs=true_bs) routed_scaling_factor=routed_scaling_factor)
def fused_experts_impl( def fused_experts_impl(
...@@ -1512,8 +1519,8 @@ def fused_experts_impl( ...@@ -1512,8 +1519,8 @@ def fused_experts_impl(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
if use_nn_moe: if use_nn_moe:
...@@ -1559,8 +1566,8 @@ def fused_experts_impl( ...@@ -1559,8 +1566,8 @@ def fused_experts_impl(
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=False, use_nn_moe=False,
num_local_tokens=num_local_tokens, shared_output=shared_output,
true_bs=true_bs, routed_scaling_factor=routed_scaling_factor
) )
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
...@@ -1587,7 +1594,9 @@ def fused_experts_impl( ...@@ -1587,7 +1594,9 @@ def fused_experts_impl(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe= False use_nn_moe= False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
) )
# #
...@@ -1760,8 +1769,28 @@ def fused_experts_impl( ...@@ -1760,8 +1769,28 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick:
out_hidden_states[begin_chunk_idx:end_chunk_idx]) from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx],
expert_mask=None, num_local_tokens=None, factor=routed_scaling_factor)
# else:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx])
# if shared_output is not None:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
# else:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states return out_hidden_states
...@@ -1795,6 +1824,8 @@ def fused_moe( ...@@ -1795,6 +1824,8 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1880,7 +1911,9 @@ def fused_moe( ...@@ -1880,7 +1911,9 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...@@ -2097,4 +2130,4 @@ def modular_triton_fused_moe( ...@@ -2097,4 +2130,4 @@ def modular_triton_fused_moe(
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
), ),
) )
\ No newline at end of file
...@@ -42,7 +42,7 @@ from vllm.platforms.interface import CpuArchEnum ...@@ -42,7 +42,7 @@ from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from lightop import op
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
...@@ -222,6 +222,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -222,6 +222,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -373,6 +374,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -373,6 +374,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
...@@ -397,6 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -397,6 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
shared_output=shared_output,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate) use_fused_gate=use_fused_gate)
...@@ -418,6 +421,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -418,6 +421,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
...@@ -460,7 +464,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -460,7 +464,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
use_nn_moe=use_nn_moe use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
) )
def forward_cpu( def forward_cpu(
...@@ -1285,7 +1291,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1285,7 +1291,8 @@ class FusedMoE(torch.nn.Module):
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
if use_fused_gate: if use_fused_gate:
if envs.VLLM_USE_LIGHT_OP: if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
topk_weights, topk_ids = op.moe_fused_gate( topk_weights, topk_ids = op.moe_fused_gate(
router_logits, router_logits,
e_score_correction_bias, e_score_correction_bias,
...@@ -1434,14 +1441,15 @@ class FusedMoE(torch.nn.Module): ...@@ -1434,14 +1441,15 @@ class FusedMoE(torch.nn.Module):
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
# TODO: Once the OOM issue for the TPU backend is resolved, we will # TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op. # switch to using the moe_forward custom op.
if current_platform.is_tpu(): if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name, shared_output)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1520,7 +1528,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1520,7 +1528,8 @@ class FusedMoE(torch.nn.Module):
return full_final_hidden_states return full_final_hidden_states
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
assert self.quant_method is not None assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels):
...@@ -1554,6 +1563,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1554,6 +1563,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view=self.expert_load_view, expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map, logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate use_fused_gate=self.use_fused_gate
...@@ -1626,17 +1636,17 @@ class FusedMoE(torch.nn.Module): ...@@ -1626,17 +1636,17 @@ class FusedMoE(torch.nn.Module):
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1647,4 +1657,4 @@ direct_register_custom_op( ...@@ -1647,4 +1657,4 @@ direct_register_custom_op(
fake_impl=moe_forward_fake, fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
\ No newline at end of file
...@@ -9,7 +9,6 @@ from vllm.triton_utils import tl, triton ...@@ -9,7 +9,6 @@ from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, round_up from vllm.utils import cdiv, round_up
import vllm.envs as envs import vllm.envs as envs
from lightop import op
@triton.jit @triton.jit
...@@ -153,7 +152,7 @@ def moe_align_block_size( ...@@ -153,7 +152,7 @@ def moe_align_block_size(
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False, pad_sorted_ids: bool = False,
num_token: Optional[int] = None, num_token: Optional[int] = None,
num_local_tokens: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
...@@ -233,12 +232,16 @@ def moe_align_block_size( ...@@ -233,12 +232,16 @@ def moe_align_block_size(
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
if envs.VLLM_USE_LIGHT_OP: if envs.VLLM_USE_LIGHTOP or expert_mask is not None:
from lightop import op as op
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, expert_map, None, num_local_tokens) expert_ids, num_tokens_post_pad,
expert_map = expert_map,
expert_mask = expert_mask,
num_local_tokens = None)
else: else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad) expert_ids, num_tokens_post_pad)
if expert_map is not None: if expert_map is not None:
expert_ids = expert_map[expert_ids] expert_ids = expert_map[expert_ids]
......
...@@ -10,6 +10,7 @@ import vllm.envs as envs ...@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool: def is_rocm_aiter_rmsnorm_enabled() -> bool:
...@@ -39,6 +40,33 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, ...@@ -39,6 +40,33 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
return out return out
def rms_norm_opt(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
from lightop import fused_rms_norm_contiguous
out = torch.empty_like(x)
fused_rms_norm_contiguous(
out,
x,
weight,
variance_epsilon,
)
return out
def rms_norm_opt_fake(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op(
op_name="rms_norm_opt",
op_func=rms_norm_opt,
mutates_args=[],
fake_impl=rms_norm_opt_fake,
)
def fused_add_rms_norm( def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -187,6 +215,23 @@ class RMSNorm(CustomOp): ...@@ -187,6 +215,23 @@ class RMSNorm(CustomOp):
else: else:
return norm_func(x, self.weight.data, self.variance_epsilon) return norm_func(x, self.weight.data, self.variance_epsilon)
def forward_cuda_opt(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if add_residual:
return norm_func(x, residual, self.weight.data,
self.variance_epsilon)
else:
return torch.ops.vllm.rms_norm_opt(x, self.weight.data, self.variance_epsilon)
def forward_apex( def forward_apex(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -38,7 +38,13 @@ if envs.USE_FUSED_RMS_QUANT: ...@@ -38,7 +38,13 @@ if envs.USE_FUSED_RMS_QUANT:
from lmslim.quantize.quant_ops import lm_faster_rmsquant from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e: except Exception as e:
print(f"Error: Import fused rmsquant error: {e}") print(f"Error: Import fused rmsquant error: {e}")
if envs.USE_FUSED_SILU_MUL_QUANT:
try:
# from lightop import fuse_silu_mul_quant
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
except Exception as e:
print(f"Error: Import fused silu_mul_qunat error: {e}")
logger = init_logger(__name__) logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
...@@ -1516,7 +1522,8 @@ class RowParallelLinear(LinearBase): ...@@ -1516,7 +1522,8 @@ class RowParallelLinear(LinearBase):
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
self, input_ self, input_,
use_fused_silu_mul_quant: Optional[bool] = False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
...@@ -1531,9 +1538,18 @@ class RowParallelLinear(LinearBase): ...@@ -1531,9 +1538,18 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, if use_fused_silu_mul_quant:
input_parallel, xq, xs = lm_fuse_silu_mul_quant(input_parallel)
bias=bias_)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel) output = self.tbo_all_reduce(output_parallel)
......
...@@ -666,7 +666,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -666,7 +666,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None): bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None):
""" """
Use the output of create_weights and the CompressedTensorsScheme Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the associated with the layer to apply the forward pass with the
...@@ -677,7 +678,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -677,7 +678,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme = layer.scheme scheme = layer.scheme
if scheme is None: if scheme is None:
raise ValueError("A scheme must be defined for each layer") raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias) return scheme.apply_weights(layer, x, bias=bias, input_quant_args=input_quant_args)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
......
...@@ -1097,7 +1097,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1097,7 +1097,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
...@@ -1137,7 +1137,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1137,7 +1137,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False) use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
......
...@@ -111,7 +111,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -111,7 +111,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer) self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor],
input_quant_args: Optional[list[torch.Tensor]] = None) -> torch.Tensor:
# return self.kernel.apply_weights(layer, x, bias) # return self.kernel.apply_weights(layer, x, bias)
...@@ -122,5 +123,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -122,5 +123,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point=layer.input_zero_point, input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj, azp_adj=layer.azp_adj,
bias=bias, bias=bias,
w8a8_strategy=self.w8a8_strategy) w8a8_strategy=self.w8a8_strategy,
input_quant_args=input_quant_args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment