CpuKernels.cpp 65.2 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2013-2021 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

#include "CpuKernels.h"
peastman's avatar
peastman committed
33
#include "ReferenceAngleBondIxn.h"
34
#include "ReferenceBondForce.h"
35
#include "ReferenceConstraints.h"
36
37
#include "ReferenceKernelFactory.h"
#include "ReferenceKernels.h"
38
#include "ReferenceLJCoulomb14.h"
39
40
#include "ReferenceProperDihedralBond.h"
#include "ReferenceRbDihedralBond.h"
41
#include "ReferenceTabulatedFunction.h"
42
#include "openmm/Context.h"
43
#include "openmm/OpenMMException.h"
peastman's avatar
peastman committed
44
#include "openmm/Vec3.h"
45
46
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
47
#include "openmm/internal/vectorize.h"
48
#include "lepton/CompiledExpression.h"
49
#include "lepton/CustomFunction.h"
50
#include "lepton/Operation.h"
51
#include "lepton/Parser.h"
52
#include <iostream>
53
#include "lepton/ParsedExpression.h"
54
55
56
57

using namespace OpenMM;
using namespace std;

peastman's avatar
peastman committed
58
static vector<Vec3>& extractPositions(ContextImpl& context) {
59
    ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
60
    return *data->positions;
61
62
}

peastman's avatar
peastman committed
63
static vector<Vec3>& extractVelocities(ContextImpl& context) {
64
    ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
65
    return *data->velocities;
66
67
}

peastman's avatar
peastman committed
68
static vector<Vec3>& extractForces(ContextImpl& context) {
69
    ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
70
    return *data->forces;
71
72
}

peastman's avatar
peastman committed
73
static Vec3& extractBoxSize(ContextImpl& context) {
74
    ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
75
    return *data->periodicBoxSize;
76
77
}

peastman's avatar
peastman committed
78
static Vec3* extractBoxVectors(ContextImpl& context) {
79
    ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
80
    return data->periodicBoxVectors;
81
82
}

83
84
static ReferenceConstraints& extractConstraints(ContextImpl& context) {
    ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
85
    return *data->constraints;
86
87
}

88
89
static map<string, double>& extractEnergyParameterDerivatives(ContextImpl& context) {
    ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
90
    return *data->energyParameterDerivatives;
91
92
}

93
94
95
96
97
98
99
/**
 * Make sure an expression doesn't use any undefined variables.
 */
static void validateVariables(const Lepton::ExpressionTreeNode& node, const set<string>& variables) {
    const Lepton::Operation& op = node.getOperation();
    if (op.getId() == Lepton::Operation::VARIABLE && variables.find(op.getName()) == variables.end())
        throw OpenMMException("Unknown variable in expression: "+op.getName());
peastman's avatar
peastman committed
100
101
    for (auto& child : node.getChildren())
        validateVariables(child, variables);
102
103
}

104
105
106
107
108
/**
 * Compute the kinetic energy of the system, possibly shifting the velocities in time to account
 * for a leapfrog integrator.
 */
static double computeShiftedKineticEnergy(ContextImpl& context, vector<double>& masses, double timeShift) {
peastman's avatar
peastman committed
109
110
111
    vector<Vec3>& posData = extractPositions(context);
    vector<Vec3>& velData = extractVelocities(context);
    vector<Vec3>& forceData = extractForces(context);
112
113
114
115
    int numParticles = context.getSystem().getNumParticles();
    
    // Compute the shifted velocities.
    
peastman's avatar
peastman committed
116
    vector<Vec3> shiftedVel(numParticles);
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    for (int i = 0; i < numParticles; ++i) {
        if (masses[i] > 0)
            shiftedVel[i] = velData[i]+forceData[i]*(timeShift/masses[i]);
        else
            shiftedVel[i] = velData[i];
    }
    
    // Apply constraints to them.
    
    vector<double> inverseMasses(numParticles);
    for (int i = 0; i < numParticles; i++)
        inverseMasses[i] = (masses[i] == 0 ? 0 : 1/masses[i]);
    extractConstraints(context).applyToVelocities(posData, shiftedVel, inverseMasses, 1e-4);
    
    // Compute the kinetic energy.
    
    double energy = 0.0;
    for (int i = 0; i < numParticles; ++i)
        if (masses[i] > 0)
            energy += masses[i]*(shiftedVel[i].dot(shiftedVel[i]));
    return 0.5*energy;
}

140
141
142
143
144
145
146
147
148
149
150
151
152
/**
 * Copy particle charges into the fourth element of the posq array.
 */
static void copyChargesToPosq(ContextImpl& context, const vector<float>& charges, int index) {
    CpuPlatform::PlatformData& data = CpuPlatform::getPlatformData(context);
    if (index == data.currentPosqIndex)
        return;
    data.currentPosqIndex = index;
    AlignedArray<float>& posq = data.posq;
    for (int i = 0; i < charges.size(); i++)
        posq[4*i+3] = charges[i];
}

peastman's avatar
peastman committed
153
154
155
156
157
158
159
CpuCalcForcesAndEnergyKernel::CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context) :
        CalcForcesAndEnergyKernel(name, platform), data(data) {
    // Create a Reference platform version of this kernel.
    
    ReferenceKernelFactory referenceFactory;
    referenceKernel = Kernel(referenceFactory.createKernelImpl(name, platform, context));
}
160

peastman's avatar
peastman committed
161
162
163
164
165
166
167
168
169
170
171
172
173
void CpuCalcForcesAndEnergyKernel::initialize(const System& system) {
    referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().initialize(system);
    lastPositions.resize(system.getNumParticles(), Vec3(1e10, 1e10, 1e10));
}

void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
    referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().beginComputation(context, includeForce, includeEnergy, groups);
    
    // Convert positions to single precision and clear the forces.

    int numParticles = context.getSystem().getNumParticles();
    bool positionsValid = true;
    data.threads.execute([&] (ThreadPool& threads, int threadIndex) {
174
175
176
        // Convert the positions to single precision and apply periodic boundary conditions

        AlignedArray<float>& posq = data.posq;
peastman's avatar
peastman committed
177
178
        vector<Vec3>& posData = extractPositions(context);
        Vec3* boxVectors = extractBoxVectors(context);
179
180
181
        double boxSize[3] = {boxVectors[0][0], boxVectors[1][1], boxVectors[2][2]};
        double invBoxSize[3] = {1/boxVectors[0][0], 1/boxVectors[1][1], 1/boxVectors[2][2]};
        bool triclinic = (boxVectors[0][1] != 0 || boxVectors[0][2] != 0 || boxVectors[1][0] != 0 || boxVectors[1][2] != 0 || boxVectors[2][0] != 0 || boxVectors[2][1] != 0);
182
183
184
185
        int numParticles = context.getSystem().getNumParticles();
        int numThreads = threads.getNumThreads();
        int start = threadIndex*numParticles/numThreads;
        int end = (threadIndex+1)*numParticles/numThreads;
186
187
188
        if (data.isPeriodic) {
            if (triclinic) {
                for (int i = start; i < end; i++) {
peastman's avatar
peastman committed
189
                    Vec3 pos = posData[i];
190
191
192
193
194
195
                    pos -= boxVectors[2]*floor(pos[2]*invBoxSize[2]);
                    pos -= boxVectors[1]*floor(pos[1]*invBoxSize[1]);
                    pos -= boxVectors[0]*floor(pos[0]*invBoxSize[0]);
                    posq[4*i] = (float) pos[0];
                    posq[4*i+1] = (float) pos[1];
                    posq[4*i+2] = (float) pos[2];
196
                }
197
198
199
200
            }
            else {
                for (int i = start; i < end; i++) {
                    for (int j = 0; j < 3; j++) {
peastman's avatar
peastman committed
201
                        double x = posData[i][j];
202
203
204
                        double base = floor(x*invBoxSize[j])*boxSize[j];
                        posq[4*i+j] = (float) (x-base);
                    }
205
                }
206
207
            }
        }
208
209
210
211
212
213
        else
            for (int i = start; i < end; i++) {
                posq[4*i] = (float) posData[i][0];
                posq[4*i+1] = (float) posData[i][1];
                posq[4*i+2] = (float) posData[i][2];
            }
214
215
216
        
        // Check for invalid positions.
        
217
218
        for (int i = 4*start; i < 4*end; i += 4)
            if (posq[i] != posq[i] || posq[i+1] != posq[i+1] || posq[i+2] != posq[i+2])
219
                positionsValid = false;
220
221
222
223
224
225

        // Clear the forces.

        fvec4 zero(0.0f);
        for (int j = 0; j < numParticles; j++)
            zero.store(&data.threadForce[threadIndex][j*4]);
peastman's avatar
peastman committed
226
    });
