hip_vec_fp32_impl.h 3.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#pragma once

#if USE_ROCM

#include <hip/hip_common.h>

// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)

namespace sgl_hip {

template <>
struct vec_t<float, 1> {
  float data;

  SGL_HIP_INLINE float& operator[](size_t i) {
    return ((float*)(&data))[i];
  }
  SGL_HIP_INLINE const float& operator[](size_t i) const {
    return ((const float*)(&data))[i];
  }
  SGL_HIP_INLINE float* ptr() {
    return reinterpret_cast<float*>(&data);
  }
  SGL_HIP_INLINE void load(const float* ptr);
  SGL_HIP_INLINE void store(float* ptr) const;
  template <typename T>
  SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
    cast_from_impl(*this, src);
  }
  template <typename T>
  SGL_HIP_INLINE void cast_load(const T* ptr) {
    cast_load_impl(*this, ptr);
  }
  template <typename T>
  SGL_HIP_INLINE void cast_store(T* ptr) const {
    cast_store_impl(ptr, *this);
  }
};

SGL_HIP_INLINE void vec_t<float, 1>::load(const float* ptr) {
  data = *ptr;
}

SGL_HIP_INLINE void vec_t<float, 1>::store(float* ptr) const {
  *ptr = data;
}

// float x 2

template <>
struct vec_t<float, 2> {
  float2 data;

  SGL_HIP_INLINE float& operator[](size_t i) {
    return ((float*)(&data))[i];
  }
  SGL_HIP_INLINE const float& operator[](size_t i) const {
    return ((const float*)(&data))[i];
  }
  SGL_HIP_INLINE float* ptr() {
    return reinterpret_cast<float*>(&data);
  }
  SGL_HIP_INLINE void load(const float* ptr);
  SGL_HIP_INLINE void store(float* ptr) const;
  template <typename T>
  SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
    cast_from_impl(*this, src);
  }
  template <typename T>
  SGL_HIP_INLINE void cast_load(const T* ptr) {
    cast_load_impl(*this, ptr);
  }
  template <typename T>
  SGL_HIP_INLINE void cast_store(T* ptr) const {
    cast_store_impl(ptr, *this);
  }
};

SGL_HIP_INLINE void vec_t<float, 2>::load(const float* ptr) {
  data = *((float2*)ptr);
}

SGL_HIP_INLINE void vec_t<float, 2>::store(float* ptr) const {
  *((float2*)ptr) = data;
}

// float x 4 or more
template <size_t vec_size>
struct vec_t<float, vec_size> {
  float4 data[vec_size / 4];

  SGL_HIP_INLINE float& operator[](size_t i) {
    return ((float*)(data))[i];
  }
  SGL_HIP_INLINE const float& operator[](size_t i) const {
    return ((const float*)(data))[i];
  }
  SGL_HIP_INLINE float* ptr() {
    return reinterpret_cast<float*>(&data);
  }
  SGL_HIP_INLINE void load(const float* ptr) {
#pragma unroll
    for (size_t i = 0; i < vec_size / 4; ++i) {
      data[i] = ((float4*)ptr)[i];
    }
  }
  SGL_HIP_INLINE void store(float* ptr) const {
#pragma unroll
    for (size_t i = 0; i < vec_size / 4; ++i) {
      ((float4*)ptr)[i] = data[i];
    }
  }
  template <typename T>
  SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
    cast_from_impl(*this, src);
  }
  template <typename T>
  SGL_HIP_INLINE void cast_load(const T* ptr) {
    cast_load_impl(*this, ptr);
  }
  template <typename T>
  SGL_HIP_INLINE void cast_store(T* ptr) const {
    cast_store_impl(ptr, *this);
  }
};

}  // namespace sgl_hip

#endif