ContextImpl.cpp 16.6 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2008-2013 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

32
33
34
35
36
37
#include "openmm/Force.h"
#include "openmm/Integrator.h"
#include "openmm/OpenMMException.h"
#include "openmm/System.h"
#include "openmm/kernels.h"
#include "openmm/internal/ForceImpl.h"
38
#include "openmm/internal/ContextImpl.h"
39
#include "openmm/State.h"
40
#include "openmm/VirtualSite.h"
Peter Eastman's avatar
Peter Eastman committed
41
42
#include "openmm/Context.h"
#include <iostream>
43
#include <map>
44
#include <utility>
45
46
47
#include <vector>

using namespace OpenMM;
Peter Eastman's avatar
Peter Eastman committed
48
using namespace std;
49

50
ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties) :
51
52
        owner(owner), system(system), integrator(integrator), hasInitializedForces(false), hasSetPositions(false), integratorIsDeleted(false),
        lastForceGroups(-1), platform(platform), platformData(NULL) {
53
54
    if (system.getNumParticles() == 0)
        throw OpenMMException("Cannot create a Context for a System with no particles");
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    
    // Check for errors in virtual sites and massless particles.
    
    for (int i = 0; i < system.getNumParticles(); i++) {
        if (system.isVirtualSite(i)) {
            if (system.getParticleMass(i) != 0.0)
                throw OpenMMException("Virtual site has nonzero mass");
            const VirtualSite& site = system.getVirtualSite(i);
            for (int j = 0; j < site.getNumParticles(); j++)
                if (system.isVirtualSite(site.getParticle(j)))
                    throw OpenMMException("A virtual site cannot depend on another virtual site");
        }
    }
    for (int i = 0; i < system.getNumConstraints(); i++) {
        int particle1, particle2;
        double distance;
        system.getConstraintParameters(i, particle1, particle2, distance);
        if (system.getParticleMass(particle1) == 0.0 || system.getParticleMass(particle2) == 0.0)
            throw OpenMMException("A constraint cannot involve a massless particle");
    }
    
    // Find the list of kernels required.
    
78
    vector<string> kernelNames;
79
    kernelNames.push_back(CalcForcesAndEnergyKernel::Name());
80
    kernelNames.push_back(UpdateStateDataKernel::Name());
81
    for (int i = 0; i < system.getNumForces(); ++i) {
82
        forceImpls.push_back(system.getForce(i).createImpl());
83
84
        map<string, double> forceParameters = forceImpls[forceImpls.size()-1]->getDefaultParameters();
        parameters.insert(forceParameters.begin(), forceParameters.end());
85
86
87
        vector<string> forceKernels = forceImpls[forceImpls.size()-1]->getKernelNames();
        kernelNames.insert(kernelNames.begin(), forceKernels.begin(), forceKernels.end());
    }
88
    hasInitializedForces = true;
89
90
    vector<string> integratorKernels = integrator.getKernelNames();
    kernelNames.insert(kernelNames.begin(), integratorKernels.begin(), integratorKernels.end());
91
    if (platform == 0)
92
        this->platform = platform = &Platform::findPlatform(kernelNames);
93
    else if (!platform->supportsKernels(kernelNames))
94
        throw OpenMMException("Specified a Platform for a Context which does not support all required kernels");
95
96
97
    
    // Create and initialize kernels and other objects.
    
98
    platform->contextCreated(*this, properties);
99
    initializeForcesKernel = platform->createKernel(CalcForcesAndEnergyKernel::Name(), *this);
100
    initializeForcesKernel.getAs<CalcForcesAndEnergyKernel>().initialize(system);
101
    updateStateDataKernel = platform->createKernel(UpdateStateDataKernel::Name(), *this);
102
    updateStateDataKernel.getAs<UpdateStateDataKernel>().initialize(system);
103
    applyConstraintsKernel = platform->createKernel(ApplyConstraintsKernel::Name(), *this);
104
    applyConstraintsKernel.getAs<ApplyConstraintsKernel>().initialize(system);
105
    virtualSitesKernel = platform->createKernel(VirtualSitesKernel::Name(), *this);
106
    virtualSitesKernel.getAs<VirtualSitesKernel>().initialize(system);
107
108
    Vec3 periodicBoxVectors[3];
    system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
109
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setPeriodicBoxVectors(*this, periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
110
    for (size_t i = 0; i < forceImpls.size(); ++i)
111
        forceImpls[i]->initialize(*this);
112
    integrator.initialize(*this);
113
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setVelocities(*this, vector<Vec3>(system.getNumParticles()));
114
115
}

116
ContextImpl::~ContextImpl() {
117
118
    for (int i = 0; i < (int) forceImpls.size(); ++i)
        delete forceImpls[i];
119
120
121
122
123
124
125
    
    // Make sure all kernels get properly deleted before contextDestroyed() is called.
    
    initializeForcesKernel = Kernel();
    updateStateDataKernel = Kernel();
    applyConstraintsKernel = Kernel();
    virtualSitesKernel = Kernel();
126
127
128
129
130
131
    if (!integratorIsDeleted) {
        // The Context is being deleted before the Integrator, so call cleanup() on it now.
        
        integrator.cleanup();
        integrator.context = NULL;
    }
132
    platform->contextDestroyed(*this);
133
134
}

135
double ContextImpl::getTime() const {
136
    return updateStateDataKernel.getAs<const UpdateStateDataKernel>().getTime(*this);
137
138
}

139
void ContextImpl::setTime(double t) {
140
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setTime(*this, t);
141
142
143
}

void ContextImpl::getPositions(std::vector<Vec3>& positions) {
144
    updateStateDataKernel.getAs<UpdateStateDataKernel>().getPositions(*this, positions);
145
146
147
}

void ContextImpl::setPositions(const std::vector<Vec3>& positions) {
148
    hasSetPositions = true;
149
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setPositions(*this, positions);
150
    integrator.stateChanged(State::Positions);
151
152
153
}

void ContextImpl::getVelocities(std::vector<Vec3>& velocities) {
154
    updateStateDataKernel.getAs<UpdateStateDataKernel>().getVelocities(*this, velocities);
155
156
157
}

void ContextImpl::setVelocities(const std::vector<Vec3>& velocities) {
158
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setVelocities(*this, velocities);
159
    integrator.stateChanged(State::Velocities);
160
161
162
}

void ContextImpl::getForces(std::vector<Vec3>& forces) {
163
    updateStateDataKernel.getAs<UpdateStateDataKernel>().getForces(*this, forces);
164
165
}

166
167
168
169
const std::map<std::string, double>& ContextImpl::getParameters() const {
    return parameters;
}

170
double ContextImpl::getParameter(std::string name) {
171
172
173
174
175
    if (parameters.find(name) == parameters.end())
        throw OpenMMException("Called getParameter() with invalid parameter name");
    return parameters[name];
}

176
void ContextImpl::setParameter(std::string name, double value) {
177
178
179
    if (parameters.find(name) == parameters.end())
        throw OpenMMException("Called setParameter() with invalid parameter name");
    parameters[name] = value;
180
    integrator.stateChanged(State::Parameters);
181
182
}

183
void ContextImpl::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) {
184
    updateStateDataKernel.getAs<UpdateStateDataKernel>().getPeriodicBoxVectors(*this, a, b, c);
185
186
187
188
189
190
191
192
193
}