227
    data.threads.waitForThreads();
peastman's avatar
peastman committed
228
    if (!positionsValid)
229
        throw OpenMMException("Particle coordinate is nan");
230
231
232
233
234
235
236
237
238
239

    // Determine whether we need to recompute the neighbor list.
        
    if (data.neighborList != NULL) {
        double padding = data.paddedCutoff-data.cutoff;;
        bool needRecompute = false;
        double closeCutoff2 = 0.25*padding*padding;
        double farCutoff2 = 0.5*padding*padding;
        int maxNumMoved = numParticles/10;
        vector<int> moved;
peastman's avatar
peastman committed
240
        vector<Vec3>& posData = extractPositions(context);
241
        for (int i = 0; i < numParticles; i++) {
peastman's avatar
peastman committed
242
            Vec3 delta = posData[i]-lastPositions[i];
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            double dist2 = delta.dot(delta);
            if (dist2 > closeCutoff2) {
                moved.push_back(i);
                if (dist2 > farCutoff2 || moved.size() > maxNumMoved) {
                    needRecompute = true;
                    break;
                }
            }
        }
        if (!needRecompute && moved.size() > 0) {
            // Some particles have moved further than half the padding distance.  Look for pairs
            // that are missing from the neighbor list.

            int numMoved = moved.size();
            double cutoff2 = data.cutoff*data.cutoff;
            double paddedCutoff2 = data.paddedCutoff*data.paddedCutoff;
            for (int i = 1; i < numMoved && !needRecompute; i++)
                for (int j = 0; j < i; j++) {
peastman's avatar
peastman committed
261
                    Vec3 delta = posData[moved[i]]-posData[moved[j]];
262
263
264
                    if (delta.dot(delta) < cutoff2) {
                        // These particles should interact.  See if they are in the neighbor list.
                        
peastman's avatar
peastman committed
265
                        Vec3 oldDelta = lastPositions[moved[i]]-lastPositions[moved[j]];
266
267
268
269
270
271
272
273
274
275
276
277
                        if (oldDelta.dot(oldDelta) > paddedCutoff2) {
                            needRecompute = true;
                            break;
                        }
                    }
                }
        }
        if (needRecompute) {
            data.neighborList->computeNeighborList(numParticles, data.posq, data.exclusions, extractBoxVectors(context), data.isPeriodic, data.paddedCutoff, data.threads);
            lastPositions = posData;
        }
    }
278
279
}

280
double CpuCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups, bool& valid) {
281
282
    // Sum the forces from all the threads.
    
peastman's avatar
peastman committed
283
284
285
286
287
288
289
    data.threads.execute([&] (ThreadPool& threads, int threadIndex) {
        // Sum the contributions to forces that have been calculated by different threads.
        
        int numParticles = context.getSystem().getNumParticles();
        int numThreads = threads.getNumThreads();
        int start = threadIndex*numParticles/numThreads;
        int end = (threadIndex+1)*numParticles/numThreads;
peastman's avatar
peastman committed
290
        vector<Vec3>& forceData = extractForces(context);
peastman's avatar
peastman committed
291
292
293
294
295
296
297
298
299
        for (int i = start; i < end; i++) {
            fvec4 f(0.0f);
            for (int j = 0; j < numThreads; j++)
                f += fvec4(&data.threadForce[j][4*i]);
            forceData[i][0] += f[0];
            forceData[i][1] += f[1];
            forceData[i][2] += f[2];
        }
    });
300
    data.threads.waitForThreads();
301
    return referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().finishComputation(context, includeForce, includeEnergy, groups, valid);
302
303
}

peastman's avatar
peastman committed
304
305
void CpuCalcHarmonicAngleForceKernel::initialize(const System& system, const HarmonicAngleForce& force) {
    numAngles = force.getNumAngles();
306
307
    angleIndexArray.resize(numAngles, vector<int>(3));
    angleParamArray.resize(numAngles, vector<double>(2));
peastman's avatar
peastman committed
308
309
310
311
312
313
314
    for (int i = 0; i < numAngles; ++i) {
        int particle1, particle2, particle3;
        double angle, k;
        force.getAngleParameters(i, particle1, particle2, particle3, angle, k);
        angleIndexArray[i][0] = particle1;
        angleIndexArray[i][1] = particle2;
        angleIndexArray[i][2] = particle3;
peastman's avatar
peastman committed
315
316
        angleParamArray[i][0] = angle;
        angleParamArray[i][1] = k;
peastman's avatar
peastman committed
317
318
    }
    bondForce.initialize(system.getNumParticles(), numAngles, 3, angleIndexArray, data.threads);
319
    usePeriodic = force.usesPeriodicBoundaryConditions();
peastman's avatar
peastman committed
320
321
322
}

double CpuCalcHarmonicAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
323
324
325
    vector<Vec3>& posData = extractPositions(context);
    vector<Vec3>& forceData = extractForces(context);
    double energy = 0;
peastman's avatar
peastman committed
326
    ReferenceAngleBondIxn angleBond;
327
328
    if (usePeriodic)
        angleBond.setPeriodic(extractBoxVectors(context));
peastman's avatar
peastman committed
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    bondForce.calculateForce(posData, angleParamArray, forceData, includeEnergy ? &energy : NULL, angleBond);
    return energy;
}

void CpuCalcHarmonicAngleForceKernel::copyParametersToContext(ContextImpl& context, const HarmonicAngleForce& force) {
    if (numAngles != force.getNumAngles())
        throw OpenMMException("updateParametersInContext: The number of angles has changed");

    // Record the values.

    for (int i = 0; i < numAngles; ++i) {
        int particle1, particle2, particle3;
        double angle, k;
        force.getAngleParameters(i, particle1, particle2, particle3, angle, k);
        if (particle1 != angleIndexArray[i][0] || particle2 != angleIndexArray[i][1] || particle3 != angleIndexArray[i][2])
            throw OpenMMException("updateParametersInContext: The set of particles in an angle has changed");
peastman's avatar
peastman committed
345
346
        angleParamArray[i][0] = angle;
        angleParamArray[i][1] = k;
peastman's avatar
peastman committed
347
348
349
    }
}

350
351
void CpuCalcPeriodicTorsionForceKernel::initialize(const System& system, const PeriodicTorsionForce& force) {
    numTorsions = force.getNumTorsions();
352
353
    torsionIndexArray.resize(numTorsions, vector<int>(4));
    torsionParamArray.resize(numTorsions, vector<double>(3));
354
355
356
357
358
359
360
361
    for (int i = 0; i < numTorsions; ++i) {
        int particle1, particle2, particle3, particle4, periodicity;
        double phase, k;
        force.getTorsionParameters(i, particle1, particle2, particle3, particle4, periodicity, phase, k);
        torsionIndexArray[i][0] = particle1;
        torsionIndexArray[i][1] = particle2;
        torsionIndexArray[i][2] = particle3;
        torsionIndexArray[i][3] = particle4;
peastman's avatar
peastman committed
362
363
364
        torsionParamArray[i][0] = k;
        torsionParamArray[i][1] = phase;
        torsionParamArray[i][2] = periodicity;
365
366
    }
    bondForce.initialize(system.getNumParticles(), numTorsions, 4, torsionIndexArray, data.threads);
367
    usePeriodic = force.usesPeriodicBoundaryConditions();
368
369
370
}

double CpuCalcPeriodicTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
371
372
373
    vector<Vec3>& posData = extractPositions(context);
    vector<Vec3>& forceData = extractForces(context);
    double energy = 0;
374
    ReferenceProperDihedralBond periodicTorsionBond;
375
376
    if (usePeriodic)
        periodicTorsionBond.setPeriodic(extractBoxVectors(context));
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    bondForce.calculateForce(posData, torsionParamArray, forceData, includeEnergy ? &energy : NULL, periodicTorsionBond);
    return energy;
}

