vectorize_neon.h 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#ifndef OPENMM_VECTORIZE_NEON_H_
#define OPENMM_VECTORIZE_NEON_H_

/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
 * Portions copyright (c) 2013-2014 Stanford University and the Authors.      *
 * Authors: Mateus Lima, Peter Eastman                                        *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

#include <cpu-features.h>
#include <arm_neon.h>
#include <cmath>

typedef int int32_t;

// This file defines classes and functions to simplify vectorizing code with NEON.

peastman's avatar
peastman committed
43
44
45
46
// These two functions are defined in the vecmath library, which is linked into OpenMM.
float32x4_t exp_ps(float32x4_t);
float32x4_t log_ps(float32x4_t);

47
48
49
50
51
52
53
54
/**
 * Determine whether ivec4 and fvec4 are supported on this processor.
 */
static bool isVec4Supported() {
    uint64_t features = android_getCpuFeatures();
    return (features & ANDROID_CPU_ARM_FEATURE_NEON) != 0;
}

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class ivec4;

/**
 * A four element vector of floats.
 */
class fvec4 {
public:
    float32x4_t val;

    fvec4() {}
    fvec4(float v) : val(vdupq_n_f32(v)) {}
    fvec4(float v1, float v2, float v3, float v4) {
        float v[] = {v1, v2, v3, v4};
        val = vld1q_f32(v);
    }
    fvec4(float32x4_t v) : val(v) {}
    fvec4(const float* v) : val(vld1q_f32(v)) {}
    operator float32x4_t() const {
        return val;
    }
    float operator[](int i) const {
76
77
78
79
80
81
82
83
84
85
86
        switch (i) {
            case 0:
                return vgetq_lane_f32(val, 0);
            case 1:
                return vgetq_lane_f32(val, 1);
            case 2:
                return vgetq_lane_f32(val, 2);
            case 3:
                return vgetq_lane_f32(val, 3);
        }
        return 0.0f;
87
88
89
90
    }
    void store(float* v) const {
        vst1q_f32(v, val);
    }
91
    fvec4 operator+(const fvec4& other) const {
92
93
        return vaddq_f32(val, other);
    }
94
    fvec4 operator-(const fvec4& other) const {
95
96
        return vsubq_f32(val, other);
    }
97
    fvec4 operator*(const fvec4& other) const {
98
99
        return vmulq_f32(val, other);
    }
100
    fvec4 operator/(const fvec4& other) const {
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        // NEON does not have a divide float-point operator, so we get the reciprocal and multiply.

        float32x4_t reciprocal = vrecpeq_f32(other);
        reciprocal = vmulq_f32(vrecpsq_f32(other, reciprocal), reciprocal);
        reciprocal = vmulq_f32(vrecpsq_f32(other, reciprocal), reciprocal);
        fvec4 result = vmulq_f32(val,reciprocal);
        return result;
    }
    void operator+=(const fvec4& other) {
        val = vaddq_f32(val, other);
    }
    void operator-=(const fvec4& other) {
        val = vsubq_f32(val, other);
    }
    void operator*=(const fvec4& other) {
        val = vmulq_f32(val, other);
    }
    void operator/=(const fvec4& other) {
119
        val = *this/other;
120
121
122
123
124
125
126
127
    }
    fvec4 operator-() const {
        return vnegq_f32(val);
    }
    fvec4 operator&(const fvec4& other) const {
        return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(val), vreinterpretq_u32_f32(other)));
    }
    fvec4 operator|(const fvec4& other) const {
128
        return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(val), vreinterpretq_u32_f32(other)));
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    }
    fvec4 operator==(const fvec4& other) const {
        return vcvtq_f32_s32(vreinterpretq_s32_u32(vceqq_f32(val, other)));
    }
    fvec4 operator!=(const fvec4& other) const {
        return vcvtq_f32_s32(vreinterpretq_s32_u32(vmvnq_u32(vceqq_f32(val, other)))); // not(equals(val, other))
    }
    fvec4 operator>(const fvec4& other) const {
        return vcvtq_f32_s32(vreinterpretq_s32_u32(vcgtq_f32(val, other)));
    }
    fvec4 operator<(const fvec4& other) const {
        return vcvtq_f32_s32(vreinterpretq_s32_u32(vcltq_f32(val, other)));
    }
    fvec4 operator>=(const fvec4& other) const {
        return vcvtq_f32_s32(vreinterpretq_s32_u32(vcgeq_f32(val, other)));
    }
    fvec4 operator<=(const fvec4& other) const {
        return vcvtq_f32_s32(vreinterpretq_s32_u32(vcleq_f32(val, other)));
    }
    operator ivec4() const;
};

