ContextImpl.cpp 18.1 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2008-2013 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

32
33
34
35
36
37
#include "openmm/Force.h"
#include "openmm/Integrator.h"
#include "openmm/OpenMMException.h"
#include "openmm/System.h"
#include "openmm/kernels.h"
#include "openmm/internal/ForceImpl.h"
38
#include "openmm/internal/ContextImpl.h"
39
#include "openmm/State.h"
40
#include "openmm/VirtualSite.h"
Peter Eastman's avatar
Peter Eastman committed
41
#include "openmm/Context.h"
42
#include <algorithm>
Peter Eastman's avatar
Peter Eastman committed
43
#include <iostream>
44
#include <map>
45
#include <utility>
46
47
48
#include <vector>

using namespace OpenMM;
Peter Eastman's avatar
Peter Eastman committed
49
using namespace std;
50

51
ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties) :
52
53
        owner(owner), system(system), integrator(integrator), hasInitializedForces(false), hasSetPositions(false), integratorIsDeleted(false),
        lastForceGroups(-1), platform(platform), platformData(NULL) {
54
55
    if (system.getNumParticles() == 0)
        throw OpenMMException("Cannot create a Context for a System with no particles");
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    
    // Check for errors in virtual sites and massless particles.
    
    for (int i = 0; i < system.getNumParticles(); i++) {
        if (system.isVirtualSite(i)) {
            if (system.getParticleMass(i) != 0.0)
                throw OpenMMException("Virtual site has nonzero mass");
            const VirtualSite& site = system.getVirtualSite(i);
            for (int j = 0; j < site.getNumParticles(); j++)
                if (system.isVirtualSite(site.getParticle(j)))
                    throw OpenMMException("A virtual site cannot depend on another virtual site");
        }
    }
    for (int i = 0; i < system.getNumConstraints(); i++) {
        int particle1, particle2;
        double distance;
        system.getConstraintParameters(i, particle1, particle2, distance);
73
74
75
        double mass1 = system.getParticleMass(particle1);
        double mass2 = system.getParticleMass(particle2);
        if ((mass1 == 0.0 && mass2 != 0.0) || (mass2 == 0.0 && mass1 != 0.0))
76
77
78
            throw OpenMMException("A constraint cannot involve a massless particle");
    }
    
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    // Validate the list of properties.

    const vector<string>& platformProperties = platform->getPropertyNames();
    for (map<string, string>::const_iterator iter = properties.begin(); iter != properties.end(); ++iter) {
        bool valid = false;
        for (int i = 0; i < (int) platformProperties.size(); i++)
            if (platformProperties[i] == iter->first) {
                valid = true;
                break;
            }
        if (!valid)
            throw OpenMMException("Illegal property name: "+iter->first);
    }
    
93
94
    // Find the list of kernels required.
    
95
    vector<string> kernelNames;
96
    kernelNames.push_back(CalcForcesAndEnergyKernel::Name());
97
    kernelNames.push_back(UpdateStateDataKernel::Name());
98
    for (int i = 0; i < system.getNumForces(); ++i) {
99
        forceImpls.push_back(system.getForce(i).createImpl());
100
101
        map<string, double> forceParameters = forceImpls[forceImpls.size()-1]->getDefaultParameters();
        parameters.insert(forceParameters.begin(), forceParameters.end());
102
103
104
        vector<string> forceKernels = forceImpls[forceImpls.size()-1]->getKernelNames();
        kernelNames.insert(kernelNames.begin(), forceKernels.begin(), forceKernels.end());
    }
105
    hasInitializedForces = true;
106
107
    vector<string> integratorKernels = integrator.getKernelNames();
    kernelNames.insert(kernelNames.begin(), integratorKernels.begin(), integratorKernels.end());
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    
    // Select a platform to use.
    
    vector<pair<double, Platform*> > candidatePlatforms;
    if (platform == NULL) {
        for (int i = 0; i < Platform::getNumPlatforms(); i++) {
            Platform& p = Platform::getPlatform(i);
            if (p.supportsKernels(kernelNames))
                candidatePlatforms.push_back(make_pair(p.getSpeed(), &p));
        }
        if (candidatePlatforms.size() == 0)
            throw OpenMMException("No Platform supports all the requested kernels");
        sort(candidatePlatforms.begin(), candidatePlatforms.end());
    }
    else {
        if (!platform->supportsKernels(kernelNames))
            throw OpenMMException("Specified a Platform for a Context which does not support all required kernels");
        candidatePlatforms.push_back(make_pair(platform->getSpeed(), platform));
    }
    for (int i = candidatePlatforms.size()-1; i >= 0; i--) {
        try {
            this->platform = platform = candidatePlatforms[i].second;
            platform->contextCreated(*this, properties);
            break;
        }
        catch (...) {
            if (i > 0)
                continue;
            throw;
        }
    }
139
140
141
    
    // Create and initialize kernels and other objects.
    
142
    initializeForcesKernel = platform->createKernel(CalcForcesAndEnergyKernel::Name(), *this);
143
    initializeForcesKernel.getAs<CalcForcesAndEnergyKernel>().initialize(system);
144
    updateStateDataKernel = platform->createKernel(UpdateStateDataKernel::Name(), *this);
145
    updateStateDataKernel.getAs<UpdateStateDataKernel>().initialize(system);
146
    applyConstraintsKernel = platform->createKernel(ApplyConstraintsKernel::Name(), *this);
147
    applyConstraintsKernel.getAs<ApplyConstraintsKernel>().initialize(system);
148
    virtualSitesKernel = platform->createKernel(VirtualSitesKernel::Name(), *this);
149
    virtualSitesKernel.getAs<VirtualSitesKernel>().initialize(system);
150
151
    Vec3 periodicBoxVectors[3];
    system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
152
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setPeriodicBoxVectors(*this, periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
153
    for (size_t i = 0; i < forceImpls.size(); ++i)
154
        forceImpls[i]->initialize(*this);
155
    integrator.initialize(*this);
156
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setVelocities(*this, vector<Vec3>(system.getNumParticles()));
157
158
}

159
ContextImpl::~ContextImpl() {
160
161
    for (int i = 0; i < (int) forceImpls.size(); ++i)
        delete forceImpls[i];
162
163
164
165
166
167
168
    
    // Make sure all kernels get properly deleted before contextDestroyed() is called.
    
    initializeForcesKernel = Kernel();
    updateStateDataKernel = Kernel();
    applyConstraintsKernel = Kernel();
    virtualSitesKernel = Kernel();
169
170
171
172
173
174
    if (!integratorIsDeleted) {
        // The Context is being deleted before the Integrator, so call cleanup() on it now.
        
        integrator.cleanup();
        integrator.context = NULL;
    }
175
    platform->contextDestroyed(*this);
176
177
}

178
double ContextImpl::getTime() const {
179
    return updateStateDataKernel.getAs<const UpdateStateDataKernel>().getTime(*this);
180
181
}

182
void ContextImpl::setTime(double t) {
183
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setTime(*this, t);
184
185
186
}

void ContextImpl::getPositions(std::vector<Vec3>& positions) {
187
    updateStateDataKernel.getAs<UpdateStateDataKernel>().getPositions(*this, positions);
188
189
190
}

void ContextImpl::setPositions(const std::vector<Vec3>& positions) {
191
    hasSetPositions = true;
192
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setPositions(*this, positions);
193
    integrator.stateChanged(State::Positions);
194
195
196
}

void ContextImpl::getVelocities(std::vector<Vec3>& velocities) {
197
    updateStateDataKernel.getAs<UpdateStateDataKernel>().getVelocities(*this, velocities);
198
199
200
}

void ContextImpl::setVelocities(const std::vector<Vec3>& velocities) {
201
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setVelocities(*this, velocities);
202
    integrator.stateChanged(State::Velocities);
203
204
205
}

void ContextImpl::getForces(std::vector<Vec3>& forces) {
206
    updateStateDataKernel.getAs<UpdateStateDataKernel>().getForces(*this, forces);
207
208
}

209
210
211
212
const std::map<std::string, double>& ContextImpl::getParameters() const {
    return parameters;
}

213
double ContextImpl::getParameter(std::string name) {
214
215
216
217
218
    if (parameters.find(name) == parameters.end())
        throw OpenMMException("Called getParameter() with invalid parameter name");
    return parameters[name];
}

219
void ContextImpl::setParameter(std::string name, double value) {
220
221
222
    if (parameters.find(name) == parameters.end())
        throw OpenMMException("Called setParameter() with invalid parameter name");
    parameters[name] = value;
223
    integrator.stateChanged(State::Parameters);
224
225
}

226
void ContextImpl::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) {
227
    updateStateDataKernel.getAs<UpdateStateDataKernel>().getPeriodicBoxVectors(*this, a, b, c);
228
229
230
231
232
233
234
235
236
}

void ContextImpl::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
    if (a[1] != 0.0 || a[2] != 0.0)
        throw OpenMMException("First periodic box vector must be parallel to x.");
    if (b[0] != 0.0 || b[2] != 0.0)
        throw OpenMMException("Second periodic box vector must be parallel to y.");
    if (c[0] != 0.0 || c[1] != 0.0)
        throw OpenMMException("Third periodic box vector must be parallel to z.");
237
    updateStateDataKernel.getAs<UpdateStateDataKernel>().setPeriodicBoxVectors(*this, a, b, c);
238
239
}

240
void ContextImpl::applyConstraints(double tol) {
241
    applyConstraintsKernel.getAs<ApplyConstraintsKernel>().apply(*this, tol);
242
243
}

244
245
246
247
void ContextImpl::applyVelocityConstraints(double tol) {
    applyConstraintsKernel.getAs<ApplyConstraintsKernel>().applyToVelocities(*this, tol);
}

248
void ContextImpl::computeVirtualSites() {
249
    virtualSitesKernel.getAs<VirtualSitesKernel>().computePositions(*this);
250
251
}

252
double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy, int groups) {
253
254
    if (!hasSetPositions)
        throw OpenMMException("Particle positions have not been set");
255
    lastForceGroups = groups;
256
    CalcForcesAndEnergyKernel& kernel = initializeForcesKernel.getAs<CalcForcesAndEnergyKernel>();
257
    double energy = 0.0;
258
    kernel.beginComputation(*this, includeForces, includeEnergy, groups);
259
    for (int i = 0; i < (int) forceImpls.size(); ++i)
260
261
        energy += forceImpls[i]->calcForcesAndEnergy(*this, includeForces, includeEnergy, groups);
    energy += kernel.finishComputation(*this, includeForces, includeEnergy, groups);
262
    return energy;
263
264
}