void CpuCalcPeriodicTorsionForceKernel::copyParametersToContext(ContextImpl& context, const PeriodicTorsionForce& force) {
    if (numTorsions != force.getNumTorsions())
        throw OpenMMException("updateParametersInContext: The number of torsions has changed");

    // Record the values.

    for (int i = 0; i < numTorsions; ++i) {
        int particle1, particle2, particle3, particle4, periodicity;
        double phase, k;
        force.getTorsionParameters(i, particle1, particle2, particle3, particle4, periodicity, phase, k);
        if (particle1 != torsionIndexArray[i][0] || particle2 != torsionIndexArray[i][1] || particle3 != torsionIndexArray[i][2] || particle4 != torsionIndexArray[i][3])
            throw OpenMMException("updateParametersInContext: The set of particles in a torsion has changed");
peastman's avatar
peastman committed
393
394
395
        torsionParamArray[i][0] = k;
        torsionParamArray[i][1] = phase;
        torsionParamArray[i][2] = periodicity;
396
397
398
399
400
    }
}

void CpuCalcRBTorsionForceKernel::initialize(const System& system, const RBTorsionForce& force) {
    numTorsions = force.getNumTorsions();
401
402
    torsionIndexArray.resize(numTorsions, vector<int>(4));
    torsionParamArray.resize(numTorsions, vector<double>(6));
403
404
405
406
407
408
409
410
    for (int i = 0; i < numTorsions; ++i) {
        int particle1, particle2, particle3, particle4;
        double c0, c1, c2, c3, c4, c5;
        force.getTorsionParameters(i, particle1, particle2, particle3, particle4, c0, c1, c2, c3, c4, c5);
        torsionIndexArray[i][0] = particle1;
        torsionIndexArray[i][1] = particle2;
        torsionIndexArray[i][2] = particle3;
        torsionIndexArray[i][3] = particle4;
peastman's avatar
peastman committed
411
412
413
414
415
416
        torsionParamArray[i][0] = c0;
        torsionParamArray[i][1] = c1;
        torsionParamArray[i][2] = c2;
        torsionParamArray[i][3] = c3;
        torsionParamArray[i][4] = c4;
        torsionParamArray[i][5] = c5;
417
418
    }
    bondForce.initialize(system.getNumParticles(), numTorsions, 4, torsionIndexArray, data.threads);
419
    usePeriodic = force.usesPeriodicBoundaryConditions();
420
421
422
}

double CpuCalcRBTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
423
424
425
    vector<Vec3>& posData = extractPositions(context);
    vector<Vec3>& forceData = extractForces(context);
    double energy = 0;
426
    ReferenceRbDihedralBond rbTorsionBond;
427
428
    if (usePeriodic)
        rbTorsionBond.setPeriodic(extractBoxVectors(context));
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    bondForce.calculateForce(posData, torsionParamArray, forceData, includeEnergy ? &energy : NULL, rbTorsionBond);
    return energy;
}

void CpuCalcRBTorsionForceKernel::copyParametersToContext(ContextImpl& context, const RBTorsionForce& force) {
    if (numTorsions != force.getNumTorsions())
        throw OpenMMException("updateParametersInContext: The number of torsions has changed");

    // Record the values.

    for (int i = 0; i < numTorsions; ++i) {
        int particle1, particle2, particle3, particle4;
        double c0, c1, c2, c3, c4, c5;
        force.getTorsionParameters(i, particle1, particle2, particle3, particle4, c0, c1, c2, c3, c4, c5);
        if (particle1 != torsionIndexArray[i][0] || particle2 != torsionIndexArray[i][1] || particle3 != torsionIndexArray[i][2] || particle4 != torsionIndexArray[i][3])
            throw OpenMMException("updateParametersInContext: The set of particles in a torsion has changed");
peastman's avatar
peastman committed
445
446
447
448
449
450
        torsionParamArray[i][0] = c0;
        torsionParamArray[i][1] = c1;
        torsionParamArray[i][2] = c2;
        torsionParamArray[i][3] = c3;
        torsionParamArray[i][4] = c4;
        torsionParamArray[i][5] = c5;
451
452
453
    }
}

454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
class CpuCalcNonbondedForceKernel::PmeIO : public CalcPmeReciprocalForceKernel::IO {
public:
    PmeIO(float* posq, float* force, int numParticles) : posq(posq), force(force), numParticles(numParticles) {
    }
    float* getPosq() {
        return posq;
    }
    void setForce(float* f) {
        for (int i = 0; i < numParticles; i++) {
            force[4*i] += f[4*i];
            force[4*i+1] += f[4*i+1];
            force[4*i+2] += f[4*i+2];
        }
    }
private:
    float* posq;
    float* force;
    int numParticles;
};

474
CpuNonbondedForce* createCpuNonbondedForceVec();
475
476

CpuCalcNonbondedForceKernel::CpuCalcNonbondedForceKernel(string name, const Platform& platform, CpuPlatform::PlatformData& data) : CalcNonbondedForceKernel(name, platform),
477
        data(data), hasInitializedPme(false), hasInitializedDispersionPme(false), nonbonded(NULL) {
478
    nonbonded = createCpuNonbondedForceVec();
479
480
}

481
CpuCalcNonbondedForceKernel::~CpuCalcNonbondedForceKernel() {
482
483
    if (nonbonded != NULL)
        delete nonbonded;
484
485
486
}

void CpuCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) {
487
488
    chargePosqIndex = data.requestPosqIndex();
    ljPosqIndex = data.requestPosqIndex();
489
490
491

    // Identify which exceptions are 1-4 interactions.

492
493
494
495
496
497
498
499
    set<int> exceptionsWithOffsets;
    for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) {
        string param;
        int exception;
        double charge, sigma, epsilon;
        force.getExceptionParameterOffset(i, param, exception, charge, sigma, epsilon);
        exceptionsWithOffsets.insert(exception);
    }
500
501
502
    numParticles = force.getNumParticles();
    exclusions.resize(numParticles);
    vector<int> nb14s;
503
    map<int, int> nb14Index;
504
505
506
507
508
509
    for (int i = 0; i < force.getNumExceptions(); i++) {
        int particle1, particle2;
        double chargeProd, sigma, epsilon;
        force.getExceptionParameters(i, particle1, particle2, chargeProd, sigma, epsilon);
        exclusions[particle1].insert(particle2);
        exclusions[particle2].insert(particle1);
510
511
        if (chargeProd != 0.0 || epsilon != 0.0 || exceptionsWithOffsets.find(i) != exceptionsWithOffsets.end()) {
            nb14Index[i] = nb14s.size();
512
            nb14s.push_back(i);
513
        }
514
515
516
517
518
    }

    // Record the particle parameters.

    num14 = nb14s.size();
519
520
    bonded14IndexArray.resize(num14, vector<int>(2));
    bonded14ParamArray.resize(num14, vector<double>(3));
peastman's avatar
peastman committed
521
    particleParams.resize(numParticles);
522
    charges.resize(numParticles);
523
    C6params.resize(numParticles);
524
525
526
    baseParticleParams.resize(numParticles);
    for (int i = 0; i < numParticles; ++i)
       force.getParticleParameters(i, baseParticleParams[i][0], baseParticleParams[i][1], baseParticleParams[i][2]);
527
    
528
    // Record exception parameters.
529
    
530
    baseExceptionParams.resize(num14);
531
532
    for (int i = 0; i < num14; ++i) {
        int particle1, particle2;
533
        force.getExceptionParameters(nb14s[i], particle1, particle2, baseExceptionParams[i][0], baseExceptionParams[i][1], baseExceptionParams[i][2]);
534
535
        bonded14IndexArray[i][0] = particle1;
        bonded14IndexArray[i][1] = particle2;
536
    }
peastman's avatar
peastman committed
537
    bondForce.initialize(system.getNumParticles(), num14, 2, bonded14IndexArray, data.threads);
538
    
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
    // Record information about parameter offsets.
    
    hasParticleOffsets = (force.getNumParticleParameterOffsets() > 0);
    hasExceptionOffsets = (force.getNumExceptionParameterOffsets() > 0);
    particleParamOffsets.resize(force.getNumParticles());
    exceptionParamOffsets.resize(force.getNumExceptions());
    for (int i = 0; i < force.getNumParticleParameterOffsets(); i++) {
        string param;
        int particle;
        double charge, sigma, epsilon;
        force.getParticleParameterOffset(i, param, particle, charge, sigma, epsilon);
        auto paramPos = find(paramNames.begin(), paramNames.end(), param);
        int paramIndex;
        if (paramPos == paramNames.end()) {
            paramIndex = paramNames.size();
            paramNames.push_back(param);
        }
        else
            paramIndex = paramPos-paramNames.begin();
        particleParamOffsets[particle].push_back(make_tuple(charge, sigma, epsilon, paramIndex));
    }
    for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) {
        string param;
        int exception;
        double charge, sigma, epsilon;
        force.getExceptionParameterOffset(i, param, exception, charge, sigma, epsilon);
        auto paramPos = find(paramNames.begin(), paramNames.end(), param);
        int paramIndex;
        if (paramPos == paramNames.end()) {
            paramIndex = paramNames.size();
            paramNames.push_back(param);
        }
        else
            paramIndex = paramPos-paramNames.begin();
573
        exceptionParamOffsets[nb14Index[exception]].push_back(make_tuple(charge, sigma, epsilon, paramIndex));
574
    }
