Commit fab1acce authored by zhuwenwen's avatar zhuwenwen
Browse files

[Feature] Support vllm v0.20.0

parent 88d34c64
......@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13" "3.14")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201;gfx928;gfx936;gfx938")
# ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake.
......@@ -1240,7 +1240,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
endif()
# For CUDA and HIP builds also build the triton_kernels external package.
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/triton_kernels.cmake)
endif()
......
......@@ -931,6 +931,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else if (kv_cache_dtype == "fp8_e5m2") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E5M2);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
......@@ -1156,9 +1172,9 @@ __global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
#ifdef USE_ROCM
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, vllm::Fp8KVCacheDataType::kFp8E4M3);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, vllm::Fp8KVCacheDataType::kFp8E4M3);
#else
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
......
......@@ -8,6 +8,8 @@
#include <cassert>
#ifdef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <hip/hip_runtime.h>
#else
#include <cuda_bf16.h>
......
......@@ -40,15 +40,15 @@
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__ inline void __syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
}
#endif
// #if defined(HIP_VERSION) && HIP_VERSION < 70000000
// // On ROCm versions before 7.0, __syncwarp isn't defined. The below
// // implementation is copy/pasted from the implementation in ROCm 7.0
// __device__ inline void __syncwarp() {
// __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
// __builtin_amdgcn_wave_barrier();
// __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
// }
// #endif
#else
#define FINAL_MASK 0xffffffff
#endif
......
......@@ -12,7 +12,9 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#ifndef USE_ROCM
#include "compat.cuh"
#endif
#include "matrix_view.cuh"
#include "qdq_2.cuh"
#include "qdq_3.cuh"
......
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
......@@ -11,318 +13,348 @@ namespace vllm {
#ifdef USE_ROCM
namespace fp8 {
#ifdef ENABLE_FP8
// #ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm
template <typename fp8_type>
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
return {};
// KV-CACHE int8
static inline __device__ float fp8_to_float(uint8_t input) {
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
}
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
// the new HW cvt with something reasonable that doesn't rely on the
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
template <>
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
#if HIP_FP8_TYPE_OCP
return c10::Float8_e4m3fn(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
__hip_fp8_e4m3::__default_interpret),
c10::Float8_e4m3fn::from_bits());
#else
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
return static_cast<c10::Float8_e4m3fn>(r);
#endif
}
// float -> fp8
static inline __device__ uint8_t float_to_fp8_e4m3(float f) {
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
uint8_t result = 0u;
const uint32_t sign = f_bits & UINT32_C(0x80000000);
f_bits ^= sign;
if (f_bits >= fp8_max) {
result = 0x7f;
} else {
if (f_bits < (UINT32_C(121) << 23)) {
f_bits =
c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits) + c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint8_t mant_odd = (f_bits >> 20) & 1;
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 20);
}
}
template <>
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
return c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
__hip_fp8_e4m3_fnuz::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
static inline __device__ uint8_t float_to_fp8_e5m2(float f) {
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
uint8_t result = 0u;
const uint32_t sign = f_bits & UINT32_C(0x80000000);
f_bits ^= sign;
if (f_bits >= fp8_max) {
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
+ c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint32_t mant_odd = (f_bits >> 21) & 1;
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x) {
return x;
}
// template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x;
// }
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
const float scale) {
const float scale, Fp8KVCacheDataType kv_type) {
return x;
}
#if HIP_FP8_TYPE_OCP
using fp8_type = __hip_fp8_e4m3;
using fp8x2_type = __hip_fp8x2_e4m3;
#else
using fp8_type = __hip_fp8_e4m3_fnuz;
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
#endif
// fp8 -> half
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
return tmp.ui32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2;
}
using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
fp8_type f8;
f8.__x = a;
return __float2bfloat16(static_cast<float>(f8));
}
using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
__nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
fp8_type f8;
f8.__x = a;
return static_cast<float>(f8);
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
vec_conversion<float2, uint16_t>(const uint16_t& a) {
fp8x2_type f8x2;
f8x2.__x = a;
return static_cast<float2>(f8x2);
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
__half_raw tmp;
tmp.x = a;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
union {
uint32_t ui32;
__half2_raw h2r;
} tmp;
tmp.ui32 = a;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
return __hip_cvt_float_to_fp8(__bfloat162float(a),
fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// float -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// float2 -> half2
template <>
__inline__ __device__ uint32_t
vec_conversion<uint32_t, float2>(const float2& a) {
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
// Float4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
// Float4 -> float4
template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
// Float8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
// float2 -> bfloat162
template <>
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
__nv_bfloat162 b = __float22bfloat162_rn(a);
return b;
}
// Float4 -> bfloat162x2
template <>
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
bf16_4_t b;
b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y);
return b;
}
// Float8 -> bfloat162x4
template <>
__inline__ __device__ bf16_8_t
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
bf16_8_t b;
b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y);
b.z = __float22bfloat162_rn(a.z);
b.w = __float22bfloat162_rn(a.w);
return b;
}
// #if HIP_FP8_TYPE_OCP
// using fp8_type = __hip_fp8_e4m3;
// using fp8x2_type = __hip_fp8x2_e4m3;
// #else
// using fp8_type = __hip_fp8_e4m3_fnuz;
// using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
// #endif
// // fp8 -> half
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
// return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
// }
// // fp8x2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// return tmp.ui32;
// }
// // fp8x4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
// union {
// uint2 u32x2;
// uint32_t u32[2];
// } tmp;
// tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
// tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
// return tmp.u32x2;
// }
// // fp8x8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
// union {
// uint4 u64x2;
// uint2 u64[2];
// } tmp;
// tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
// tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
// return tmp.u64x2;
// }
// using __nv_bfloat16 = __hip_bfloat16;
// // fp8 -> __nv_bfloat16
// template <>
// __inline__ __device__ __nv_bfloat16
// vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8));
// }
// using __nv_bfloat162 = __hip_bfloat162;
// // fp8x2 -> __nv_bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
// __nv_bfloat162 res;
// res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
// res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
// return res;
// }
// // fp8x4 -> bf16_4_t
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
// bf16_4_t res;
// res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
// res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x8 -> bf16_8_t
// template <>
// __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// bf16_4_t tmp1, tmp2;
// tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
// tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
// bf16_8_t res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // fp8 -> float
// template <>
// __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8);
// }
// // fp8x2 -> float2
// template <>
// __inline__ __device__ float2
// vec_conversion<float2, uint16_t>(const uint16_t& a) {
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2);
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ Float4_
// vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
// Float4_ res;
// res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
// res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
// return res;
// }
// // fp8x4 -> float4
// template <>
// __inline__ __device__ float4
// vec_conversion<float4, uint32_t>(const uint32_t& a) {
// Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
// float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
// return res;
// }
// // fp8x8 -> float8
// template <>
// __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
// Float4_ tmp1, tmp2;
// tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
// tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
// Float8_ res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
// return res;
// }
// // half -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
// __half_raw tmp;
// tmp.x = a;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// template <>
// __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // bf16 -> fp8
// template <>
// __inline__ __device__ uint8_t
// vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
// return __hip_cvt_float_to_fp8(__bfloat162float(a),
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float -> fp8
// template <>
// __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
// return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// }
// // float2 -> half2
// template <>
// __inline__ __device__ uint32_t
// vec_conversion<uint32_t, float2>(const float2& a) {
// union {
// half2 float16;
// uint32_t uint32;
// };
// float16 = __float22half2_rn(a);
// return uint32;
// }
// // Float4 -> half2x2
// template <>
// __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
// uint2 b;
// float2 val;
// val.x = a.x.x;
// val.y = a.x.y;
// b.x = vec_conversion<uint32_t, float2>(val);
// val.x = a.y.x;
// val.y = a.y.y;
// b.y = vec_conversion<uint32_t, float2>(val);
// return b;
// }
// // Float4 -> float4
// template <>
// __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
// float4 b;
// b.x = a.x.x;
// b.y = a.x.y;
// b.z = a.y.x;
// b.w = a.y.y;
// return b;
// }
// // Float8 -> half2x4
// template <>
// __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
// uint4 b;
// b.x = vec_conversion<uint32_t, float2>(a.x);
// b.y = vec_conversion<uint32_t, float2>(a.y);
// b.z = vec_conversion<uint32_t, float2>(a.z);
// b.w = vec_conversion<uint32_t, float2>(a.w);
// return b;
// }
// // float2 -> bfloat162
// template <>
// __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, float2>(const float2& a) {
// __nv_bfloat162 b = __float22bfloat162_rn(a);
// return b;
// }
// // Float4 -> bfloat162x2
// template <>
// __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
// bf16_4_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// return b;
// }
// // Float8 -> bfloat162x4
// template <>
// __inline__ __device__ bf16_8_t
// vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
// bf16_8_t b;
// b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y);
// b.z = __float22bfloat162_rn(a.z);
// b.w = __float22bfloat162_rn(a.w);
// return b;
// }
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
......@@ -338,42 +370,47 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a;
return __float2bfloat16(static_cast<float>(f8) * scale);
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return __float2bfloat16(fp8_to_float(a) * scale);
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8) * scale);
}
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type);
res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale);
scale, kv_type);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, kv_type);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
......@@ -385,46 +422,55 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a;
return static_cast<float>(f8) * scale;
const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return fp8_to_float(a) * scale;
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
fp8x2_type f8x2;
f8x2.__x = a;
return static_cast<float2>(f8x2) * scale;
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float2 f2r;
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale, kv_type);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return f2r;
// [[maybe_unused]]
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2) * scale;
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) {
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale, kv_type);
return {res.x.x, res.x.y, res.y.x, res.y.y};
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, kv_type);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
......@@ -436,200 +482,249 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
__half_raw res;
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
return res.x;
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
float res = fp8_to_float(a) * scale;
return float_to_half(res);
// __half_raw res;
// res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
// return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
tmp.h2r.x.data *= scale;
tmp.h2r.y.data *= scale;
return tmp.ui32;
uint16_t u16[2];
uint32_t u32;
} res;
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale, kv_type);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res.u32;
// [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// tmp.h2r.x.data *= scale;
// tmp.h2r.y.data *= scale;
// return tmp.ui32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, kv_type);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, kv_type);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, kv_type);
return tmp.u64x2;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
__half_raw tmp;
tmp.x = a;
tmp.data /= scale;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret);
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = half_to_float(a) / scale;
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
// __half_raw tmp;
// tmp.x = a;
// tmp.data /= scale;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint32_t ui32;
__half2_raw h2r;
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui32 = a;
tmp.h2r.x.data /= scale;
tmp.h2r.y.data /= scale;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
union {
uint32_t ui32;
half2 h2r;
} tmp_a;
tmp_a.ui32 = a;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale, kv_type);
return tmp.ui16;
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// tmp.h2r.x.data /= scale;
// tmp.h2r.y.data /= scale;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale, kv_type);
return tmp.ui32;
}
// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
union {
uint2 ui2[2];
uint4 ui4;
} tmp;
tmp.ui4 = a;
uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale, kv_type);
return res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) {
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
fp8_type::__default_saturation,
fp8_type::__default_interpret);
const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = (static_cast<float>(a)) / scale;
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) {
const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale, kv_type);
return tmp.ui16;
}
// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale, kv_type);
return tmp.ui32;
}
// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) {
uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale, kv_type);
return res;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
fp8_type::__default_interpret);
scaled_vec_conversion<uint8_t, float>(const float& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(a / scale);
} else {
return float_to_fp8_e5m2(a / scale);
}
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// floatx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
fp8_type::__default_interpret);
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale, kv_type);
return tmp.ui16;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale, kv_type);
return tmp.ui32;
}
#endif // ENABLE_FP8
// #endif // ENABLE_FP8
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin& x) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x);
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
// template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
// __inline__ __device__ Tout convert(const Tin& x) {
// #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
// return vec_conversion<Tout, Tin>(x);
// }
// #endif
// assert(false);
// return {}; // Squash missing return statement warning
// }
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
// #ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale, kv_dt);
}
#endif
// #endif
assert(false);
return {}; // Squash missing return statement warning
}
......@@ -652,19 +747,31 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E5M2) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
}
}
} // namespace fp8
#endif // USE_ROCM
} // namespace vllm
} // namespace vllm
\ No newline at end of file
......@@ -47,15 +47,19 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
x = val / scale;
}
float r =
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
// float r =
// fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM
// Use hardware cvt instruction for fp8 on nvidia
// Currently only support fp8_type = c10::Float8_e4m3fn
return fp8::vec_conversion<fp8_type, float>(r);
#else
fp8_type *test;
uint8_t test_uint8 = fp8::float_to_fp8_e4m3(x);
test = (fp8_type*)(&test_uint8);
return *test;
// Use hardware cvt instruction for fp8 on rocm
return fp8::cvt_c10<fp8_type>(r);
// return fp8::cvt_c10<fp8_type>(r);
#endif
}
......
......@@ -16,8 +16,13 @@ packaging>=24.2
setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
runai-model-streamer[s3,gcs,azure]==0.15.7
conch-triton-kernels==1.2.1
# conch-triton-kernels==1.2.1
timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
# Other necessary dependencies
torch == 2.10.0
torchvision == 0.25.0
flash_attn == 2.8.3
......@@ -20,6 +20,12 @@ from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
from setuptools_scm import get_version
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
from typing import Optional, Union
pwd = os.path.dirname(os.path.abspath(__file__))
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
def load_module_from_path(module_name, path):
......@@ -365,7 +371,7 @@ class cmake_build_ext(build_ext):
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
self.copy_file(file, dst_file)
if _is_cuda() or _is_hip():
if _is_cuda():
# copy vllm/third_party/triton_kernels/**/*.py from self.build_lib
# to current directory so that they can be included in the editable
# build
......@@ -895,6 +901,94 @@ def get_nvcc_cuda_version() -> Version:
return nvcc_cuda_version
def get_sha(root: Union[str, Path]) -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip()
except Exception:
return 'Unknown'
def get_version_add(sha: Optional[str] = None) -> str:
command = "git config --global --add safe.directory "+pwd
subprocess.run(command, shell=True, capture_output=False, text=True)
vllm_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(vllm_root, "vllm"), "version.py")
major, minor, _ = torch.__version__.split('.')
if add_git_version:
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
version = 'das.' + sha[:7]
else:
version = 'das'
# dtk version
if os.getenv("ROCM_PATH"):
rocm_path = os.getenv('ROCM_PATH', "")
rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
with open(rocm_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
rocm_version=lines[0].replace(".", "")
version += ".dtk" + rocm_version
new_version_content = f"""
try:
__version__ = "0.20.0"
__version_tuple__ = (0, 20, 0)
__hcu_version__ = f'0.20.0+{version}'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
For example - return True if the current version if 0.7.4 and the
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
def _prev_minor_version():
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
"""
with open(add_version_path, encoding="utf-8",mode="w") as file:
file.write(new_version_content)
file.close()
def get_version():
get_version_add()
version_file = 'vllm/version.py'
with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__hcu_version__']
def get_vllm_version() -> str:
# Allow overriding the version. This is useful to build platform-specific
# wheels (e.g. CPU, TPU) without modifying the source.
......@@ -903,8 +997,9 @@ def get_vllm_version() -> str:
os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = env_version
return get_version(write_to="vllm/_version.py")
version = get_version(write_to="vllm/_version.py")
sep = "+" if "+" not in version else "." # dev versions might contain +
if not _is_hip():
version = get_version(write_to="vllm/_version.py")
sep = "+" if "+" not in version else "." # dev versions might contain +
if _no_device():
if envs.VLLM_TARGET_DEVICE == "empty":
......@@ -921,9 +1016,10 @@ def get_vllm_version() -> str:
version += f"{sep}cu{cuda_version_str}"
elif _is_hip():
# Get the Rocm Version
rocm_version = get_rocm_version() or torch.version.hip
if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION:
version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
# rocm_version = get_rocm_version() or torch.version.hip
# if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION:
# version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
version = get_version()
elif _is_tpu():
version += f"{sep}tpu"
elif _is_cpu():
......@@ -991,7 +1087,7 @@ if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
# Optional since this doesn't get built (produce an .so file). This is just
# copying the relevant .py files from the source repository.
ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
# ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
......
......@@ -44,10 +44,10 @@ except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
# import custom ops, trigger op registration
try:
import vllm._rocm_C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e)
# try:
# import vllm._rocm_C # noqa: F401
# except ImportError as e:
# logger.warning("Failed to import from vllm._rocm_C with %r", e)
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: list[str] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment