TestFindExclusions.cpp 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
 * Portions copyright (c) 2008 Stanford University and the Authors.           *
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

/**
 * This tests the findExclusions() method of StandardMMForceFieldImpl, which identifies pairs of atoms
 * whose nonbonded atoms are either excluded or decreased.  The test system is a chain with branches:
 * 
 * 1  3  5  7  9  11  13  15  17  19
 * |  |  |  |  |  |   |   |   |   |
 * 0--2--4--6--8--10--12--14--16--18
 */

#include "AssertionUtilities.h"
#include "Kernel.h"
#include "KernelFactory.h"
#include "OpenMMContext.h"
#include "Platform.h"
#include "StandardMMForceField.h"
#include "Stream.h"
#include "StreamFactory.h"
#include "System.h"
#include "VerletIntegrator.h"
#include "kernels.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include <set>
#include <vector>

using namespace OpenMM;
using namespace std;

static const int NUM_ATOMS = 20;

/**
 * Add a pair of atoms to the list of exclusions.
 */

void addAtomsToExclusions(int atom1, int atom2, vector<set<int> >& exclusions) {
    if (atom2 < NUM_ATOMS) {
        exclusions[atom1].insert(atom2);
        exclusions[atom2].insert(atom1);
    }
}

/**
 * Verify that the exclusions are what we expect.
 */

void verifyExclusions(const vector<set<int> >& exclusions) {
    vector<set<int> > expected(NUM_ATOMS);
    for (int i = 0; i < NUM_ATOMS; i += 2) {
        addAtomsToExclusions(i, i+1, expected);
        addAtomsToExclusions(i, i+2, expected);
        addAtomsToExclusions(i, i+3, expected);
        addAtomsToExclusions(i, i+4, expected);
        addAtomsToExclusions(i, i+5, expected);
        addAtomsToExclusions(i, i+6, expected);
        addAtomsToExclusions(i+1, i+2, expected);
        addAtomsToExclusions(i+1, i+3, expected);
        addAtomsToExclusions(i+1, i+4, expected);
    }
    ASSERT_EQUAL(expected.size(), exclusions.size());
    for (int i = 0; i < NUM_ATOMS; ++i) {
        ASSERT_EQUAL(expected[i].size(), exclusions[i].size());
        vector<int> intersection(0);
        insert_iterator<vector<int> > inserter(intersection, intersection.begin());
        set_intersection(exclusions[i].begin(), exclusions[i].end(), expected[i].begin(), expected[i].end(), inserter);
        ASSERT_EQUAL(expected[i].size(), intersection.size());
    }
}

/**
 * Add a pair of atoms to the list of 1-4 pairs.
 */

void addAtomsTo14List(int atom1, int atom2, set<pair<int, int> >& bonded14Indices) {
    if (atom2 < NUM_ATOMS)
        bonded14Indices.insert(pair<int, int>(atom1, atom2));
}

/**
 * Verify that the 1-4 pairs are what we expect.
 */

void verify14(const vector<vector<int> >& bonded14Indices) {
    set<pair<int, int> > expected, found;
    for (int i = 0; i < NUM_ATOMS; i += 2) {
        addAtomsTo14List(i, i+5, expected);
        addAtomsTo14List(i, i+6, expected);
        addAtomsTo14List(i+1, i+3, expected);
        addAtomsTo14List(i+1, i+4, expected);
    }
    ASSERT_EQUAL(expected.size(), bonded14Indices.size());
123
    for (size_t i = 0; i < bonded14Indices.size(); ++i) {
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        int atom1 = bonded14Indices[i][0];
        int atom2 = bonded14Indices[i][1];
        found.insert(pair<int, int>(min(atom1, atom2), max(atom1, atom2)));
    }
    vector<pair<int, int> > intersection(0);
    insert_iterator<vector<pair<int, int> > > inserter(intersection, intersection.begin());
    set_intersection(expected.begin(), expected.end(), found.begin(), found.end(), inserter);
    ASSERT_EQUAL(expected.size(), intersection.size());
}

/**
 * The following classes define a Platform whose job is to check whether the correct values were passed
 * to the initialize() methods.
 */

139
class DummyForceKernel : public CalcStandardMMForceFieldKernel {
140
public:
141
    DummyForceKernel(string name, const Platform& platform) : CalcStandardMMForceFieldKernel(name, platform) {
142
143
144
145
146
    }
    void initialize(const vector<vector<int> >& bondIndices, const vector<vector<double> >& bondParameters,
            const vector<vector<int> >& angleIndices, const vector<vector<double> >& angleParameters,
            const vector<vector<int> >& periodicTorsionIndices, const vector<vector<double> >& periodicTorsionParameters,
            const vector<vector<int> >& rbTorsionIndices, const vector<vector<double> >& rbTorsionParameters,
147
            const vector<vector<int> >& bonded14Indices, double lj14Scale, double coulomb14Scale,
148
149
            const vector<set<int> >& exclusions, const vector<vector<double> >& nonbondedParameters,
            NonbondedMethod nonbondedMethod, double nonbondedCutoff, double periodicBoxSize[3]) {
150
151
152
        verifyExclusions(exclusions);
        verify14(bonded14Indices);
    }
153
    void executeForces(const Stream& positions, Stream& forces) {
154
    }
155
    double executeEnergy(const Stream& positions) {
156
		return 0.0;
157
158
159
160
161
    }
};

class DummyIntegratorKernel : public IntegrateVerletStepKernel {
public:
162
    DummyIntegratorKernel(string name, const Platform& platform) : IntegrateVerletStepKernel(name, platform) {
163
164
165
166
167
168
169
170
171
    }
    void initialize(const vector<double>& masses, const vector<vector<int> >& constraintIndices, const vector<double>& constraintLengths) {
    }
    void execute(Stream& positions, Stream& velocities, const Stream& forces, double stepSize) {
    }
};

class DummyKEKernel : public CalcKineticEnergyKernel {
public:
172
    DummyKEKernel(string name, const Platform& platform) : CalcKineticEnergyKernel(name, platform) {
173
174
175
176
    }
    void initialize(const vector<double>& masses) {
    }
    double execute(const Stream& positions) {
177
        return 0.0;
178
179
180
181
182
    }
};

class DummyStreamImpl : public StreamImpl {
public:
183
    DummyStreamImpl(string name, int size, Stream::DataType type, const Platform& platform) : StreamImpl(name, size, type, platform) {
184
185
186
187
188
189
190
191
192
193
194
    }
    void loadFromArray(const void* array) {
    }
    void saveToArray(void* array) {
    }
    void fillWithValue(void* value) {
    }
};

class DummyKernelFactory : public KernelFactory {
public:
195
    KernelImpl* createKernelImpl(string name, const Platform& platform, OpenMMContextImpl& context) const {
196
197
        if (name == CalcStandardMMForceFieldKernel::Name())
            return new DummyForceKernel(name, platform);
198
        if (name == IntegrateVerletStepKernel::Name())
199
            return new DummyIntegratorKernel(name, platform);
200
        if (name == CalcKineticEnergyKernel::Name())
201
            return new DummyKEKernel(name, platform);
202
203
204
205
206
207
        return 0;
    }
};

class DummyStreamFactory : public StreamFactory {
public:
208
    StreamImpl* createStreamImpl(string name, int size, Stream::DataType type, const Platform& platform, OpenMMContextImpl& context) const {
209
        return new DummyStreamImpl(name, size, type, platform);
210
211
212
213
214
215
    }
};

class DummyPlatform : public Platform {
public:
    DummyPlatform() {
216
        registerKernelFactory(CalcStandardMMForceFieldKernel::Name(), new DummyKernelFactory());
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        registerKernelFactory(IntegrateVerletStepKernel::Name(), new DummyKernelFactory());
        registerKernelFactory(CalcKineticEnergyKernel::Name(), new DummyKernelFactory());
    }
    string getName() const {
        return "Dummy";
    }
    double getSpeed() const {
        return 1.0;
    }
    bool supportsDoublePrecision() const {
        return true;
    }
    const StreamFactory& getDefaultStreamFactory() const {
        return streamFactory;
    }
private:
    DummyStreamFactory streamFactory;
};

int main() {
    try {
        DummyPlatform platform;
        System system(NUM_ATOMS, 0);
        VerletIntegrator integrator(0.01);
        StandardMMForceField* forces = new StandardMMForceField(NUM_ATOMS, NUM_ATOMS-1, 0, 0, 0);
242
243
244
245
246
247
248
249
250
        
        // loop over all main-chain atoms (even numbered atoms)
        for (int i = 0; i < NUM_ATOMS-1; i += 2) 
        {
        	// side-chain bonds
        	forces->setBondParameters(i, i, i+1, 1.0, 1.0);

        	// main-chain bonds
            if (i < NUM_ATOMS-2) // penultimate atom (NUM_ATOMS-2) has no subsequent main-chain atom
251
252
                forces->setBondParameters(i+1, i, i+2, 1.0, 1.0);
        }
253
        
254
255
256
257
258
259
260
261
262
263
        system.addForce(forces);
        OpenMMContext context(system, integrator, platform);
    }
    catch(const exception& e) {
        cout << "exception: " << e.what() << endl;
        return 1;
    }
    cout << "Done" << endl;
    return 0;
}