265
266
267
268
int ContextImpl::getLastForceGroups() const {
    return lastForceGroups;
}

269
double ContextImpl::calcKineticEnergy() {
270
    return integrator.computeKineticEnergy();
271
272
}

273
void ContextImpl::updateContextState() {
274
275
276
277
    for (int i = 0; i < (int) forceImpls.size(); ++i)
        forceImpls[i]->updateContextState(*this);
}

278
279
280
281
const vector<ForceImpl*>& ContextImpl::getForceImpls() const {
    return forceImpls;
}

282
283
284
285
vector<ForceImpl*>& ContextImpl::getForceImpls() {
    return forceImpls;
}

286
void* ContextImpl::getPlatformData() {
287
288
289
    return platformData;
}

290
291
292
293
const void* ContextImpl::getPlatformData() const {
    return platformData;
}

294
void ContextImpl::setPlatformData(void* data) {
295
296
    platformData = data;
}
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

const vector<vector<int> >& ContextImpl::getMolecules() const {
    if (!hasInitializedForces)
        throw OpenMMException("ContextImpl: getMolecules() cannot be called until all ForceImpls have been initialized");
    if (molecules.size() > 0 || system.getNumParticles() == 0)
        return molecules;

    // First make a list of bonds and constraints.

    vector<pair<int, int> > bonds;
    for (int i = 0; i < system.getNumConstraints(); i++) {
        int particle1, particle2;
        double distance;
        system.getConstraintParameters(i, particle1, particle2, distance);
        bonds.push_back(std::make_pair(particle1, particle2));
    }
    for (int i = 0; i < (int) forceImpls.size(); i++) {
        vector<pair<int, int> > forceBonds = forceImpls[i]->getBondedParticles();
        bonds.insert(bonds.end(), forceBonds.begin(), forceBonds.end());
    }
317
318
319
320
321
322
323
    for (int i = 0; i < system.getNumParticles(); i++) {
        if (system.isVirtualSite(i)) {
            const VirtualSite& site = system.getVirtualSite(i);
            for (int j = 0; j < site.getNumParticles(); j++)
                bonds.push_back(std::make_pair(i, site.getParticle(j)));
        }
    }
324
325
326
327
328
329
330
331
332
333

    // Make a list of every other particle to which each particle is connected

    int numParticles = system.getNumParticles();
    vector<vector<int> > particleBonds(numParticles);
    for (int i = 0; i < (int) bonds.size(); i++) {
        particleBonds[bonds[i].first].push_back(bonds[i].second);
        particleBonds[bonds[i].second].push_back(bonds[i].first);
    }

334
    // Now identify particles by which molecule they belong to.
335

336
337
338
339
340
341
342
343
344
    molecules = findMolecules(numParticles, particleBonds);
    return molecules;
}