/**
 * A four element vector of ints.
 */
class ivec4 {
public:
    
    int32x4_t val;

    ivec4() {}
    ivec4(int v) : val(vdupq_n_s32(v)) {}
    ivec4(int v1, int v2, int v3, int v4) {
        int v[] = {v1, v2, v3, v4};
        val = vld1q_s32(v);
    }
    ivec4(int32x4_t v) : val(v) {}
    ivec4(const int* v) : val(vld1q_s32(v)) {}
    operator int32x4_t() const {
        return val;
    }
    int operator[](int i) const {
171
172
173
174
175
176
177
178
179
180
181
        switch (i) {
            case 0:
                return vgetq_lane_s32(val, 0);
            case 1:
                return vgetq_lane_s32(val, 1);
            case 2:
                return vgetq_lane_s32(val, 2);
            case 3:
                return vgetq_lane_s32(val, 3);
        }
        return 0;
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    }
    void store(int* v) const {
        vst1q_s32(v, val);
    }
    ivec4 operator+(const ivec4& other) const {
        return vaddq_s32(val, other);
    }
    ivec4 operator-(const ivec4& other) const {
        return vsubq_s32(val, other);
    }
    ivec4 operator*(const ivec4& other) const {
        return vmulq_s32(val, other);
    }
    void operator+=(const ivec4& other) {
        val = vaddq_s32(val, other);
    }
    void operator-=(const ivec4& other) {
        val = vsubq_s32(val, other);
    }
    void operator*=(const ivec4& other) {
        val = vmulq_s32(val, other);
    }
    ivec4 operator-() const {
        return vnegq_s32(val);
    }
207
208
    ivec4 operator&(const ivec4& other) const {
        return vandq_s32(val, other);
209
210
    }
    ivec4 operator|(const ivec4& other) const {
211
        return vorrq_s32(val, other);
212
213
    }
    ivec4 operator==(const ivec4& other) const {
214
        return vreinterpretq_s32_u32(vceqq_s32(val, other));
215
    }
216
217
    ivec4 operator!=(const ivec4& other) const {
        return vreinterpretq_s32_u32(vmvnq_u32(vceqq_s32(val, other))); // not(equal(val, other))
218
219
    }
    ivec4 operator>(const ivec4& other) const {
220
        return vreinterpretq_s32_u32(vcgtq_s32(val, other));
221
222
    }
    ivec4 operator<(const ivec4& other) const {
223
        return vreinterpretq_s32_u32(vcltq_s32(val, other));
224
225
    }
    ivec4 operator>=(const ivec4& other) const {
226
        return vreinterpretq_s32_u32(vcgeq_s32(val, other));
227
    }
228
229
    ivec4 operator<=(const ivec4& other) const {
        return vreinterpretq_s32_u32(vcleq_s32(val, other));
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    }
    operator fvec4() const;
};

// Conversion operators.

inline fvec4::operator ivec4() const {
    return ivec4(vcvtq_s32_f32(val));
}

inline ivec4::operator fvec4() const {
    return fvec4(vcvtq_f32_s32(val));
}

// Functions that operate on fvec4s.

246
static inline fvec4 min(const fvec4& v1, const fvec4& v2) {
247
    return vminq_f32(v1, v2);
248
249
}

250
static inline fvec4 max(const fvec4& v1, const fvec4& v2) {
251
    return vmaxq_f32(v1, v2);
252
253
}

254
255
static inline fvec4 abs(const fvec4& v) {
    return vabsq_f32(v);
256
257
}

peastman's avatar
peastman committed
258
static inline fvec4 rsqrt(const fvec4& v) {
259
260
261
    float32x4_t recipSqrt = vrsqrteq_f32(v);
    recipSqrt = vmulq_f32(recipSqrt, vrsqrtsq_f32(vmulq_f32(recipSqrt, v), recipSqrt));
    recipSqrt = vmulq_f32(recipSqrt, vrsqrtsq_f32(vmulq_f32(recipSqrt, v), recipSqrt));
peastman's avatar
peastman committed
262
263
264
265
266
    return recipSqrt;
}

