#pragma once #if USE_ROCM #include // Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) namespace sgl_hip { template <> struct vec_t { float data; SGL_HIP_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } SGL_HIP_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } SGL_HIP_INLINE float* ptr() { return reinterpret_cast(&data); } SGL_HIP_INLINE void load(const float* ptr); SGL_HIP_INLINE void store(float* ptr) const; template SGL_HIP_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template SGL_HIP_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template SGL_HIP_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } }; SGL_HIP_INLINE void vec_t::load(const float* ptr) { data = *ptr; } SGL_HIP_INLINE void vec_t::store(float* ptr) const { *ptr = data; } // float x 2 template <> struct vec_t { float2 data; SGL_HIP_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } SGL_HIP_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } SGL_HIP_INLINE float* ptr() { return reinterpret_cast(&data); } SGL_HIP_INLINE void load(const float* ptr); SGL_HIP_INLINE void store(float* ptr) const; template SGL_HIP_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template SGL_HIP_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template SGL_HIP_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } }; SGL_HIP_INLINE void vec_t::load(const float* ptr) { data = *((float2*)ptr); } SGL_HIP_INLINE void vec_t::store(float* ptr) const { *((float2*)ptr) = data; } // float x 4 or more template struct vec_t { float4 data[vec_size / 4]; SGL_HIP_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } SGL_HIP_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; } SGL_HIP_INLINE float* ptr() { return reinterpret_cast(&data); } SGL_HIP_INLINE void load(const float* ptr) { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { data[i] = ((float4*)ptr)[i]; } } SGL_HIP_INLINE void store(float* ptr) const { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { ((float4*)ptr)[i] = data[i]; } } template SGL_HIP_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template SGL_HIP_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template SGL_HIP_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } }; } // namespace sgl_hip #endif