vector<vector<int> > ContextImpl::findMolecules(int numParticles, vector<vector<int> >& particleBonds) {
    // This is essentially a recursive algorithm, but it is reformulated as a loop to avoid
    // stack overflows.  It selects a particle, marks it as a new molecule, then recursively
    // marks every particle bonded to it as also being in that molecule.
    
345
346
347
    vector<int> particleMolecule(numParticles, -1);
    int numMolecules = 0;
    for (int i = 0; i < numParticles; i++)
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        if (particleMolecule[i] == -1) {
            // Start a new molecule.
            
            vector<int> particleStack;
            vector<int> neighborStack;
            particleStack.push_back(i);
            neighborStack.push_back(0);
            int molecule = numMolecules++;
            
            // Recursively tag all the bonded particles.
            
            while (particleStack.size() > 0) {
                int particle = particleStack.back();
                particleMolecule[particle] = molecule;
                int& neighbor = neighborStack.back();
                while (neighbor < particleBonds[particle].size() && particleMolecule[particleBonds[particle][neighbor]] != -1)
                    neighbor++;
                if (neighbor < particleBonds[particle].size()) {
                    particleStack.push_back(particleBonds[particle][neighbor]);
                    neighborStack.push_back(0);
                }
                else {
                    particleStack.pop_back();
                    neighborStack.pop_back();
                }
            }
        }
    
    // Build the final output vector.
    
    vector<vector<int> > molecules(numMolecules);
379
380
381
382
383
    for (int i = 0; i < numParticles; i++)
        molecules[particleMolecule[i]].push_back(i);
    return molecules;
}

Peter Eastman's avatar
Peter Eastman committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
static void writeString(ostream& stream, string str) {
    int length = str.size();
    stream.write((char*) &length, sizeof(int));
    stream.write((char*) &str[0], length);
}

static string readString(istream& stream) {
    int length;
    stream.read((char*) &length, sizeof(int));
    string str(length, ' ');
    stream.read((char*) &str[0], length);
    return str;
}

void ContextImpl::createCheckpoint(ostream& stream) {
    writeString(stream, getPlatform().getName());
    int numParticles = getSystem().getNumParticles();
    stream.write((char*) &numParticles, sizeof(int));
    int numParameters = parameters.size();
    stream.write((char*) &numParameters, sizeof(int));
    for (map<string, double>::const_iterator iter = parameters.begin(); iter != parameters.end(); ++iter) {
        writeString(stream, iter->first);
        stream.write((char*) &iter->second, sizeof(double));
    }
408
    updateStateDataKernel.getAs<UpdateStateDataKernel>().createCheckpoint(*this, stream);
Peter Eastman's avatar
Peter Eastman committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    stream.flush();
}

void ContextImpl::loadCheckpoint(istream& stream) {
    string platformName = readString(stream);
    if (platformName != getPlatform().getName())
        throw OpenMMException("loadCheckpoint: Checkpoint was created with a different Platform: "+platformName);
    int numParticles;
    stream.read((char*) &numParticles, sizeof(int));
    if (numParticles != getSystem().getNumParticles())
        throw OpenMMException("loadCheckpoint: Checkpoint contains the wrong number of particles");
    int numParameters;
    stream.read((char*) &numParameters, sizeof(int));
    for (int i = 0; i < numParameters; i++) {
        string name = readString(stream);
        double value;
        stream.read((char*) &value, sizeof(double));
        parameters[name] = value;
    }
428
    updateStateDataKernel.getAs<UpdateStateDataKernel>().loadCheckpoint(*this, stream);
Peter Eastman's avatar
Peter Eastman committed
429
}