void ContextImpl::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
    if (a[1] != 0.0 || a[2] != 0.0)
        throw OpenMMException("First periodic box vector must be parallel to x.");
    if (b[0] != 0.0 || b[2] != 0.0)
        throw OpenMMException("Second periodic box vector must be parallel to y.");
    if (c[0] != 0.0 || c[1] != 0.0)
        throw OpenMMException("Third periodic box vector must be parallel to z.");
194
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setPeriodicBoxVectors(*this, a, b, c);
195
196
}

197
void ContextImpl::applyConstraints(double tol) {
198
    applyConstraintsKernel.getAs<ApplyConstraintsKernel>().apply(*this, tol);
199
200
}

201
202
203
204
void ContextImpl::applyVelocityConstraints(double tol) {
    applyConstraintsKernel.getAs<ApplyConstraintsKernel>().applyToVelocities(*this, tol);
}

205
void ContextImpl::computeVirtualSites() {
206
    virtualSitesKernel.getAs<VirtualSitesKernel>().computePositions(*this);
207
208
}

209
double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy, int groups) {
210
211
    if (!hasSetPositions)
        throw OpenMMException("Particle positions have not been set");
212
    lastForceGroups = groups;
213
    CalcForcesAndEnergyKernel& kernel = initializeForcesKernel.getAs<CalcForcesAndEnergyKernel>();
214
    double energy = 0.0;
215
    kernel.beginComputation(*this, includeForces, includeEnergy, groups);
216
    for (int i = 0; i < (int) forceImpls.size(); ++i)
217
218
        energy += forceImpls[i]->calcForcesAndEnergy(*this, includeForces, includeEnergy, groups);
    energy += kernel.finishComputation(*this, includeForces, includeEnergy, groups);
219
    return energy;
220
221
}

