removeCM.cc 2.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/**
 * Calculate the center of mass momentum.
 */

KERNEL void calcCenterOfMassMomentum(int numAtoms, GLOBAL const mixed4* RESTRICT velm, GLOBAL float3* RESTRICT cmMomentum) {
    LOCAL float3 temp[64];
    float3 cm = make_float3(0, 0, 0);
    for (int index = GLOBAL_ID; index < numAtoms; index += GLOBAL_SIZE) {
        mixed4 velocity = velm[index];
        if (velocity.w != 0) {
            mixed mass = RECIP(velocity.w);
            cm.x += (float) (velocity.x*mass);
            cm.y += (float) (velocity.y*mass);
            cm.z += (float) (velocity.z*mass);
        }
    }

    // Sum the threads in this group.

    int thread = LOCAL_ID;
    temp[thread] = cm;
    SYNC_THREADS;
    if (thread < 32)
        temp[thread] += temp[thread+32];
    SYNC_THREADS;
    if (thread < 16)
        temp[thread] += temp[thread+16];
    SYNC_THREADS;
    if (thread < 8)
        temp[thread] += temp[thread+8];
    SYNC_THREADS;
    if (thread < 4)
        temp[thread] += temp[thread+4];
    SYNC_THREADS;
    if (thread < 2)
        temp[thread] += temp[thread+2];
    SYNC_THREADS;
    if (thread == 0)
        cmMomentum[GROUP_ID] = temp[thread]+temp[thread+1];
}

/**
 * Remove center of mass motion.
 */

KERNEL void removeCenterOfMassMomentum(int numAtoms, GLOBAL mixed4* RESTRICT velm, GLOBAL const float3* RESTRICT cmMomentum) {
    // First sum all of the momenta that were calculated by individual groups.

    LOCAL float3 temp[64];
    float3 cm = make_float3(0, 0, 0);
    for (int index = LOCAL_ID; index < NUM_GROUPS; index += LOCAL_SIZE)
        cm += cmMomentum[index];
    int thread = LOCAL_ID;
    temp[thread] = cm;
    SYNC_THREADS;
    if (thread < 32)
        temp[thread] += temp[thread+32];
    SYNC_THREADS;
    if (thread < 16)
        temp[thread] += temp[thread+16];
    SYNC_THREADS;
    if (thread < 8)
        temp[thread] += temp[thread+8];
    SYNC_THREADS;
    if (thread < 4)
        temp[thread] += temp[thread+4];
    SYNC_THREADS;
    if (thread < 2)
        temp[thread] += temp[thread+2];
    SYNC_THREADS;
    cm = make_float3(INVERSE_TOTAL_MASS*(temp[0].x+temp[1].x), INVERSE_TOTAL_MASS*(temp[0].y+temp[1].y), INVERSE_TOTAL_MASS*(temp[0].z+temp[1].z));

    // Now remove the center of mass velocity from each atom.

    for (int index = GLOBAL_ID; index < numAtoms; index += GLOBAL_SIZE) {
        velm[index].x -= cm.x;
        velm[index].y -= cm.y;
        velm[index].z -= cm.z;
    }
}