customCentroidBond.cc 5.58 KB
Newer Older
1
2
3
/**
 * Compute the center of each group.
 */
4
5
6
7
KERNEL void computeGroupCenters(int numParticleGroups, GLOBAL const real4* RESTRICT posq, GLOBAL const int* RESTRICT groupParticles,
        GLOBAL const real* RESTRICT groupWeights, GLOBAL const int* RESTRICT groupOffsets, GLOBAL real4* RESTRICT centerPositions) {
    LOCAL volatile real3 temp[64];
    for (int group = GROUP_ID; group < numParticleGroups; group += NUM_GROUPS) {
8
        // The threads in this block work together to compute the center one group.
9

10
11
        int firstIndex = groupOffsets[group];
        int lastIndex = groupOffsets[group+1];
12
13
        real3 center = make_real3(0);
        for (int index = LOCAL_ID; index < lastIndex-firstIndex; index += LOCAL_SIZE) {
14
15
16
17
18
19
20
            int atom = groupParticles[firstIndex+index];
            real weight = groupWeights[firstIndex+index];
            real4 pos = posq[atom];
            center.x += weight*pos.x;
            center.y += weight*pos.y;
            center.z += weight*pos.z;
        }
21

22
        // Sum the values.
23

24
        int thread = LOCAL_ID;
25
26
27
        temp[thread].x = center.x;
        temp[thread].y = center.y;
        temp[thread].z = center.z;
28
        SYNC_THREADS;
29
30
31
32
33
        if (thread < 32) {
            temp[thread].x += temp[thread+32].x;
            temp[thread].y += temp[thread+32].y;
            temp[thread].z += temp[thread+32].z;
        }
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        SYNC_WARPS;
        if (thread < 16) {
            temp[thread].x += temp[thread+16].x;
            temp[thread].y += temp[thread+16].y;
            temp[thread].z += temp[thread+16].z;
        }
        SYNC_WARPS;
        if (thread < 8) {
            temp[thread].x += temp[thread+8].x;
            temp[thread].y += temp[thread+8].y;
            temp[thread].z += temp[thread+8].z;
        }
        SYNC_WARPS;
        if (thread < 4) {
            temp[thread].x += temp[thread+4].x;
            temp[thread].y += temp[thread+4].y;
            temp[thread].z += temp[thread+4].z;
        }
        SYNC_WARPS;
        if (thread < 2) {
            temp[thread].x += temp[thread+2].x;
            temp[thread].y += temp[thread+2].y;
            temp[thread].z += temp[thread+2].z;
        }
        SYNC_WARPS;
59
        if (thread == 0)
60
            centerPositions[group] = make_real4(temp[0].x+temp[1].x, temp[0].y+temp[1].y, temp[0].z+temp[1].z, 0);
61
62
63
64
65
66
    }
}

/**
 * Compute the difference between two vectors, setting the fourth component to the squared magnitude.
 */
67
DEVICE real4 delta(real4 vec1, real4 vec2, bool periodic, real4 periodicBoxSize, real4 invPeriodicBoxSize, 
68
        real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ) {
69
    real4 result = make_real4(vec1.x-vec2.x, vec1.y-vec2.y, vec1.z-vec2.z, 0);
70
71
    if (periodic)
        APPLY_PERIODIC_TO_DELTA(result);
72
73
74
75
76
77
78
    result.w = result.x*result.x + result.y*result.y + result.z*result.z;
    return result;
}

/**
 * Compute the angle between two vectors.  The w component of each vector should contain the squared magnitude.
 */
79
DEVICE real computeAngle(real4 vec1, real4 vec2) {
80
81
82
83
84
85
    real dotProduct = vec1.x*vec2.x + vec1.y*vec2.y + vec1.z*vec2.z;
    real cosine = dotProduct*RSQRT(vec1.w*vec2.w);
    real angle;
    if (cosine > 0.99f || cosine < -0.99f) {
        // We're close to the singularity in acos(), so take the cross product and use asin() instead.

86
        real3 crossProduct = cross(trimTo3(vec1), trimTo3(vec2));
87
        real scale = vec1.w*vec2.w;
88
        angle = ASIN(SQRT(dot(crossProduct, crossProduct)/scale));
89
90
91
92
        if (cosine < 0)
            angle = M_PI-angle;
    }
    else
93
       angle = ACOS(cosine);
94
95
96
97
98
99
    return angle;
}

/**
 * Compute the cross product of two vectors, setting the fourth component to the squared magnitude.
 */
100
101
102
DEVICE real4 computeCross(real4 vec1, real4 vec2) {
    real3 cp = cross(trimTo3(vec1), trimTo3(vec2));
    return make_real4(cp.x, cp.y, cp.z, cp.x*cp.x+cp.y*cp.y+cp.z*cp.z);
103
104
105
106
107
}

/**
 * Compute the forces on groups based on the bonds.
 */
108
109
KERNEL void computeGroupForces(int numParticleGroups, GLOBAL mm_ulong* RESTRICT groupForce, GLOBAL mixed* RESTRICT energyBuffer, GLOBAL const real4* RESTRICT centerPositions,
        GLOBAL const int* RESTRICT bondGroups, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
110
        EXTRA_ARGS) {
111
    mixed energy = 0;
112
    INIT_PARAM_DERIVS
113
    for (int index = GLOBAL_ID; index < NUM_BONDS; index += GLOBAL_SIZE) {
114
115
        COMPUTE_FORCE
    }
116
    energyBuffer[GLOBAL_ID] += energy;
117
    SAVE_PARAM_DERIVS
118
119
120
121
122
}

/**
 * Apply the forces from the group centers to the individual atoms.
 */
123
124
125
126
127
128
KERNEL void applyForcesToAtoms(int numParticleGroups, GLOBAL const int* RESTRICT groupParticles, GLOBAL const real* RESTRICT groupWeights, GLOBAL const int* RESTRICT groupOffsets,
        GLOBAL const mm_long* RESTRICT groupForce, GLOBAL mm_ulong* RESTRICT atomForce) {
    for (int group = GROUP_ID; group < numParticleGroups; group += NUM_GROUPS) {
        mm_long fx = groupForce[group];
        mm_long fy = groupForce[group+numParticleGroups];
        mm_long fz = groupForce[group+numParticleGroups*2];
129
130
        int firstIndex = groupOffsets[group];
        int lastIndex = groupOffsets[group+1];
131
        for (int index = LOCAL_ID; index < lastIndex-firstIndex; index += LOCAL_SIZE) {
132
133
            int atom = groupParticles[firstIndex+index];
            real weight = groupWeights[firstIndex+index];
134
135
136
            ATOMIC_ADD(&atomForce[atom], (mm_ulong) ((mm_long) (fx*weight)));
            ATOMIC_ADD(&atomForce[atom+PADDED_NUM_ATOMS], (mm_ulong) ((mm_long) (fy*weight)));
            ATOMIC_ADD(&atomForce[atom+2*PADDED_NUM_ATOMS], (mm_ulong) ((mm_long) (fz*weight)));
137
138
139
        }
    }
}