222
223
224
225
int ContextImpl::getLastForceGroups() const {
    return lastForceGroups;
}

226
double ContextImpl::calcKineticEnergy() {
227
    return integrator.computeKineticEnergy();
228
229
}

230
void ContextImpl::updateContextState() {
231
232
233
234
    for (int i = 0; i < (int) forceImpls.size(); ++i)
        forceImpls[i]->updateContextState(*this);
}

235
236
237
238
const vector<ForceImpl*>& ContextImpl::getForceImpls() const {
    return forceImpls;
}

239
240
241
242
vector<ForceImpl*>& ContextImpl::getForceImpls() {
    return forceImpls;
}

243
void* ContextImpl::getPlatformData() {
244
245
246
    return platformData;
}

247
248
249
250
const void* ContextImpl::getPlatformData() const {
    return platformData;
}

251
void ContextImpl::setPlatformData(void* data) {
252
253
    platformData = data;
}
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

const vector<vector<int> >& ContextImpl::getMolecules() const {
    if (!hasInitializedForces)
        throw OpenMMException("ContextImpl: getMolecules() cannot be called until all ForceImpls have been initialized");
    if (molecules.size() > 0 || system.getNumParticles() == 0)
        return molecules;

    // First make a list of bonds and constraints.

    vector<pair<int, int> > bonds;
    for (int i = 0; i < system.getNumConstraints(); i++) {
        int particle1, particle2;
        double distance;
        system.getConstraintParameters(i, particle1, particle2, distance);
        bonds.push_back(std::make_pair(particle1, particle2));
    }
    for (int i = 0; i < (int) forceImpls.size(); i++) {
        vector<pair<int, int> > forceBonds = forceImpls[i]->getBondedParticles();
        bonds.insert(bonds.end(), forceBonds.begin(), forceBonds.end());
    }
274
275
276
277
278
279
280
    for (int i = 0; i < system.getNumParticles(); i++) {
        if (system.isVirtualSite(i)) {
            const VirtualSite& site = system.getVirtualSite(i);
            for (int j = 0; j < site.getNumParticles(); j++)
                bonds.push_back(std::make_pair(i, site.getParticle(j)));
        }
    }
281
282
283
284
285
286
287
288
289
290

    // Make a list of every other particle to which each particle is connected

    int numParticles = system.getNumParticles();
    vector<vector<int> > particleBonds(numParticles);
    for (int i = 0; i < (int) bonds.size(); i++) {
        particleBonds[bonds[i].first].push_back(bonds[i].second);
        particleBonds[bonds[i].second].push_back(bonds[i].first);
    }

291
    // Now identify particles by which molecule they belong to.
292

293
294
295
296
297
298
299
300
301
    molecules = findMolecules(numParticles, particleBonds);
    return molecules;
}

