#pragma once #if USE_ROCM #include #include // Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) using nv_bfloat16 = __hip_bfloat16; using nv_bfloat162 = __hip_bfloat162; __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) { __hip_bfloat162 t; t.x = x; t.y = y; return t; } namespace sgl_hip { // nv_bfloat16 x 1 template <> struct vec_t { nv_bfloat16 data; SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; } SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { return ((const nv_bfloat16*)(&data))[i]; } SGL_HIP_INLINE nv_bfloat16* ptr() { return reinterpret_cast(&data); } SGL_HIP_INLINE void load(const nv_bfloat16* ptr); SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; template SGL_HIP_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template SGL_HIP_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template SGL_HIP_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } }; SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { data = *ptr; } SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { *ptr = data; } // nv_bfloat16 x 2 template <> struct vec_t { nv_bfloat162 data; SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; } SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { return ((const nv_bfloat16*)(&data))[i]; } SGL_HIP_INLINE nv_bfloat16* ptr() { return reinterpret_cast(&data); } SGL_HIP_INLINE void load(const nv_bfloat16* ptr); SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; template SGL_HIP_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template SGL_HIP_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template SGL_HIP_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } }; SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { data = *((nv_bfloat162*)ptr); } SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { *((nv_bfloat162*)ptr) = data; } template <> struct vec_t { uint2 data; SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; } SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { return ((const nv_bfloat16*)(&data))[i]; } SGL_HIP_INLINE nv_bfloat16* ptr() { return reinterpret_cast(&data); } SGL_HIP_INLINE void load(const nv_bfloat16* ptr); SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; template SGL_HIP_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template SGL_HIP_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template SGL_HIP_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } }; SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { data = *((uint2*)ptr); } SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { *((uint2*)ptr) = data; } // nv_bfloat16 x 8 or more template struct vec_t { uint4 data[vec_size / 8]; SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)data)[i]; } SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { return ((const nv_bfloat16*)data)[i]; } SGL_HIP_INLINE nv_bfloat16* ptr() { return reinterpret_cast(&data); } SGL_HIP_INLINE void load(const nv_bfloat16* ptr) { #pragma unoll for (size_t i = 0; i < vec_size / 8; ++i) { data[i] = ((uint4*)ptr)[i]; } } SGL_HIP_INLINE void store(nv_bfloat16* ptr) const { #pragma unoll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4*)ptr)[i] = data[i]; } } template SGL_HIP_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template SGL_HIP_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template SGL_HIP_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } }; } // namespace sgl_hip #endif