575
    paramValues.resize(paramNames.size(), 0.0);
576

577
578
    // Record other parameters.
    
579
580
581
582
583
    nonbondedMethod = CalcNonbondedForceKernel::NonbondedMethod(force.getNonbondedMethod());
    nonbondedCutoff = force.getCutoffDistance();
    if (nonbondedMethod == NoCutoff)
        useSwitchingFunction = false;
    else {
584
        data.requestNeighborList(nonbondedCutoff, 0.25*nonbondedCutoff, true, exclusions);
585
586
587
588
589
590
591
592
593
594
        useSwitchingFunction = force.getUseSwitchingFunction();
        switchingDistance = force.getSwitchingDistance();
    }
    if (nonbondedMethod == Ewald) {
        double alpha;
        NonbondedForceImpl::calcEwaldParameters(system, force, alpha, kmax[0], kmax[1], kmax[2]);
        ewaldAlpha = alpha;
    }
    else if (nonbondedMethod == PME) {
        double alpha;
595
        NonbondedForceImpl::calcPMEParameters(system, force, alpha, gridSize[0], gridSize[1], gridSize[2], false);
596
597
        ewaldAlpha = alpha;
    }
598
599
600
    else if (nonbondedMethod == LJPME) {
        double alpha;
        NonbondedForceImpl::calcPMEParameters(system, force, alpha, gridSize[0], gridSize[1], gridSize[2], false);
peastman's avatar
peastman committed
601
        ewaldAlpha = alpha;
602
        NonbondedForceImpl::calcPMEParameters(system, force, alpha, dispersionGridSize[0], dispersionGridSize[1], dispersionGridSize[2], true);
peastman's avatar
peastman committed
603
        ewaldDispersionAlpha = alpha;
604
605
        useSwitchingFunction = false;
    }
606
607
608
609
    if (nonbondedMethod == NoCutoff || nonbondedMethod == CutoffNonPeriodic)
        exceptionsArePeriodic = false;
    else
        exceptionsArePeriodic = force.getExceptionsUsePeriodicBoundaryConditions();
610
611
612
613
614
    rfDielectric = force.getReactionFieldDielectric();
    if (force.getUseDispersionCorrection())
        dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(system, force);
    else
        dispersionCoefficient = 0.0;
615
    data.isPeriodic |= (nonbondedMethod == CutoffPeriodic || nonbondedMethod == Ewald || nonbondedMethod == PME || nonbondedMethod == LJPME);
616
617
618
}

double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) {
619
620
621
    if (!hasInitializedPme) {
        hasInitializedPme = true;
        useOptimizedPme = false;
622
        computeParameters(context, false);
623
624
625
        if (nonbondedMethod == PME) {
            // If available, use the optimized PME implementation.

peastman's avatar
peastman committed
626
627
628
629
            vector<string> kernelNames;
            kernelNames.push_back("CalcPmeReciprocalForce");
            useOptimizedPme = getPlatform().supportsKernels(kernelNames);
            if (useOptimizedPme) {
630
                optimizedPme = getPlatform().createKernel(CalcPmeReciprocalForceKernel::Name(), context);
631
                optimizedPme.getAs<CalcPmeReciprocalForceKernel>().initialize(gridSize[0], gridSize[1], gridSize[2], numParticles, ewaldAlpha, data.deterministicForces);
632
633
            }
        }
634
635
636
637
638
        if (nonbondedMethod == LJPME) {
            // If available, use the optimized PME implementation.

            vector<string> kernelNames;
            kernelNames.push_back("CalcPmeReciprocalForce");
639
            kernelNames.push_back("CalcDispersionPmeReciprocalForce");
640
641
642
            useOptimizedPme = getPlatform().supportsKernels(kernelNames);
            if (useOptimizedPme) {
                optimizedPme = getPlatform().createKernel(CalcPmeReciprocalForceKernel::Name(), context);
643
                optimizedPme.getAs<CalcPmeReciprocalForceKernel>().initialize(gridSize[0], gridSize[1], gridSize[2], numParticles, ewaldAlpha, data.deterministicForces);
644
645
                optimizedDispersionPme = getPlatform().createKernel(CalcDispersionPmeReciprocalForceKernel::Name(), context);
                optimizedDispersionPme.getAs<CalcDispersionPmeReciprocalForceKernel>().initialize(dispersionGridSize[0], dispersionGridSize[1],
646
                                                                                                  dispersionGridSize[2], numParticles, ewaldDispersionAlpha, data.deterministicForces);
647
648
            }
        }
649
    }
650
    computeParameters(context, true);
651
    copyChargesToPosq(context, charges, chargePosqIndex);
652
    AlignedArray<float>& posq = data.posq;
peastman's avatar
peastman committed
653
654
655
    vector<Vec3>& posData = extractPositions(context);
    vector<Vec3>& forceData = extractForces(context);
    Vec3* boxVectors = extractBoxVectors(context);
656
    double energy = (includeReciprocal ? ewaldSelfEnergy : 0.0);
657
658
    bool ewald  = (nonbondedMethod == Ewald);
    bool pme  = (nonbondedMethod == PME);
659
    bool ljpme = (nonbondedMethod == LJPME);
660
661
    if (nonbondedMethod != NoCutoff)
        nonbonded->setUseCutoff(nonbondedCutoff, *data.neighborList, rfDielectric);
662
    if (data.isPeriodic) {
peastman's avatar
peastman committed
663
        Vec3* boxVectors = extractBoxVectors(context);
664
        double minAllowedSize = 1.999999*nonbondedCutoff;
665
        if (boxVectors[0][0] < minAllowedSize || boxVectors[1][1] < minAllowedSize || boxVectors[2][2] < minAllowedSize)
666
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
667
        nonbonded->setPeriodic(boxVectors);
668
        nonbonded->setPeriodicExceptions(exceptionsArePeriodic);
669
670
    }
    if (ewald)
671
        nonbonded->setUseEwald(ewaldAlpha, kmax[0], kmax[1], kmax[2]);
672
    if (pme)
673
        nonbonded->setUsePME(ewaldAlpha, gridSize);
674
    if (useSwitchingFunction)
675
        nonbonded->setUseSwitchingFunction(switchingDistance);
676
677
678
679
    if (ljpme){
        nonbonded->setUsePME(ewaldAlpha, gridSize);
        nonbonded->setUseLJPME(ewaldDispersionAlpha, dispersionGridSize);
    }
680
    double nonbondedEnergy = 0;
peastman's avatar
peastman committed
681
    if (includeDirect)
682
        nonbonded->calculateDirectIxn(numParticles, &posq[0], posData, particleParams, C6params, exclusions, data.threadForce, includeEnergy ? &nonbondedEnergy : NULL, data.threads);
683
684
    if (includeReciprocal) {
        if (useOptimizedPme) {
685
            PmeIO io(&posq[0], &data.threadForce[0][0], numParticles);
peastman's avatar
peastman committed
686
687
            Vec3 periodicBoxVectors[3] = {boxVectors[0], boxVectors[1], boxVectors[2]};
            optimizedPme.getAs<CalcPmeReciprocalForceKernel>().beginComputation(io, periodicBoxVectors, includeEnergy);
688
            nonbondedEnergy += optimizedPme.getAs<CalcPmeReciprocalForceKernel>().finishComputation(io);
689
690
691
692
693
            if (nonbondedMethod == LJPME) {
                copyChargesToPosq(context, C6params, ljPosqIndex);
                optimizedDispersionPme.getAs<CalcDispersionPmeReciprocalForceKernel>().beginComputation(io, periodicBoxVectors, includeEnergy);
                nonbondedEnergy += optimizedDispersionPme.getAs<CalcDispersionPmeReciprocalForceKernel>().finishComputation(io);
            }
694
695
        }
        else
696
            nonbonded->calculateReciprocalIxn(numParticles, &posq[0], posData, particleParams, C6params, exclusions, forceData, includeEnergy ? &nonbondedEnergy : NULL);
697
    }
peastman's avatar
peastman committed
698
    energy += nonbondedEnergy;
699
700
    if (includeDirect) {
        ReferenceLJCoulomb14 nonbonded14;
701
702
703
704
        if (exceptionsArePeriodic) {
            Vec3* boxVectors = extractBoxVectors(context);
            nonbonded14.setPeriodic(boxVectors);
        }
peastman's avatar
peastman committed
705
        bondForce.calculateForce(posData, bonded14ParamArray, forceData, includeEnergy ? &energy : NULL, nonbonded14);
706
        if (data.isPeriodic && nonbondedMethod != LJPME)
707
            energy += dispersionCoefficient/(boxVectors[0][0]*boxVectors[1][1]*boxVectors[2][2]);
708
709
    }
    return energy;
710
711
712
}

void CpuCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) {
713
714
    if (force.getNumParticles() != numParticles)
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
peastman's avatar
peastman committed
715
716
717
718
719
720
721
722
723
724
725

    // Identify which exceptions are 1-4 interactions.

    set<int> exceptionsWithOffsets;
    for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) {
        string param;
        int exception;
        double charge, sigma, epsilon;
        force.getExceptionParameterOffset(i, param, exception, charge, sigma, epsilon);
        exceptionsWithOffsets.insert(exception);
    }
726
727
728
729
730
    vector<int> nb14s;
    for (int i = 0; i < force.getNumExceptions(); i++) {
        int particle1, particle2;
        double chargeProd, sigma, epsilon;
        force.getExceptionParameters(i, particle1, particle2, chargeProd, sigma, epsilon);
peastman's avatar
peastman committed
731
        if (chargeProd != 0.0 || epsilon != 0.0 || exceptionsWithOffsets.find(i) != exceptionsWithOffsets.end())
732
733
734
735
736
737
738
            nb14s.push_back(i);
    }
    if (nb14s.size() != num14)
        throw OpenMMException("updateParametersInContext: The number of non-excluded exceptions has changed");

    // Record the values.

739
740
    for (int i = 0; i < numParticles; ++i)
       force.getParticleParameters(i, baseParticleParams[i][0], baseParticleParams[i][1], baseParticleParams[i][2]);
741
742
    for (int i = 0; i < num14; ++i) {
        int particle1, particle2;
743
        force.getExceptionParameters(nb14s[i], particle1, particle2, baseExceptionParams[i][0], baseExceptionParams[i][1], baseExceptionParams[i][2]);
744
745
746
        bonded14IndexArray[i][0] = particle1;
        bonded14IndexArray[i][1] = particle2;
    }
747
    computeParameters(context, false);
748
749
750
751
752
753
    
    // Recompute the coefficient for the dispersion correction.

    NonbondedForce::NonbondedMethod method = force.getNonbondedMethod();
    if (force.getUseDispersionCorrection() && (method == NonbondedForce::CutoffPeriodic || method == NonbondedForce::Ewald || method == NonbondedForce::PME))
        dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(context.getSystem(), force);
754
}
755

756
void CpuCalcNonbondedForceKernel::getPMEParameters(double& alpha, int& nx, int& ny, int& nz) const {
757
    if (nonbondedMethod != PME && nonbondedMethod != LJPME)
758
759
760
761
762
763
764
765
766
767
768
        throw OpenMMException("getPMEParametersInContext: This Context is not using PME");
    if (useOptimizedPme)
        optimizedPme.getAs<const CalcPmeReciprocalForceKernel>().getPMEParameters(alpha, nx, ny, nz);
    else {
        alpha = ewaldAlpha;
        nx = gridSize[0];
        ny = gridSize[1];
        nz = gridSize[2];
    }
}

769
void CpuCalcNonbondedForceKernel::getLJPMEParameters(double& alpha, int& nx, int& ny, int& nz) const {
770
771
772
    if (nonbondedMethod != LJPME)
        throw OpenMMException("getPMEParametersInContext: This Context is not using PME");
    if (useOptimizedPme)
773
        optimizedDispersionPme.getAs<const CalcPmeReciprocalForceKernel>().getPMEParameters(alpha, nx, ny, nz);
774
    else {
775
776
777
778
        alpha = ewaldDispersionAlpha;
        nx = dispersionGridSize[0];
        ny = dispersionGridSize[1];
        nz = dispersionGridSize[2];
779
780
781
    }
}

782
void CpuCalcNonbondedForceKernel::computeParameters(ContextImpl& context, bool offsetsOnly) {
783
784
785
786
787
788
789
790
791
792
    bool paramChanged = false;
    for (int i = 0; i < paramNames.size(); i++) {
        double value = context.getParameter(paramNames[i]);
        if (value != paramValues[i]) {
            paramValues[i] = value;;
            paramChanged = true;
        }
    }
    if (!paramChanged && offsetsOnly)
        return;
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820

    // Compute particle parameters.

    if (hasParticleOffsets || !offsetsOnly) {
        double sumSquaredCharges = 0.0;
        for (int i = 0; i < numParticles; i++) {
            double charge = baseParticleParams[i][0];
            double sigma = baseParticleParams[i][1];
            double epsilon = baseParticleParams[i][2];
            for (auto& offset : particleParamOffsets[i]) {
                double value = paramValues[get<3>(offset)];
                charge += value*get<0>(offset);
                sigma += value*get<1>(offset);
                epsilon += value*get<2>(offset);
            }
            charges[i] = (float) charge;
            particleParams[i] = make_pair((float) (0.5*sigma), (float) (2.0*sqrt(epsilon)));
            C6params[i] = 8.0*pow(particleParams[i].first, 3.0) * particleParams[i].second;
            sumSquaredCharges += charge*charge;
        }
        if (nonbondedMethod == Ewald || nonbondedMethod == PME || nonbondedMethod == LJPME) {
            ewaldSelfEnergy = -ONE_4PI_EPS0*ewaldAlpha*sumSquaredCharges/sqrt(M_PI);
            if (nonbondedMethod == LJPME)
                for (int atom = 0; atom < numParticles; atom++)
                    ewaldSelfEnergy += pow(ewaldDispersionAlpha, 6.0) * C6params[atom]*C6params[atom] / 12.0;
        }
        else
            ewaldSelfEnergy = 0.0;
821
822
        chargePosqIndex = data.requestPosqIndex();
        ljPosqIndex = data.requestPosqIndex();
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
    }

    // Compute exception parameters.

    if (hasExceptionOffsets || !offsetsOnly) {
        for (int i = 0; i < num14; i++) {
            double chargeProd = baseExceptionParams[i][0];
            double sigma = baseExceptionParams[i][1];
            double epsilon = baseExceptionParams[i][2];
            for (auto& offset : exceptionParamOffsets[i]) {
                double value = paramValues[get<3>(offset)];
                chargeProd += value*get<0>(offset);
                sigma += value*get<1>(offset);
                epsilon += value*get<2>(offset);
            }
            bonded14ParamArray[i][0] = sigma;
            bonded14ParamArray[i][1] = 4.0*epsilon;
            bonded14ParamArray[i][2] = chargeProd;
        }
    }
}

845
CpuCalcCustomNonbondedForceKernel::CpuCalcCustomNonbondedForceKernel(string name, const Platform& platform, CpuPlatform::PlatformData& data) :
846
            CalcCustomNonbondedForceKernel(name, platform), data(data), forceCopy(NULL), nonbonded(NULL) {
847
848
}

849
CpuCalcCustomNonbondedForceKernel::~CpuCalcCustomNonbondedForceKernel() {
850
851
    if (nonbonded != NULL)
        delete nonbonded;
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
    if (forceCopy != NULL)
        delete forceCopy;
}

void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const CustomNonbondedForce& force) {

    // Record the exclusions.

    numParticles = force.getNumParticles();
    exclusions.resize(numParticles);
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int particle1, particle2;
        force.getExclusionParticles(i, particle1, particle2);
        exclusions[particle1].insert(particle2);
        exclusions[particle2].insert(particle1);
    }

    // Build the arrays.

    int numParameters = force.getNumPerParticleParameters();
872
873
874
    particleParamArray.resize(numParticles);
    for (int i = 0; i < numParticles; ++i)
        force.getParticleParameters(i, particleParamArray[i]);