vector<vector<int> > ContextImpl::findMolecules(int numParticles, vector<vector<int> >& particleBonds) {
    // This is essentially a recursive algorithm, but it is reformulated as a loop to avoid
    // stack overflows.  It selects a particle, marks it as a new molecule, then recursively
    // marks every particle bonded to it as also being in that molecule.
    
302
303
304
    vector<int> particleMolecule(numParticles, -1);
    int numMolecules = 0;
    for (int i = 0; i < numParticles; i++)
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        if (particleMolecule[i] == -1) {
            // Start a new molecule.
            
            vector<int> particleStack;
            vector<int> neighborStack;
            particleStack.push_back(i);
            neighborStack.push_back(0);
            int molecule = numMolecules++;
            
            // Recursively tag all the bonded particles.
            
            while (particleStack.size() > 0) {
                int particle = particleStack.back();
                particleMolecule[particle] = molecule;
                int& neighbor = neighborStack.back();
                while (neighbor < particleBonds[particle].size() && particleMolecule[particleBonds[particle][neighbor]] != -1)
                    neighbor++;
                if (neighbor < particleBonds[particle].size()) {
                    particleStack.push_back(particleBonds[particle][neighbor]);
                    neighborStack.push_back(0);
                }
                else {
                    particleStack.pop_back();
                    neighborStack.pop_back();
                }
            }
        }
    
    // Build the final output vector.
    
    vector<vector<int> > molecules(numMolecules);
336
337
338
339
340
    for (int i = 0; i < numParticles; i++)
        molecules[particleMolecule[i]].push_back(i);
    return molecules;
}

Peter Eastman's avatar
Peter Eastman committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
static void writeString(ostream& stream, string str) {
    int length = str.size();
    stream.write((char*) &length, sizeof(int));
    stream.write((char*) &str[0], length);
}

static string readString(istream& stream) {
    int length;
    stream.read((char*) &length, sizeof(int));
    string str(length, ' ');
    stream.read((char*) &str[0], length);
    return str;
}

void ContextImpl::createCheckpoint(ostream& stream) {
    writeString(stream, getPlatform().getName());
    int numParticles = getSystem().getNumParticles();
    stream.write((char*) &numParticles, sizeof(int));
    int numParameters = parameters.size();
    stream.write((char*) &numParameters, sizeof(int));
    for (map<string, double>::const_iterator iter = parameters.begin(); iter != parameters.end(); ++iter) {
        writeString(stream, iter->first);
        stream.write((char*) &iter->second, sizeof(double));
    }
365
    updateStateDataKernel.getAs<UpdateStateDataKernel>().createCheckpoint(*this, stream);
Peter Eastman's avatar
Peter Eastman committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    stream.flush();
}

void ContextImpl::loadCheckpoint(istream& stream) {
    string platformName = readString(stream);
    if (platformName != getPlatform().getName())
        throw OpenMMException("loadCheckpoint: Checkpoint was created with a different Platform: "+platformName);
    int numParticles;
    stream.read((char*) &numParticles, sizeof(int));
    if (numParticles != getSystem().getNumParticles())
        throw OpenMMException("loadCheckpoint: Checkpoint contains the wrong number of particles");
    int numParameters;
    stream.read((char*) &numParameters, sizeof(int));
    for (int i = 0; i < numParameters; i++) {
        string name = readString(stream);
        double value;
        stream.read((char*) &value, sizeof(double));
        parameters[name] = value;
    }
385
    updateStateDataKernel.getAs<UpdateStateDataKernel>().loadCheckpoint(*this, stream);
Peter Eastman's avatar
Peter Eastman committed
386
}