static inline fvec4 sqrt(const fvec4& v) {
    return rsqrt(v)*v;
267
268
}

peastman's avatar
peastman committed
269
270
271
272
273
274
275
276
static inline fvec4 exp(const fvec4& v) {
    return fvec4(exp_ps(v.val));
}

static inline fvec4 log(const fvec4& v) {
    return fvec4(log_ps(v.val));
}

277
278
279
static inline float dot3(const fvec4& v1, const fvec4& v2) {
    fvec4 result = v1*v2;
    return vgetq_lane_f32(result, 0) + vgetq_lane_f32(result, 1) + vgetq_lane_f32(result, 2);
280
281
}

282
283
284
static inline float dot4(const fvec4& v1, const fvec4& v2) {
    fvec4 result = v1*v2;
    return vgetq_lane_f32(result, 0) + vgetq_lane_f32(result, 1) + vgetq_lane_f32(result, 2) + vgetq_lane_f32(result,3);
285
286
}

287
288
289
290
291
292
static inline fvec4 cross(const fvec4& v1, const fvec4& v2) {
    return fvec4(v1[1]*v2[2] - v1[2]*v2[1],
                 v1[2]*v2[0] - v1[0]*v2[2],
                 v1[0]*v2[1] - v1[1]*v2[0], 0);
}

293
294
295
296
297
298
299
300
301
static inline void transpose(fvec4& v1, fvec4& v2, fvec4& v3, fvec4& v4) {
    float32x4x2_t t1 = vuzpq_f32(v1, v3);
    float32x4x2_t t2 = vuzpq_f32(v2, v4);
    float32x4x2_t t3 = vtrnq_f32(t1.val[0], t2.val[0]);
    float32x4x2_t t4 = vtrnq_f32(t1.val[1], t2.val[1]);
    v1 = t3.val[0];
    v2 = t4.val[0];
    v3 = t3.val[1];
    v4 = t4.val[1];
302
303
304
305
}

// Functions that operate on ivec4s.

306
307
static inline ivec4 min(const ivec4& v1, const ivec4& v2) {
    return vminq_s32(v1, v2);
308
309
}

310
311
static inline ivec4 max(const ivec4& v1, const ivec4& v2) {
    return vmaxq_s32(v1, v2);
312
313
}

314
315
static inline ivec4 abs(const ivec4& v) {
    return vabdq_s32(v, ivec4(0));
316
317
}

318
319
static inline bool any(const ivec4& v) {
    return (vgetq_lane_s32(v, 0) != 0 || vgetq_lane_s32(v, 1) != 0 || vgetq_lane_s32(v, 2) != 0 || vgetq_lane_s32(v, 3) != 0);
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
}

// Mathematical operators involving a scalar and a vector.

static inline fvec4 operator+(float v1, const fvec4& v2) {
    return fvec4(v1)+v2;
}

static inline fvec4 operator-(float v1, const fvec4& v2) {
    return fvec4(v1)-v2;
}

static inline fvec4 operator*(float v1, const fvec4& v2) {
    return fvec4(v1)*v2;
}

static inline fvec4 operator/(float v1, const fvec4& v2) {
    return fvec4(v1)/v2;
}

// Operations for blending fvec4s based on an ivec4.

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
static inline fvec4 blend(const fvec4& v1, const fvec4& v2, const ivec4& mask) {
    return vbslq_f32(vreinterpretq_u32_s32(mask), v2, v1);
}

// These are at the end since they involve other functions defined above.

static inline fvec4 round(const fvec4& v) {
    fvec4 shift(0x1.0p23f);
    fvec4 absResult = (abs(v)+shift)-shift;
    return blend(v, absResult, ivec4(0x7FFFFFFF));
}

static inline fvec4 floor(const fvec4& v) {
    fvec4 rounded = round(v);
    return rounded + blend(0.0f, -1.0f, rounded>v);
}

static inline fvec4 ceil(const fvec4& v) {
    fvec4 rounded = round(v);
    return rounded + blend(0.0f, 1.0f, rounded<v);
362
363
364
}

#endif /*OPENMM_VECTORIZE_NEON_H_*/