875
876
877
878
879
    nonbondedMethod = CalcCustomNonbondedForceKernel::NonbondedMethod(force.getNonbondedMethod());
    nonbondedCutoff = force.getCutoffDistance();
    if (nonbondedMethod == NoCutoff)
        useSwitchingFunction = false;
    else {
880
        data.requestNeighborList(nonbondedCutoff, 0.25*nonbondedCutoff, true, exclusions);
881
882
883
884
885
886
887
888
889
890
891
892
893
        useSwitchingFunction = force.getUseSwitchingFunction();
        switchingDistance = force.getSwitchingDistance();
    }

    // Create custom functions for the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    for (int i = 0; i < force.getNumFunctions(); i++)
        functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));

    // Parse the various expressions used to calculate the force.

    Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
894
895
    Lepton::CompiledExpression energyExpression = expression.createCompiledExpression();
    Lepton::CompiledExpression forceExpression = expression.differentiate("r").createCompiledExpression();
896
897
898
899
900
901
    for (int i = 0; i < numParameters; i++)
        parameterNames.push_back(force.getPerParticleParameterName(i));
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParameterNames.push_back(force.getGlobalParameterName(i));
        globalParamValues[force.getGlobalParameterName(i)] = force.getGlobalParameterDefaultValue(i);
    }
902
903
904
905
906
907
    std::vector<Lepton::CompiledExpression> energyParamDerivExpressions;
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
        string param = force.getEnergyParameterDerivativeName(i);
        energyParamDerivNames.push_back(param);
        energyParamDerivExpressions.push_back(expression.differentiate(param).createCompiledExpression());
    }
908
909
910
911
912
913
914
915
    set<string> variables;
    variables.insert("r");
    for (int i = 0; i < numParameters; i++) {
        variables.insert(parameterNames[i]+"1");
        variables.insert(parameterNames[i]+"2");
    }
    variables.insert(globalParameterNames.begin(), globalParameterNames.end());
    validateVariables(expression.getRootNode(), variables);
916
917
918

    // Delete the custom functions.

peastman's avatar
peastman committed
919
920
    for (auto& function : functions)
        delete function.second;
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
    
    // Record information for the long range correction.
    
    if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
        forceCopy = new CustomNonbondedForce(force);
        hasInitializedLongRangeCorrection = false;
    }
    else {
        longRangeCoefficient = 0.0;
        hasInitializedLongRangeCorrection = true;
    }
    
    // Record the interaction groups.
    
    for (int i = 0; i < force.getNumInteractionGroups(); i++) {
        set<int> set1, set2;
        force.getInteractionGroupParameters(i, set1, set2);
        interactionGroups.push_back(make_pair(set1, set2));
    }
940
    data.isPeriodic |= (nonbondedMethod == CutoffPeriodic);
941
    nonbonded = new CpuCustomNonbondedForce(energyExpression, forceExpression, parameterNames, exclusions, energyParamDerivExpressions, data.threads);
peastman's avatar
Bug fix  
peastman committed
942
943
    if (interactionGroups.size() > 0)
        nonbonded->setInteractionGroups(interactionGroups);
944
945
946
}

double CpuCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
947
948
949
    vector<Vec3>& posData = extractPositions(context);
    vector<Vec3>& forceData = extractForces(context);
    Vec3* boxVectors = extractBoxVectors(context);
950
951
    double energy = 0;
    bool periodic = (nonbondedMethod == CutoffPeriodic);
952
953
    if (nonbondedMethod != NoCutoff)
        nonbonded->setUseCutoff(nonbondedCutoff, *data.neighborList);
954
955
    if (periodic) {
        double minAllowedSize = 2*nonbondedCutoff;
956
        if (boxVectors[0][0] < minAllowedSize || boxVectors[1][1] < minAllowedSize || boxVectors[2][2] < minAllowedSize)
957
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
958
        nonbonded->setPeriodic(boxVectors);
959
960
    }
    bool globalParamsChanged = false;
peastman's avatar
peastman committed
961
962
963
    for (auto& name : globalParameterNames) {
        double value = context.getParameter(name);
        if (globalParamValues[name] != value)
964
            globalParamsChanged = true;
peastman's avatar
peastman committed
965
        globalParamValues[name] = value;
966
967
    }
    if (useSwitchingFunction)
968
        nonbonded->setUseSwitchingFunction(switchingDistance);
969
    vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0);
970
    nonbonded->calculatePairIxn(numParticles, &data.posq[0], posData, particleParamArray, globalParamValues, data.threadForce, includeForces, includeEnergy, energy, &energyParamDerivValues[0]);
971
972
973
    map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context);
    for (int i = 0; i < energyParamDerivNames.size(); i++)
        energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i];
974
975
976
    
    // Add in the long range correction.
    
977
978
979
    if (!hasInitializedLongRangeCorrection) {
        longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy);
        CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads);
980
981
        hasInitializedLongRangeCorrection = true;
    }
982
983
    else if (globalParamsChanged && forceCopy != NULL)
        CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads);
984
985
986
987
    double volume = boxVectors[0][0]*boxVectors[1][1]*boxVectors[2][2];
    energy += longRangeCoefficient/volume;
    for (int i = 0; i < longRangeCoefficientDerivs.size(); i++)
        energyParamDerivs[energyParamDerivNames[i]] += longRangeCoefficientDerivs[i]/volume;
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
    return energy;
}

void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force) {
    if (numParticles != force.getNumParticles())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");

    // Record the values.

    int numParameters = force.getNumPerParticleParameters();
    vector<double> params;
    for (int i = 0; i < numParticles; ++i) {
        vector<double> parameters;
        force.getParticleParameters(i, parameters);
        for (int j = 0; j < numParameters; j++)
            particleParamArray[i][j] = parameters[j];
    }
    
    // If necessary, recompute the long range correction.
    
    if (forceCopy != NULL) {
1009
1010
        longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force);
        CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads);
1011
1012
1013
1014
1015
        hasInitializedLongRangeCorrection = true;
        *forceCopy = force;
    }
}

1016
1017
1018
1019
CpuCalcGBSAOBCForceKernel::~CpuCalcGBSAOBCForceKernel() {
}

void CpuCalcGBSAOBCForceKernel::initialize(const System& system, const GBSAOBCForce& force) {
1020
    posqIndex = data.requestPosqIndex();
1021
1022
    int numParticles = system.getNumParticles();
    particleParams.resize(numParticles);
1023
    charges.resize(numParticles);
1024
1025
1026
    for (int i = 0; i < numParticles; ++i) {
        double charge, radius, scalingFactor;
        force.getParticleParameters(i, charge, radius, scalingFactor);
1027
        charges[i] = (float) charge;
1028
1029
        radius -= 0.009;
        particleParams[i] = make_pair((float) radius, (float) (scalingFactor*radius));
1030
1031
1032
1033
    }
    obc.setParticleParameters(particleParams);
    obc.setSolventDielectric((float) force.getSolventDielectric());
    obc.setSoluteDielectric((float) force.getSoluteDielectric());
1034
    obc.setSurfaceAreaEnergy((float) force.getSurfaceAreaEnergy());
1035
1036
    if (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff)
        obc.setUseCutoff((float) force.getCutoffDistance());
1037
    data.isPeriodic |= (force.getNonbondedMethod() == GBSAOBCForce::CutoffPeriodic);
1038
1039
1040
}

double CpuCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
1041
    copyChargesToPosq(context, charges, posqIndex);
1042
    if (data.isPeriodic) {
peastman's avatar
peastman committed
1043
        Vec3& boxSize = extractBoxSize(context);
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        float floatBoxSize[3] = {(float) boxSize[0], (float) boxSize[1], (float) boxSize[2]};
        obc.setPeriodic(floatBoxSize);
    }
    double energy = 0.0;
    obc.computeForce(data.posq, data.threadForce, includeEnergy ? &energy : NULL, data.threads);
    return energy;
}

void CpuCalcGBSAOBCForceKernel::copyParametersToContext(ContextImpl& context, const GBSAOBCForce& force) {
    int numParticles = force.getNumParticles();
    if (numParticles != obc.getParticleParameters().size())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");

    // Record the values.

1059
    posqIndex = data.requestPosqIndex();
1060
1061
1062
    for (int i = 0; i < numParticles; ++i) {
        double charge, radius, scalingFactor;
        force.getParticleParameters(i, charge, radius, scalingFactor);
1063
        charges[i] = (float) charge;
1064
1065
        radius -= 0.009;
        particleParams[i] = make_pair((float) radius, (float) (scalingFactor*radius));
1066
1067
1068
    }
    obc.setParticleParameters(particleParams);
}
1069

1070
CpuCalcCustomGBForceKernel::~CpuCalcCustomGBForceKernel() {
1071
1072
    if (ixn != NULL)
        delete ixn;
Peter Eastman's avatar
Peter Eastman committed
1073
1074
    if (neighborList != NULL)
        delete neighborList;
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
}

void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGBForce& force) {
    if (force.getNumComputedValues() > 0) {
        string name, expression;
        CustomGBForce::ComputationType type;
        force.getComputedValueParameters(0, name, expression, type);
        if (type == CustomGBForce::SingleParticle)
            throw OpenMMException("CpuPlatform requires that the first computed value for a CustomGBForce be of type ParticlePair or ParticlePairNoExclusions.");
        for (int i = 1; i < force.getNumComputedValues(); i++) {
            force.getComputedValueParameters(i, name, expression, type);
            if (type != CustomGBForce::SingleParticle)
                throw OpenMMException("CpuPlatform requires that a CustomGBForce only have one computed value of type ParticlePair or ParticlePairNoExclusions.");
        }
    }

    // Record the exclusions.

    numParticles = force.getNumParticles();
    exclusions.resize(numParticles);
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int particle1, particle2;
        force.getExclusionParticles(i, particle1, particle2);
        exclusions[particle1].insert(particle2);
        exclusions[particle2].insert(particle1);
    }

    // Build the arrays.

    int numPerParticleParameters = force.getNumPerParticleParameters();
1105
1106
1107
    particleParamArray.resize(numParticles);
    for (int i = 0; i < numParticles; ++i)
        force.getParticleParameters(i, particleParamArray[i]);
1108
1109
1110
1111
1112
    for (int i = 0; i < numPerParticleParameters; i++)
        particleParameterNames.push_back(force.getPerParticleParameterName(i));
    for (int i = 0; i < force.getNumGlobalParameters(); i++)
        globalParameterNames.push_back(force.getGlobalParameterName(i));
    nonbondedMethod = CalcCustomGBForceKernel::NonbondedMethod(force.getNonbondedMethod());
peastman's avatar
peastman committed
1113
    nonbondedCutoff = force.getCutoffDistance();
1114
    if (nonbondedMethod != NoCutoff)
Peter Eastman's avatar
Peter Eastman committed
1115
        neighborList = new CpuNeighborList(4);
1116
1117
1118
1119
1120
1121
1122
1123
1124

    // Create custom functions for the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    for (int i = 0; i < force.getNumFunctions(); i++)
        functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));

    // Parse the expressions for computed values.

1125
1126
    vector<vector<Lepton::CompiledExpression> > valueDerivExpressions(force.getNumComputedValues());
    vector<vector<Lepton::CompiledExpression> > valueGradientExpressions(force.getNumComputedValues());
1127
    vector<vector<Lepton::CompiledExpression> > valueParamDerivExpressions(force.getNumComputedValues());
1128
1129
    vector<Lepton::CompiledExpression> valueExpressions;
    vector<Lepton::CompiledExpression> energyExpressions;
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
    set<string> particleVariables, pairVariables;
    pairVariables.insert("r");
    particleVariables.insert("x");
    particleVariables.insert("y");
    particleVariables.insert("z");
    for (int i = 0; i < numPerParticleParameters; i++) {
        particleVariables.insert(particleParameterNames[i]);
        pairVariables.insert(particleParameterNames[i]+"1");
        pairVariables.insert(particleParameterNames[i]+"2");
    }
    particleVariables.insert(globalParameterNames.begin(), globalParameterNames.end());
    pairVariables.insert(globalParameterNames.begin(), globalParameterNames.end());
1142
1143
1144
1145
1146
1147
1148
1149
    for (int i = 0; i < force.getNumComputedValues(); i++) {
        string name, expression;
        CustomGBForce::ComputationType type;
        force.getComputedValueParameters(i, name, expression, type);
        Lepton::ParsedExpression ex = Lepton::Parser::parse(expression, functions).optimize();
        valueExpressions.push_back(ex.createCompiledExpression());
        valueTypes.push_back(type);
        valueNames.push_back(name);
1150
        if (i == 0) {
1151
            valueDerivExpressions[i].push_back(ex.differentiate("r").createCompiledExpression());
1152
1153
            validateVariables(ex.getRootNode(), pairVariables);
        }
1154
1155
1156
1157
1158
1159
        else {
            valueGradientExpressions[i].push_back(ex.differentiate("x").createCompiledExpression());
            valueGradientExpressions[i].push_back(ex.differentiate("y").createCompiledExpression());
            valueGradientExpressions[i].push_back(ex.differentiate("z").createCompiledExpression());
            for (int j = 0; j < i; j++)
                valueDerivExpressions[i].push_back(ex.differentiate(valueNames[j]).createCompiledExpression());
1160
            validateVariables(ex.getRootNode(), particleVariables);
1161
        }
1162
1163
1164
1165
1166
        for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++) {
            string param = force.getEnergyParameterDerivativeName(j);
            energyParamDerivNames.push_back(param);
            valueParamDerivExpressions[i].push_back(ex.differentiate(param).createCompiledExpression());
        }
1167
1168
1169
        particleVariables.insert(name);
        pairVariables.insert(name+"1");
        pairVariables.insert(name+"2");
1170
1171
1172
1173
    }

    // Parse the expressions for energy terms.

1174
1175
    vector<vector<Lepton::CompiledExpression> > energyDerivExpressions(force.getNumEnergyTerms());
    vector<vector<Lepton::CompiledExpression> > energyGradientExpressions(force.getNumEnergyTerms());
1176
    vector<vector<Lepton::CompiledExpression> > energyParamDerivExpressions(force.getNumEnergyTerms());
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
    for (int i = 0; i < force.getNumEnergyTerms(); i++) {
        string expression;
        CustomGBForce::ComputationType type;
        force.getEnergyTermParameters(i, expression, type);
        Lepton::ParsedExpression ex = Lepton::Parser::parse(expression, functions).optimize();
        energyExpressions.push_back(ex.createCompiledExpression());
        energyTypes.push_back(type);
        if (type != CustomGBForce::SingleParticle)
            energyDerivExpressions[i].push_back(ex.differentiate("r").createCompiledExpression());
        for (int j = 0; j < force.getNumComputedValues(); j++) {
            if (type == CustomGBForce::SingleParticle) {
                energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]).createCompiledExpression());
                energyGradientExpressions[i].push_back(ex.differentiate("x").createCompiledExpression());
                energyGradientExpressions[i].push_back(ex.differentiate("y").createCompiledExpression());
                energyGradientExpressions[i].push_back(ex.differentiate("z").createCompiledExpression());
1192
                validateVariables(ex.getRootNode(), particleVariables);
1193
1194
1195
1196
            }
            else {
                energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]+"1").createCompiledExpression());
                energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]+"2").createCompiledExpression());
1197
                validateVariables(ex.getRootNode(), pairVariables);
1198
1199
            }
        }
1200
1201
        for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
            energyParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).createCompiledExpression());
1202
1203
1204
1205
    }

    // Delete the custom functions.

peastman's avatar
peastman committed
1206
1207
    for (auto& function : functions)
        delete function.second;
1208
1209
1210
    ixn = new CpuCustomGBForce(numParticles, exclusions, valueExpressions, valueDerivExpressions, valueGradientExpressions, valueParamDerivExpressions,
        valueNames, valueTypes, energyExpressions, energyDerivExpressions, energyGradientExpressions, energyParamDerivExpressions, energyTypes,
        particleParameterNames, data.threads);
1211
    data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
1212
1213
1214
}

double CpuCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
1215
1216
1217
    vector<Vec3>& forceData = extractForces(context);
    double energy = 0;
    Vec3* boxVectors = extractBoxVectors(context);
1218
1219
    if (data.isPeriodic)
        ixn->setPeriodic(extractBoxSize(context));
1220
    if (nonbondedMethod != NoCutoff) {
1221
        vector<set<int> > noExclusions(numParticles);
Peter Eastman's avatar
Peter Eastman committed
1222
1223
        neighborList->computeNeighborList(numParticles, data.posq, noExclusions, boxVectors, data.isPeriodic, nonbondedCutoff, data.threads);
        ixn->setUseCutoff(nonbondedCutoff, *neighborList);
1224
1225
    }
    map<string, double> globalParameters;
peastman's avatar
peastman committed
1226
1227
    for (auto& name : globalParameterNames)
        globalParameters[name] = context.getParameter(name);
1228
1229
1230
1231
1232
    vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0);
    ixn->calculateIxn(numParticles, &data.posq[0], particleParamArray, globalParameters, data.threadForce, includeForces, includeEnergy, energy, &energyParamDerivValues[0]);
    map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context);
    for (int i = 0; i < energyParamDerivNames.size(); i++)
        energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i];
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
    return energy;
}

void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, const CustomGBForce& force) {
    if (numParticles != force.getNumParticles())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");

    // Record the values.

    int numParameters = force.getNumPerParticleParameters();
    vector<double> params;
    for (int i = 0; i < numParticles; ++i) {
        vector<double> parameters;
        force.getParticleParameters(i, parameters);
        for (int j = 0; j < numParameters; j++)
peastman's avatar
peastman committed
1248
            particleParamArray[i][j] = static_cast<double>(parameters[j]);
1249
1250
1251
    }
}

1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
CpuCalcCustomManyParticleForceKernel::~CpuCalcCustomManyParticleForceKernel() {
    if (ixn != NULL)
        delete ixn;
}

void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, const CustomManyParticleForce& force) {

    // Build the arrays.

    numParticles = system.getNumParticles();
1262
    particleParamArray.resize(numParticles);
1263
1264
    for (int i = 0; i < numParticles; ++i) {
        int type;
1265
        force.getParticleParameters(i, particleParamArray[i], type);
1266
1267
1268
    }
    for (int i = 0; i < force.getNumGlobalParameters(); i++)
        globalParameterNames.push_back(force.getGlobalParameterName(i));
1269
    ixn = new CpuCustomManyParticleForce(force, data.threads);
1270
1271
    nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
    cutoffDistance = force.getCutoffDistance();
1272
    data.isPeriodic |= (nonbondedMethod == CutoffPeriodic);
1273
1274
1275
1276
}

double CpuCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
    map<string, double> globalParameters;
peastman's avatar
peastman committed
1277
1278
    for (auto& name : globalParameterNames)
        globalParameters[name] = context.getParameter(name);
1279
    if (nonbondedMethod == CutoffPeriodic) {
peastman's avatar
peastman committed
1280
        Vec3* boxVectors = extractBoxVectors(context);
1281
        double minAllowedSize = 2*cutoffDistance;
1282
        if (boxVectors[0][0] < minAllowedSize || boxVectors[1][1] < minAllowedSize || boxVectors[2][2] < minAllowedSize)
1283
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
1284
        ixn->setPeriodic(boxVectors);
1285
    }
1286
    double energy = 0;
1287
    ixn->calculateIxn(data.posq, particleParamArray, globalParameters, data.threadForce, includeForces, includeEnergy, energy);
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
    return energy;
}

void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl& context, const CustomManyParticleForce& force) {
    if (numParticles != force.getNumParticles())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");

    // Record the values.

    int numParameters = force.getNumPerParticleParameters();
    vector<double> params;
    for (int i = 0; i < numParticles; ++i) {
        vector<double> parameters;
        int type;
        force.getParticleParameters(i, parameters, type);
        for (int j = 0; j < numParameters; j++)
peastman's avatar
peastman committed
1304
            particleParamArray[i][j] = static_cast<double>(parameters[j]);
1305
1306
1307
    }
}

1308
1309
1310
1311
1312
1313
1314
CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() {
    if (ixn != NULL)
        delete ixn;
}

void CpuCalcGayBerneForceKernel::initialize(const System& system, const GayBerneForce& force) {
    ixn = new CpuGayBerneForce(force);
1315
    data.isPeriodic |= (force.getNonbondedMethod() == GayBerneForce::CutoffPeriodic);
1316
1317
1318
1319
    if (force.getNonbondedMethod() != GayBerneForce::NoCutoff) {
        double cutoff = force.getCutoffDistance();
        data.requestNeighborList(cutoff, 0.1*cutoff, true, ixn->getExclusions());
    }
1320
1321
1322
}

double CpuCalcGayBerneForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
1323
    return ixn->calculateForce(extractPositions(context), extractForces(context), data.threadForce, extractBoxVectors(context), data);
1324
1325
1326
1327
1328
1329
1330
1331
}

void CpuCalcGayBerneForceKernel::copyParametersToContext(ContextImpl& context, const GayBerneForce& force) {
    delete ixn;
    ixn = NULL;
    ixn = new CpuGayBerneForce(force);
}

1332
1333
1334
1335
1336
1337
1338
1339
1340
CpuIntegrateLangevinStepKernel::~CpuIntegrateLangevinStepKernel() {
    if (dynamics)
        delete dynamics;
}

void CpuIntegrateLangevinStepKernel::initialize(const System& system, const LangevinIntegrator& integrator) {
    int numParticles = system.getNumParticles();
    masses.resize(numParticles);
    for (int i = 0; i < numParticles; ++i)
peastman's avatar
peastman committed
1341
        masses[i] = static_cast<double>(system.getParticleMass(i));
1342
1343
1344
1345
1346
1347
1348
    data.random.initialize(integrator.getRandomNumberSeed(), data.threads.getNumThreads());
}

void CpuIntegrateLangevinStepKernel::execute(ContextImpl& context, const LangevinIntegrator& integrator) {
    double temperature = integrator.getTemperature();
    double friction = integrator.getFriction();
    double stepSize = integrator.getStepSize();
peastman's avatar
peastman committed
1349
1350
1351
    vector<Vec3>& posData = extractPositions(context);
    vector<Vec3>& velData = extractVelocities(context);
    vector<Vec3>& forceData = extractForces(context);
1352
1353
1354
1355
1356
    if (dynamics == 0 || temperature != prevTemp || friction != prevFriction || stepSize != prevStepSize) {
        // Recreate the computation objects with the new parameters.
        
        if (dynamics)
            delete dynamics;
1357
        dynamics = new CpuLangevinDynamics(context.getSystem().getNumParticles(), stepSize, friction, temperature, data.threads, data.random);
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
        dynamics->setReferenceConstraintAlgorithm(&extractConstraints(context));
        prevTemp = temperature;
        prevFriction = friction;
        prevStepSize = stepSize;
    }
    dynamics->update(context.getSystem(), posData, velData, forceData, masses, integrator.getConstraintTolerance());
    ReferencePlatform::PlatformData* refData = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
    refData->time += stepSize;
    refData->stepCount++;
}

double CpuIntegrateLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const LangevinIntegrator& integrator) {
    return computeShiftedKineticEnergy(context, masses, 0.5*integrator.getStepSize());
}
1372

1373
CpuIntegrateLangevinMiddleStepKernel::~CpuIntegrateLangevinMiddleStepKernel() {
1374
1375
1376
1377
    if (dynamics)
        delete dynamics;
}

1378
void CpuIntegrateLangevinMiddleStepKernel::initialize(const System& system, const LangevinMiddleIntegrator& integrator) {
1379
1380
1381
1382
1383
1384
1385
    int numParticles = system.getNumParticles();
    masses.resize(numParticles);
    for (int i = 0; i < numParticles; ++i)
        masses[i] = system.getParticleMass(i);
    data.random.initialize(integrator.getRandomNumberSeed(), data.threads.getNumThreads());
}

1386
void CpuIntegrateLangevinMiddleStepKernel::execute(ContextImpl& context, const LangevinMiddleIntegrator& integrator) {
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
    double temperature = integrator.getTemperature();
    double friction = integrator.getFriction();
    double stepSize = integrator.getStepSize();
    vector<Vec3>& posData = extractPositions(context);
    vector<Vec3>& velData = extractVelocities(context);
    if (dynamics == 0 || temperature != prevTemp || friction != prevFriction || stepSize != prevStepSize) {
        // Recreate the computation objects with the new parameters.
        
        if (dynamics)
            delete dynamics;
1397
        dynamics = new CpuLangevinMiddleDynamics(context.getSystem().getNumParticles(), stepSize, friction, temperature, data.threads, data.random);
1398
1399
1400
1401
1402
        dynamics->setReferenceConstraintAlgorithm(&extractConstraints(context));
        prevTemp = temperature;
        prevFriction = friction;
        prevStepSize = stepSize;
    }
1403
    dynamics->update(context, posData, velData, masses, integrator.getConstraintTolerance());
1404
1405
1406
1407
1408
    ReferencePlatform::PlatformData* refData = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
    refData->time += stepSize;
    refData->stepCount++;
}

1409
double CpuIntegrateLangevinMiddleStepKernel::computeKineticEnergy(ContextImpl& context, const LangevinMiddleIntegrator& integrator) {
1410
1411
    return computeShiftedKineticEnergy(context, masses, 0.0);
}