Commit a10e9cee authored by zhuwenwen's avatar zhuwenwen
Browse files

support bf16b infer

parent 675c0abe
......@@ -87,40 +87,40 @@ struct FloatVec<bf16_8_t> {
// Utility functions for type conversions.
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
return __bfloat1622float2(val);
#endif
// #endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
return __bfloat162bfloat162(val);
#endif
// #endif
}
// Vector addition.
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
#ifndef USE_ROCM
return a + b;
#else
return __hadd(a, b);
#endif
#endif
// #endif
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
return __hadd2(a, b);
#endif
// #endif
}
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
......@@ -163,20 +163,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
// Vector multiplication.
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
return __hmul(a, b);
#endif
// #endif
}
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
return __hmul2(a, b);
#endif
// #endif
}
template<>
......@@ -281,19 +281,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
// Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
return __hfma2(a, b, c);
#endif
// #endif
}
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
return __hfma2(bf162bf162(a), b, c);
#endif
// #endif
}
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
......@@ -406,31 +406,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
}
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
dst = __float22bfloat162_rn(src);
#endif
// #endif
}
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#endif
// #endif
}
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#endif
// #endif
}
// From bfloat16 to float32.
......@@ -440,12 +440,12 @@ inline __device__ float to_float(__nv_bfloat16 u) {
// Zero-out a variable.
inline __device__ void zero(__nv_bfloat16& dst) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false);
// #else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
#endif
// #endif
}
} // namespace vllm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment