"docs/vscode:/vscode.git/clone" did not exist on "56b7f0efa4d83865afc2da38b40b2f337d778dda"
Commit a10e9cee authored by zhuwenwen's avatar zhuwenwen
Browse files

support bf16b infer

parent 675c0abe
...@@ -87,40 +87,40 @@ struct FloatVec<bf16_8_t> { ...@@ -87,40 +87,40 @@ struct FloatVec<bf16_8_t> {
// Utility functions for type conversions. // Utility functions for type conversions.
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
return __bfloat1622float2(val); return __bfloat1622float2(val);
#endif // #endif
} }
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
return __bfloat162bfloat162(val); return __bfloat162bfloat162(val);
#endif // #endif
} }
// Vector addition. // Vector addition.
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
#ifndef USE_ROCM #ifndef USE_ROCM
return a + b; return a + b;
#else #else
return __hadd(a, b); return __hadd(a, b);
#endif #endif
#endif // #endif
} }
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
return __hadd2(a, b); return __hadd2(a, b);
#endif // #endif
} }
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
...@@ -163,20 +163,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { ...@@ -163,20 +163,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
// Vector multiplication. // Vector multiplication.
template<> template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
return __hmul(a, b); return __hmul(a, b);
#endif // #endif
} }
template<> template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
return __hmul2(a, b); return __hmul2(a, b);
#endif // #endif
} }
template<> template<>
...@@ -281,19 +281,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { ...@@ -281,19 +281,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
return __hfma2(a, b, c); return __hfma2(a, b, c);
#endif // #endif
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
return __hfma2(bf162bf162(a), b, c); return __hfma2(bf162bf162(a), b, c);
#endif // #endif
} }
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
...@@ -406,31 +406,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) { ...@@ -406,31 +406,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
} }
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
dst = __float22bfloat162_rn(src); dst = __float22bfloat162_rn(src);
#endif // #endif
} }
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
dst.x = __float22bfloat162_rn(src.x); dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y); dst.y = __float22bfloat162_rn(src.y);
#endif // #endif
} }
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
dst.x = __float22bfloat162_rn(src.x); dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y); dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z); dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w); dst.w = __float22bfloat162_rn(src.w);
#endif // #endif
} }
// From bfloat16 to float32. // From bfloat16 to float32.
...@@ -440,12 +440,12 @@ inline __device__ float to_float(__nv_bfloat16 u) { ...@@ -440,12 +440,12 @@ inline __device__ float to_float(__nv_bfloat16 u) {
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(__nv_bfloat16& dst) { inline __device__ void zero(__nv_bfloat16& dst) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); // assert(false);
#else // #else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2. // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst = __ushort_as_bfloat16((unsigned short)0x0000U); dst = __ushort_as_bfloat16((unsigned short)0x0000U);
#endif // #endif
} }
} // namespace vllm } // namespace vllm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment